sglang_v0.5.2/flashinfer_0.3.1/flashinfer/jit/cpp_ext.py

271 lines
8.0 KiB
Python

# Adapted from https://github.com/pytorch/pytorch/blob/v2.7.0/torch/utils/cpp_extension.py
import functools
import os
import re
import subprocess
import sys
import sysconfig
from packaging.version import Version
from pathlib import Path
from typing import List, Optional
import torch
from torch.utils.cpp_extension import (
_TORCH_PATH,
CUDA_HOME,
_get_num_workers,
_get_pybind11_abi_build_flags,
)
from . import env as jit_env
from ..compilation_context import CompilationContext
@functools.cache
def get_cuda_path() -> str:
if CUDA_HOME is None:
# get output of "which nvcc"
result = subprocess.run(["which", "nvcc"], capture_output=True)
if result.returncode != 0:
raise RuntimeError("Could not find nvcc")
return result.stdout.decode("utf-8").strip()
else:
return CUDA_HOME
@functools.cache
def get_cuda_version() -> Version:
if CUDA_HOME is None:
nvcc = "nvcc"
else:
nvcc = os.path.join(CUDA_HOME, "bin/nvcc")
txt = subprocess.check_output([nvcc, "--version"], text=True)
matches = re.findall(r"release (\d+\.\d+),", txt)
if not matches:
raise RuntimeError(
f"Could not parse CUDA version from nvcc --version output: {txt}"
)
return Version(matches[0])
def is_cuda_version_at_least(version_str: str) -> bool:
return get_cuda_version() >= Version(version_str)
def _get_glibcxx_abi_build_flags() -> List[str]:
glibcxx_abi_cflags = [
"-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
]
return glibcxx_abi_cflags
def join_multiline(vs: List[str]) -> str:
return " $\n ".join(vs)
def generate_ninja_build_for_op(
name: str,
sources: List[Path],
extra_cflags: Optional[List[str]],
extra_cuda_cflags: Optional[List[str]],
extra_ldflags: Optional[List[str]],
extra_include_dirs: Optional[List[Path]],
needs_device_linking: bool = False,
) -> str:
system_includes = [
sysconfig.get_path("include"),
"$torch_home/include",
"$torch_home/include/torch/csrc/api/include",
"$cuda_home/include",
"$cuda_home/include/cccl",
jit_env.FLASHINFER_INCLUDE_DIR.resolve(),
jit_env.FLASHINFER_CSRC_DIR.resolve(),
]
system_includes += [p.resolve() for p in jit_env.CUTLASS_INCLUDE_DIRS]
system_includes.append(jit_env.SPDLOG_INCLUDE_DIR.resolve())
common_cflags = [
"-DTORCH_EXTENSION_NAME=$name",
"-DTORCH_API_INCLUDE_EXTENSION_H",
"-DPy_LIMITED_API=0x03090000",
]
common_cflags += _get_pybind11_abi_build_flags()
common_cflags += _get_glibcxx_abi_build_flags()
if extra_include_dirs is not None:
for extra_dir in extra_include_dirs:
common_cflags.append(f"-I{extra_dir.resolve()}")
for sys_dir in system_includes:
common_cflags.append(f"-isystem {sys_dir}")
cflags = [
"$common_cflags",
"-fPIC",
]
if extra_cflags is not None:
cflags += extra_cflags
cuda_cflags: List[str] = []
cc_env = os.environ.get("CC")
if cc_env is not None:
cuda_cflags += ["-ccbin", cc_env]
cuda_cflags += [
"$common_cflags",
"--compiler-options=-fPIC",
"--expt-relaxed-constexpr",
]
cuda_version = get_cuda_version()
# enable -static-global-template-stub when cuda version >= 12.8
if cuda_version >= Version("12.8"):
cuda_cflags += [
"-static-global-template-stub=false",
]
cpp_ext_initial_compilation_context = CompilationContext()
global_flags = cpp_ext_initial_compilation_context.get_nvcc_flags_list()
if extra_cuda_cflags is not None:
# Check if module provides architecture flags
module_has_gencode = any(
flag.startswith("-gencode=") for flag in extra_cuda_cflags
)
if module_has_gencode:
# Use module's architecture flags, but keep global non-architecture flags
global_non_arch_flags = [
flag for flag in global_flags if not flag.startswith("-gencode=")
]
cuda_cflags += global_non_arch_flags + extra_cuda_cflags
else:
# No module architecture flags, use both global and module flags
cuda_cflags += global_flags + extra_cuda_cflags
else:
# No module flags, use global flags
cuda_cflags += global_flags
ldflags = [
"-shared",
"-L$torch_home/lib",
"-L$cuda_home/lib64",
"-lc10",
"-lc10_cuda",
"-ltorch_cpu",
"-ltorch_cuda",
"-ltorch",
"-lcudart",
]
env_extra_ldflags = os.environ.get("FLASHINFER_EXTRA_LDFLAGS")
if env_extra_ldflags:
try:
import shlex
ldflags += shlex.split(env_extra_ldflags)
except ValueError as e:
print(
f"Warning: Could not parse FLASHINFER_EXTRA_LDFLAGS with shlex: {e}. Falling back to simple split.",
file=sys.stderr,
)
ldflags += env_extra_ldflags.split()
if extra_ldflags is not None:
ldflags += extra_ldflags
cxx = os.environ.get("CXX", "c++")
cuda_home = CUDA_HOME or "/usr/local/cuda"
nvcc = os.environ.get("PYTORCH_NVCC", "$cuda_home/bin/nvcc")
lines = [
"ninja_required_version = 1.3",
f"name = {name}",
f"cuda_home = {cuda_home}",
f"torch_home = {_TORCH_PATH}",
f"cxx = {cxx}",
f"nvcc = {nvcc}",
"",
"common_cflags = " + join_multiline(common_cflags),
"cflags = " + join_multiline(cflags),
"post_cflags =",
"cuda_cflags = " + join_multiline(cuda_cflags),
"cuda_post_cflags =",
"ldflags = " + join_multiline(ldflags),
"",
"rule compile",
" command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags",
" depfile = $out.d",
" deps = gcc",
"",
"rule cuda_compile",
" command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags",
" depfile = $out.d",
" deps = gcc",
"",
]
# Add nvcc linking rule for device code
if needs_device_linking:
lines.extend(
[
"rule nvcc_link",
" command = $nvcc -shared $in $ldflags -o $out",
"",
]
)
else:
lines.extend(
[
"rule link",
" command = $cxx $in $ldflags -o $out",
"",
]
)
objects = []
for source in sources:
is_cuda = source.suffix == ".cu"
object_suffix = ".cuda.o" if is_cuda else ".o"
cmd = "cuda_compile" if is_cuda else "compile"
obj_name = source.with_suffix(object_suffix).name
obj = f"$name/{obj_name}"
objects.append(obj)
lines.append(f"build {obj}: {cmd} {source.resolve()}")
lines.append("")
link_rule = "nvcc_link" if needs_device_linking else "link"
lines.append(f"build $name/$name.so: {link_rule} " + " ".join(objects))
lines.append("default $name/$name.so")
lines.append("")
return "\n".join(lines)
def run_ninja(workdir: Path, ninja_file: Path, verbose: bool) -> None:
workdir.mkdir(parents=True, exist_ok=True)
command = [
"ninja",
"-v",
"-C",
str(workdir.resolve()),
"-f",
str(ninja_file.resolve()),
]
num_workers = _get_num_workers(verbose)
if num_workers is not None:
command += ["-j", str(num_workers)]
sys.stdout.flush()
sys.stderr.flush()
try:
subprocess.run(
command,
stdout=None if verbose else subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=str(workdir.resolve()),
check=True,
text=True,
)
except subprocess.CalledProcessError as e:
msg = "Ninja build failed."
if e.output:
msg += " Ninja output:\n" + e.output
raise RuntimeError(msg) from e