276 lines
8.8 KiB
Python
276 lines
8.8 KiB
Python
import dataclasses
|
|
import logging
|
|
import os
|
|
import warnings
|
|
from contextlib import nullcontext
|
|
from pathlib import Path
|
|
from typing import List, Optional, Sequence, Union
|
|
|
|
import torch
|
|
from filelock import FileLock
|
|
|
|
from . import env as jit_env
|
|
from .cpp_ext import generate_ninja_build_for_op, run_ninja
|
|
from .utils import write_if_different
|
|
from ..compilation_context import CompilationContext
|
|
|
|
os.makedirs(jit_env.FLASHINFER_WORKSPACE_DIR, exist_ok=True)
|
|
os.makedirs(jit_env.FLASHINFER_CSRC_DIR, exist_ok=True)
|
|
|
|
|
|
class FlashInferJITLogger(logging.Logger):
|
|
def __init__(self, name):
|
|
super().__init__(name)
|
|
logging_level = os.getenv("FLASHINFER_LOGGING_LEVEL", "info")
|
|
self.setLevel(logging_level.upper())
|
|
self.addHandler(logging.StreamHandler())
|
|
log_path = jit_env.FLASHINFER_WORKSPACE_DIR / "flashinfer_jit.log"
|
|
if not os.path.exists(log_path):
|
|
# create an empty file
|
|
with open(log_path, "w") as f: # noqa: F841
|
|
pass
|
|
self.addHandler(logging.FileHandler(log_path))
|
|
# set the format of the log
|
|
self.handlers[0].setFormatter(
|
|
logging.Formatter(
|
|
"%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - flashinfer.jit: %(message)s"
|
|
)
|
|
)
|
|
self.handlers[1].setFormatter(
|
|
logging.Formatter(
|
|
"%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - flashinfer.jit: %(message)s"
|
|
)
|
|
)
|
|
|
|
|
|
logger = FlashInferJITLogger("flashinfer.jit")
|
|
|
|
|
|
def check_cuda_arch():
|
|
# Collect all detected CUDA architectures
|
|
eligible = False
|
|
for major, minor in current_compilation_context.TARGET_CUDA_ARCHS:
|
|
if major >= 8:
|
|
eligible = True
|
|
elif major == 7 and minor.isdigit():
|
|
if int(minor) >= 5:
|
|
eligible = True
|
|
|
|
# Raise error only if all detected architectures are lower than sm75
|
|
if not eligible:
|
|
raise RuntimeError("FlashInfer requires GPUs with sm75 or higher")
|
|
|
|
|
|
def clear_cache_dir():
|
|
if os.path.exists(jit_env.FLASHINFER_JIT_DIR):
|
|
import shutil
|
|
|
|
shutil.rmtree(jit_env.FLASHINFER_JIT_DIR)
|
|
|
|
|
|
common_nvcc_flags = [
|
|
"-DFLASHINFER_ENABLE_FP8_E8M0",
|
|
"-DFLASHINFER_ENABLE_FP4_E2M1",
|
|
]
|
|
sm90a_nvcc_flags = ["-gencode=arch=compute_90a,code=sm_90a"] + common_nvcc_flags
|
|
sm100a_nvcc_flags = ["-gencode=arch=compute_100a,code=sm_100a"] + common_nvcc_flags
|
|
sm103a_nvcc_flags = ["-gencode=arch=compute_103a,code=sm_103a"] + common_nvcc_flags
|
|
sm110a_nvcc_flags = ["-gencode=arch=compute_110a,code=sm_110a"] + common_nvcc_flags
|
|
sm120a_nvcc_flags = ["-gencode=arch=compute_120a,code=sm_120a"] + common_nvcc_flags
|
|
sm121a_nvcc_flags = ["-gencode=arch=compute_121a,code=sm_121a"] + common_nvcc_flags
|
|
|
|
current_compilation_context = CompilationContext()
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class JitSpec:
|
|
name: str
|
|
sources: List[Path]
|
|
extra_cflags: Optional[List[str]]
|
|
extra_cuda_cflags: Optional[List[str]]
|
|
extra_ldflags: Optional[List[str]]
|
|
extra_include_dirs: Optional[List[Path]]
|
|
is_class: bool = False
|
|
needs_device_linking: bool = False
|
|
|
|
@property
|
|
def ninja_path(self) -> Path:
|
|
return jit_env.FLASHINFER_JIT_DIR / self.name / "build.ninja"
|
|
|
|
@property
|
|
def jit_library_path(self) -> Path:
|
|
return jit_env.FLASHINFER_JIT_DIR / self.name / f"{self.name}.so"
|
|
|
|
def get_library_path(self) -> Path:
|
|
if self.is_aot:
|
|
return self.aot_path
|
|
return self.jit_library_path
|
|
|
|
@property
|
|
def aot_path(self) -> Path:
|
|
return jit_env.FLASHINFER_AOT_DIR / self.name / f"{self.name}.so"
|
|
|
|
@property
|
|
def is_aot(self) -> bool:
|
|
return self.aot_path.exists()
|
|
|
|
@property
|
|
def lock_path(self) -> Path:
|
|
return get_tmpdir() / f"{self.name}.lock"
|
|
|
|
def write_ninja(self) -> None:
|
|
ninja_path = self.ninja_path
|
|
ninja_path.parent.mkdir(parents=True, exist_ok=True)
|
|
content = generate_ninja_build_for_op(
|
|
name=self.name,
|
|
sources=self.sources,
|
|
extra_cflags=self.extra_cflags,
|
|
extra_cuda_cflags=self.extra_cuda_cflags,
|
|
extra_ldflags=self.extra_ldflags,
|
|
extra_include_dirs=self.extra_include_dirs,
|
|
needs_device_linking=self.needs_device_linking,
|
|
)
|
|
write_if_different(ninja_path, content)
|
|
|
|
def build(self, verbose: bool, need_lock: bool = True) -> None:
|
|
lock = (
|
|
FileLock(self.lock_path, thread_local=False) if need_lock else nullcontext()
|
|
)
|
|
with lock:
|
|
run_ninja(jit_env.FLASHINFER_JIT_DIR, self.ninja_path, verbose)
|
|
|
|
def load(self, so_path: Path, class_name: str = None):
|
|
load_class = class_name is not None
|
|
loader = torch.classes if load_class else torch.ops
|
|
loader.load_library(so_path)
|
|
if load_class:
|
|
cls = torch._C._get_custom_class_python_wrapper(self.name, class_name)
|
|
return cls
|
|
return getattr(loader, self.name)
|
|
|
|
def build_and_load(self, class_name: str = None):
|
|
if self.is_aot:
|
|
return self.load(self.aot_path, class_name)
|
|
|
|
# Guard both build and load with the same lock to avoid race condition
|
|
# where another process is building the library and removes the .so file.
|
|
with FileLock(self.lock_path, thread_local=False):
|
|
so_path = self.jit_library_path
|
|
verbose = os.environ.get("FLASHINFER_JIT_VERBOSE", "0") == "1"
|
|
self.build(verbose, need_lock=False)
|
|
result = self.load(so_path, class_name)
|
|
|
|
return result
|
|
|
|
|
|
def gen_jit_spec(
|
|
name: str,
|
|
sources: Sequence[Union[str, Path]],
|
|
extra_cflags: Optional[List[str]] = None,
|
|
extra_cuda_cflags: Optional[List[str]] = None,
|
|
extra_ldflags: Optional[List[str]] = None,
|
|
extra_include_paths: Optional[List[Union[str, Path]]] = None,
|
|
needs_device_linking: bool = False,
|
|
) -> JitSpec:
|
|
check_cuda_arch()
|
|
verbose = os.environ.get("FLASHINFER_JIT_VERBOSE", "0") == "1"
|
|
|
|
cflags = ["-O3", "-std=c++17", "-Wno-switch-bool"]
|
|
cuda_cflags = [
|
|
"-O3",
|
|
"-std=c++17",
|
|
f"--threads={os.environ.get('FLASHINFER_NVCC_THREADS', '1')}",
|
|
"-use_fast_math",
|
|
"-DFLASHINFER_ENABLE_F16",
|
|
"-DFLASHINFER_ENABLE_BF16",
|
|
"-DFLASHINFER_ENABLE_FP8_E4M3",
|
|
"-DFLASHINFER_ENABLE_FP8_E5M2",
|
|
]
|
|
if verbose:
|
|
cuda_cflags += [
|
|
"-g",
|
|
"-lineinfo",
|
|
"--ptxas-options=-v",
|
|
"--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage",
|
|
"-DCUTLASS_DEBUG_TRACE_LEVEL=2",
|
|
]
|
|
else:
|
|
# non debug mode
|
|
cuda_cflags += ["-DNDEBUG"]
|
|
|
|
if extra_cflags is not None:
|
|
cflags += extra_cflags
|
|
if extra_cuda_cflags is not None:
|
|
cuda_cflags += extra_cuda_cflags
|
|
|
|
spec = JitSpec(
|
|
name=name,
|
|
sources=[Path(x) for x in sources],
|
|
extra_cflags=cflags,
|
|
extra_cuda_cflags=cuda_cflags,
|
|
extra_ldflags=extra_ldflags,
|
|
extra_include_dirs=(
|
|
[Path(x) for x in extra_include_paths]
|
|
if extra_include_paths is not None
|
|
else None
|
|
),
|
|
needs_device_linking=needs_device_linking,
|
|
)
|
|
spec.write_ninja()
|
|
return spec
|
|
|
|
|
|
def get_tmpdir() -> Path:
|
|
# TODO(lequn): Try /dev/shm first. This should help Lock on NFS.
|
|
tmpdir = jit_env.FLASHINFER_JIT_DIR / "tmp"
|
|
if not tmpdir.exists():
|
|
tmpdir.mkdir(parents=True, exist_ok=True)
|
|
return tmpdir
|
|
|
|
|
|
def build_jit_specs(
|
|
specs: List[JitSpec],
|
|
verbose: bool = False,
|
|
skip_prebuilt: bool = True,
|
|
) -> None:
|
|
lines: List[str] = []
|
|
for spec in specs:
|
|
if skip_prebuilt and spec.aot_path.exists():
|
|
continue
|
|
lines.append(f"subninja {spec.ninja_path}")
|
|
if not lines:
|
|
return
|
|
|
|
lines = ["ninja_required_version = 1.3"] + lines + [""]
|
|
|
|
tmpdir = get_tmpdir()
|
|
with FileLock(tmpdir / "flashinfer_jit.lock", thread_local=False):
|
|
ninja_path = tmpdir / "flashinfer_jit.ninja"
|
|
write_if_different(ninja_path, "\n".join(lines))
|
|
run_ninja(jit_env.FLASHINFER_JIT_DIR, ninja_path, verbose)
|
|
|
|
|
|
def load_cuda_ops(
|
|
name: str,
|
|
sources: List[Union[str, Path]],
|
|
extra_cflags: Optional[List[str]] = None,
|
|
extra_cuda_cflags: Optional[List[str]] = None,
|
|
extra_ldflags=None,
|
|
extra_include_paths=None,
|
|
):
|
|
# TODO(lequn): Remove this function and use JitSpec directly.
|
|
warnings.warn(
|
|
"load_cuda_ops is deprecated. Use JitSpec directly.",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
spec = gen_jit_spec(
|
|
name=name,
|
|
sources=sources,
|
|
extra_cflags=extra_cflags,
|
|
extra_cuda_cflags=extra_cuda_cflags,
|
|
extra_ldflags=extra_ldflags,
|
|
extra_include_paths=extra_include_paths,
|
|
)
|
|
return spec.build_and_load()
|