187 lines
6.0 KiB
Python
187 lines
6.0 KiB
Python
"""
|
|
Copyright (c) 2025 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import ctypes
|
|
import hashlib
|
|
import os
|
|
import shutil
|
|
import time
|
|
|
|
import filelock
|
|
|
|
from .core import logger
|
|
from .env import FLASHINFER_CUBIN_DIR
|
|
|
|
# This is the storage path for the cubins, it can be replaced
|
|
# with a local path for testing.
|
|
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
|
|
"FLASHINFER_CUBINS_REPOSITORY",
|
|
"https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/",
|
|
)
|
|
|
|
|
|
def download_file(source, local_path, retries=3, delay=5, timeout=10, lock_timeout=30):
|
|
"""
|
|
Downloads a file from a URL or copies from a local path to a destination.
|
|
|
|
Parameters:
|
|
- source (str): The URL or local file path of the file to download.
|
|
- local_path (str): The local file path to save the downloaded/copied file.
|
|
- retries (int): Number of retry attempts for URL downloads (default: 3).
|
|
- delay (int): Delay in seconds between retries (default: 5).
|
|
- timeout (int): Timeout for the HTTP request in seconds (default: 10).
|
|
- lock_timeout (int): Timeout in seconds for the file lock (default: 30).
|
|
|
|
Returns:
|
|
- bool: True if download or copy is successful, False otherwise.
|
|
"""
|
|
|
|
import requests # type: ignore[import-untyped]
|
|
|
|
lock_path = f"{local_path}.lock" # Lock file path
|
|
lock = filelock.FileLock(lock_path, timeout=lock_timeout)
|
|
|
|
try:
|
|
with lock:
|
|
logger.info(f"Acquired lock for {local_path}")
|
|
|
|
# Handle local file copy
|
|
if os.path.exists(source):
|
|
try:
|
|
shutil.copy(source, local_path)
|
|
logger.info(f"File copied successfully: {local_path}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to copy local file: {e}")
|
|
return False
|
|
|
|
# Handle URL downloads
|
|
for attempt in range(1, retries + 1):
|
|
try:
|
|
response = requests.get(source, timeout=timeout)
|
|
response.raise_for_status()
|
|
|
|
with open(local_path, "wb") as file:
|
|
file.write(response.content)
|
|
|
|
logger.info(
|
|
f"File downloaded successfully: {source} -> {local_path}"
|
|
)
|
|
return True
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
logger.warning(
|
|
f"Downloading {source}: attempt {attempt} failed: {e}"
|
|
)
|
|
|
|
if attempt < retries:
|
|
logger.info(f"Retrying in {delay} seconds...")
|
|
time.sleep(delay)
|
|
else:
|
|
logger.error("Max retries reached. Download failed.")
|
|
return False
|
|
|
|
except filelock.Timeout:
|
|
logger.error(
|
|
f"Failed to acquire lock for {local_path} within {lock_timeout} seconds."
|
|
)
|
|
return False
|
|
|
|
finally:
|
|
# Clean up the lock file
|
|
if os.path.exists(lock_path):
|
|
os.remove(lock_path)
|
|
logger.info(f"Lock file {lock_path} removed.")
|
|
|
|
|
|
def load_cubin(cubin_path, sha256) -> bytes:
|
|
"""
|
|
Load a cubin from the provide local path and
|
|
ensure that the sha256 signature matches.
|
|
|
|
Return None on failure.
|
|
"""
|
|
logger.debug(f"Loading from {cubin_path}")
|
|
try:
|
|
with open(cubin_path, mode="rb") as f:
|
|
cubin = f.read()
|
|
if os.getenv("FLASHINFER_CUBIN_CHECKSUM_DISABLED"):
|
|
return cubin
|
|
m = hashlib.sha256()
|
|
m.update(cubin)
|
|
actual_sha = m.hexdigest()
|
|
if sha256 == actual_sha:
|
|
return cubin
|
|
logger.warning(
|
|
f"sha256 mismatch (expected {sha256} actual {actual_sha}) for {cubin_path}"
|
|
)
|
|
except Exception:
|
|
pass
|
|
return b""
|
|
|
|
|
|
def get_cubin(name, sha256, file_extension=".cubin"):
|
|
"""
|
|
Load a cubin from the local cache directory with {name} and
|
|
ensure that the sha256 signature matches.
|
|
|
|
If the kernel does not exist in the cache, it will downloaded.
|
|
|
|
Returns:
|
|
None on failure.
|
|
"""
|
|
cubin_fname = name + file_extension
|
|
cubin_path = FLASHINFER_CUBIN_DIR / cubin_fname
|
|
cubin = load_cubin(cubin_path, sha256)
|
|
if cubin:
|
|
return cubin
|
|
# either the file does not exist or it is corrupted, we'll download a new one.
|
|
uri = FLASHINFER_CUBINS_REPOSITORY + "/" + cubin_fname
|
|
logger.info(f"Fetching cubin {name} from {uri}")
|
|
download_file(uri, cubin_path)
|
|
return load_cubin(cubin_path, sha256)
|
|
|
|
|
|
def convert_to_ctypes_char_p(data: bytes):
|
|
return ctypes.c_char_p(data)
|
|
|
|
|
|
# Keep a reference to the callback for each loaded library to prevent GC
|
|
dll_cubin_handlers = {}
|
|
|
|
|
|
def setup_cubin_loader(dll_path: str):
|
|
if dll_path in dll_cubin_handlers:
|
|
return
|
|
|
|
_LIB = ctypes.CDLL(dll_path)
|
|
|
|
# Define the correct callback type
|
|
CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p)
|
|
|
|
def get_cubin_callback(name, sha256):
|
|
# Both name and sha256 are bytes (c_char_p)
|
|
cubin = get_cubin(name.decode("utf-8"), sha256.decode("utf-8"))
|
|
_LIB.FlashInferSetCurrentCubin(
|
|
convert_to_ctypes_char_p(cubin), ctypes.c_int(len(cubin))
|
|
)
|
|
|
|
# Create the callback and keep a reference to prevent GC
|
|
cb = CALLBACK_TYPE(get_cubin_callback)
|
|
dll_cubin_handlers[dll_path] = cb
|
|
|
|
_LIB.FlashInferSetCubinCallback(cb)
|