sglang0.4.5.post1/sgl-kernel/setup.py

250 lines
7.9 KiB
Python

# Copyright 2025 SGLang Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import shutil
import sys
from pathlib import Path
import torch
from setuptools import find_packages, setup
from setuptools.command.build_py import build_py
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
root = Path(__file__).parent.resolve()
if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv:
sys.argv.extend(["--plat-name", "manylinux2014_x86_64"])
def _get_cuda_version():
if torch.version.cuda:
return tuple(map(int, torch.version.cuda.split(".")))
return (0, 0)
def _get_device_sm():
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
return major * 10 + minor
return 0
def _get_version():
with open(root / "pyproject.toml") as f:
for line in f:
if line.startswith("version"):
return line.split("=")[1].strip().strip('"')
operator_namespace = "sgl_kernel"
cutlass_default = root / "3rdparty" / "cutlass"
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
flashinfer = root / "3rdparty" / "flashinfer"
deepgemm = root / "3rdparty" / "deepgemm"
include_dirs = [
root / "include",
root / "csrc",
cutlass.resolve() / "include",
cutlass.resolve() / "tools" / "util" / "include",
flashinfer.resolve() / "include",
flashinfer.resolve() / "include" / "gemm",
flashinfer.resolve() / "csrc",
"cublas",
]
class CustomBuildPy(build_py):
def run(self):
self.copy_deepgemm_to_build_lib()
self.make_jit_include_symlinks()
build_py.run(self)
def make_jit_include_symlinks(self):
# Make symbolic links of third-party include directories
build_include_dir = os.path.join(self.build_lib, "deep_gemm/include")
os.makedirs(build_include_dir, exist_ok=True)
third_party_include_dirs = [
cutlass.resolve() / "include" / "cute",
cutlass.resolve() / "include" / "cutlass",
]
for d in third_party_include_dirs:
dirname = str(d).split("/")[-1]
src_dir = d
dst_dir = f"{build_include_dir}/{dirname}"
assert os.path.exists(src_dir)
if os.path.exists(dst_dir):
assert os.path.islink(dst_dir)
os.unlink(dst_dir)
os.symlink(src_dir, dst_dir, target_is_directory=True)
def copy_deepgemm_to_build_lib(self):
"""
This function copies DeepGemm to python's site-packages
"""
dst_dir = os.path.join(self.build_lib, "deep_gemm")
os.makedirs(dst_dir, exist_ok=True)
# Copy deepgemm/deep_gemm to the build directory
src_dir = os.path.join(str(deepgemm.resolve()), "deep_gemm")
# Remove existing directory if it exists
if os.path.exists(dst_dir):
shutil.rmtree(dst_dir)
# Copy the directory
shutil.copytree(src_dir, dst_dir)
nvcc_flags = [
"-DNDEBUG",
f"-DOPERATOR_NAMESPACE={operator_namespace}",
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-std=c++17",
"-DFLASHINFER_ENABLE_F16",
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
"-DCUTLASS_VERSIONS_GENERATED",
"-DCUTE_USE_PACKED_TUPLE=1",
"-DCUTLASS_TEST_LEVEL=0",
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1",
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
"--ptxas-options=-v",
"--expt-relaxed-constexpr",
"-Xcompiler=-Wconversion",
"-Xcompiler=-fno-strict-aliasing",
]
nvcc_flags_fp8 = [
"-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
sources = [
"csrc/allreduce/trt_reduce_internal.cu",
"csrc/allreduce/trt_reduce_kernel.cu",
"csrc/attention/lightning_attention_decode_kernel.cu",
"csrc/elementwise/activation.cu",
"csrc/elementwise/fused_add_rms_norm_kernel.cu",
"csrc/elementwise/rope.cu",
"csrc/gemm/bmm_fp8.cu",
"csrc/gemm/cublas_grouped_gemm.cu",
"csrc/gemm/awq_kernel.cu",
"csrc/gemm/fp8_gemm_kernel.cu",
"csrc/gemm/fp8_blockwise_gemm_kernel.cu",
"csrc/gemm/int8_gemm_kernel.cu",
"csrc/gemm/nvfp4_quant_entry.cu",
"csrc/gemm/nvfp4_quant_kernels.cu",
"csrc/gemm/nvfp4_scaled_mm_entry.cu",
"csrc/gemm/nvfp4_scaled_mm_kernels.cu",
"csrc/gemm/per_token_group_quant_8bit.cu",
"csrc/gemm/per_token_quant_fp8.cu",
"csrc/gemm/per_tensor_quant_fp8.cu",
"csrc/moe/moe_align_kernel.cu",
"csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/speculative/eagle_utils.cu",
"csrc/speculative/speculative_sampling.cu",
"csrc/speculative/packbit.cu",
"csrc/torch_extension.cc",
"3rdparty/flashinfer/csrc/norm.cu",
"3rdparty/flashinfer/csrc/renorm.cu",
"3rdparty/flashinfer/csrc/sampling.cu",
]
enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
enable_fp4 = os.getenv("SGL_KERNEL_ENABLE_FP4", "0") == "1"
enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
enable_sm100a = os.getenv("SGL_KERNEL_ENABLE_SM100A", "0") == "1"
cuda_version = _get_cuda_version()
sm_version = _get_device_sm()
if torch.cuda.is_available():
if cuda_version >= (12, 0) and sm_version >= 90:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
if cuda_version >= (12, 8) and sm_version >= 100:
nvcc_flags.append("-gencode=arch=compute_100,code=sm_100")
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
nvcc_flags.append("-DENABLE_NVFP4=1")
else:
nvcc_flags.append("-use_fast_math")
if sm_version >= 90:
nvcc_flags.extend(nvcc_flags_fp8)
if sm_version >= 80:
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
else:
# compilation environment without GPU
if enable_sm100a:
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
if enable_sm90a:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
if enable_fp4:
nvcc_flags.append("-DENABLE_NVFP4=1")
if enable_fp8:
nvcc_flags.extend(nvcc_flags_fp8)
if enable_bf16:
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
for flag in [
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]:
try:
torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag)
except ValueError:
pass
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python", "cuda", "cublas"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
ext_modules = [
CUDAExtension(
name="sgl_kernel.common_ops",
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"nvcc": nvcc_flags,
"cxx": cxx_flags,
},
libraries=libraries,
extra_link_args=extra_link_args,
py_limited_api=True,
),
]
setup(
name="sgl-kernel",
version=_get_version(),
packages=find_packages(where="python"),
package_dir={"": "python"},
ext_modules=ext_modules,
cmdclass={
"build_ext": BuildExtension.with_options(use_ninja=True),
"build_py": CustomBuildPy,
},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)