sglang_v0.5.2/sglang/sgl-kernel/CMakeLists.txt

535 lines
18 KiB
CMake

cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
project(sgl-kernel LANGUAGES CXX CUDA)
# CMake
cmake_policy(SET CMP0169 OLD)
cmake_policy(SET CMP0177 NEW)
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
set(CMAKE_COLOR_DIAGNOSTICS ON)
set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_SHARED_LIBRARY_PREFIX "")
# Python
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
# CXX
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
# CUDA
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
set_property(GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON)
message(STATUS "Detected CUDA_VERSION=${CUDA_VERSION}")
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
message("CUDA_VERSION ${CUDA_VERSION} >= 13.0")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8")
message("CUDA_VERSION ${CUDA_VERSION} >= 12.8")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4")
message("CUDA_VERSION ${CUDA_VERSION} >= 12.4")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.1")
message("CUDA_VERSION ${CUDA_VERSION} >= 12.1")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8")
message("CUDA_VERSION ${CUDA_VERSION} >= 11.8")
endif()
# Torch
find_package(Torch REQUIRED)
# clean Torch Flag
clear_cuda_arches(CMAKE_FLAG)
include(FetchContent)
# cutlass
FetchContent_Declare(
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG a49a78ffefc86a87160dfe0ccc3a3a2d1622c918
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-cutlass)
# DeepGEMM
FetchContent_Declare(
repo-deepgemm
GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM
GIT_TAG sgl
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-deepgemm)
FetchContent_Declare(
repo-fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt
GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-fmt)
# Triton
FetchContent_Declare(
repo-triton
GIT_REPOSITORY "https://github.com/triton-lang/triton"
GIT_TAG 8f9f695ea8fde23a0c7c88e4ab256634ca27789f
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-triton)
# flashinfer
FetchContent_Declare(
repo-flashinfer
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flashinfer)
# flash-attention
FetchContent_Declare(
repo-flash-attention
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_TAG sgl-kernel
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flash-attention)
# flash-attention origin
FetchContent_Declare(
repo-flash-attention-origin
GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git
GIT_TAG 203b9b3dba39d5d08dffb49c09aa622984dff07d
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flash-attention-origin)
# mscclpp
FetchContent_Declare(
repo-mscclpp
GIT_REPOSITORY https://github.com/microsoft/mscclpp.git
GIT_TAG 51eca89d20f0cfb3764ccd764338d7b22cd486a6
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-mscclpp)
# ccache option
option(ENABLE_CCACHE "Whether to use ccache" ON)
find_program(CCACHE_FOUND ccache)
if(CCACHE_FOUND AND ENABLE_CCACHE AND DEFINED ENV{CCACHE_DIR})
message(STATUS "Building with CCACHE enabled")
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache")
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "ccache")
endif()
# Enable gencode below SM90
option(ENABLE_BELOW_SM90 "Enable below SM90" ON)
if (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
set(ENABLE_BELOW_SM90 OFF)
message(STATUS "For aarch64, disable gencode below SM90 by default")
endif()
include_directories(
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/csrc
${repo-cutlass_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-mscclpp_SOURCE_DIR}/include
)
set(SGL_KERNEL_CUDA_FLAGS
"-DNDEBUG"
"-DOPERATOR_NAMESPACE=sgl-kernel"
"-O3"
"-Xcompiler"
"-fPIC"
"-gencode=arch=compute_90,code=sm_90"
"-std=c++17"
"-DFLASHINFER_ENABLE_F16"
"-DCUTE_USE_PACKED_TUPLE=1"
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
"-DCUTLASS_VERSIONS_GENERATED"
"-DCUTLASS_TEST_LEVEL=0"
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr"
"--expt-extended-lambda"
# The following flag leads to the CMAKE_BUILD_PARALLEL_LEVEL breaking,
# it triggers OOM with low memory host. Extract the threads number to
# option named SGL_KERNEL_COMPILE_THREADS, default value 32.
# "--threads=32"
# Supress warnings
"-Xcompiler=-Wno-clang-format-violations"
"-Xcompiler=-Wno-conversion"
"-Xcompiler=-Wno-deprecated-declarations"
"-Xcompiler=-Wno-terminate"
"-Xcompiler=-Wfatal-errors"
"-Xcompiler=-ftemplate-backtrace-limit=1"
"-Xcudafe=--diag_suppress=177" # variable was declared but never referenced
# uncomment to debug
# "--ptxas-options=-v"
# "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
)
set(SGL_KERNEL_COMPILE_THREADS 32 CACHE STRING "Set compilation threads, default 32")
# When SGL_KERNEL_COMPILE_THREADS value is less than 1, set it to 1
if (NOT SGL_KERNEL_COMPILE_THREADS MATCHES "^[0-9]+$")
message(FATAL_ERROR "SGL_KERNEL_COMPILE_THREADS must be an integer, but was set to '${SGL_KERNEL_COMPILE_THREADS}'.")
elseif (SGL_KERNEL_COMPILE_THREADS LESS 1)
message(STATUS "SGL_KERNEL_COMPILE_THREADS was set to a value less than 1. Using 1 instead.")
set(SGL_KERNEL_COMPILE_THREADS 1)
endif()
list(APPEND SGL_KERNEL_CUDA_FLAGS
"--threads=${SGL_KERNEL_COMPILE_THREADS}"
)
option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF)
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
if (SGL_KERNEL_ENABLE_BF16)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-DFLASHINFER_ENABLE_BF16"
)
endif()
if (SGL_KERNEL_ENABLE_FP8)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-DFLASHINFER_ENABLE_FP8"
"-DFLASHINFER_ENABLE_FP8_E4M3"
"-DFLASHINFER_ENABLE_FP8_E5M2"
)
endif()
if (ENABLE_BELOW_SM90)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_89,code=sm_89"
)
endif()
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_100a,code=sm_100a"
"-gencode=arch=compute_120a,code=sm_120a"
)
# refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_103a,code=sm_103a"
"-gencode=arch=compute_110a,code=sm_110a"
"-gencode=arch=compute_121a,code=sm_121a"
"--compress-mode=size"
)
else()
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_101a,code=sm_101a"
)
endif()
else()
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-use_fast_math"
)
endif()
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
set(SGL_KERNEL_ENABLE_FA3 ON)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_90a,code=sm_90a"
)
endif()
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-DENABLE_NVFP4=1"
)
endif()
set(SOURCES
"csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/mscclpp_allreduce.cu"
"csrc/attention/cascade.cu"
"csrc/attention/cutlass_mla_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/elementwise/activation.cu"
"csrc/elementwise/cast.cu"
"csrc/elementwise/copy.cu"
"csrc/elementwise/concat_mla.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
"csrc/elementwise/rope.cu"
"csrc/common_extension.cc"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
"csrc/gemm/dsv3_router_gemm_bf16_out.cu"
"csrc/gemm/dsv3_router_gemm_entry.cu"
"csrc/gemm/dsv3_router_gemm_float_out.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
"csrc/gemm/nvfp4_expert_quant.cu"
"csrc/gemm/nvfp4_quant_entry.cu"
"csrc/gemm/nvfp4_quant_kernels.cu"
"csrc/gemm/nvfp4_scaled_mm_entry.cu"
"csrc/gemm/nvfp4_scaled_mm_kernels.cu"
"csrc/gemm/per_tensor_quant_fp8.cu"
"csrc/gemm/per_token_group_quant_8bit.cu"
"csrc/gemm/per_token_quant_fp8.cu"
"csrc/gemm/qserve_w4a8_per_chn_gemm.cu"
"csrc/gemm/qserve_w4a8_per_group_gemm.cu"
"csrc/gemm/marlin/gptq_marlin.cu"
"csrc/gemm/marlin/gptq_marlin_repack.cu"
"csrc/gemm/marlin/awq_marlin_repack.cu"
"csrc/gemm/gptq/gptq_kernel.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/mamba/causal_conv1d.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/moe/nvfp4_blockwise_moe.cu"
"csrc/moe/fp8_blockwise_moe_kernel.cu"
"csrc/moe/prepare_moe_input.cu"
"csrc/memory/store.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/packbit.cu"
"csrc/speculative/speculative_sampling.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
)
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
target_include_directories(common_ops PRIVATE
${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
${repo-cutlass_SOURCE_DIR}/examples/common
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
)
find_package(Python3 COMPONENTS Interpreter REQUIRED)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"
OUTPUT_VARIABLE TORCH_CXX11_ABI
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(TORCH_CXX11_ABI STREQUAL "0")
message(STATUS "Using old C++ ABI (-D_GLIBCXX_USE_CXX11_ABI=0)")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
else()
message(STATUS "Using new C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI=1)")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
endif()
# mscclpp
set(MSCCLPP_USE_CUDA ON)
set(MSCCLPP_BYPASS_GPU_CHECK ON)
set(MSCCLPP_BUILD_TESTS OFF)
add_subdirectory(
${repo-mscclpp_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}/mscclpp-build
)
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
# flash attention
target_compile_definitions(common_ops PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
FLASHATTENTION_DISABLE_UNEVEN_K
)
install(TARGETS common_ops LIBRARY DESTINATION sgl_kernel)
# ============================ Optional Install ============================= #
# set flash-attention sources file
# Now FA3 support sm80/sm86/sm90
if (SGL_KERNEL_ENABLE_FA3)
set(SGL_FLASH_KERNEL_CUDA_FLAGS
"-DNDEBUG"
"-DOPERATOR_NAMESPACE=sgl-kernel"
"-O3"
"-Xcompiler"
"-fPIC"
"-gencode=arch=compute_90a,code=sm_90a"
"-std=c++17"
"-DCUTE_USE_PACKED_TUPLE=1"
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
"-DCUTLASS_VERSIONS_GENERATED"
"-DCUTLASS_TEST_LEVEL=0"
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr"
"--expt-extended-lambda"
"--use_fast_math"
"-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing"
)
if (ENABLE_BELOW_SM90)
list(APPEND SGL_FLASH_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_86,code=sm_86"
)
# SM8X Logic
file(GLOB FA3_SM8X_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu")
endif()
file(GLOB FA3_BF16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
file(GLOB FA3_BF16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
file(GLOB FA3_FP16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
# FP8 source files
file(GLOB FA3_FP8_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
file(GLOB FA3_FP8_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS} ${FA3_SM8X_GEN_SRCS})
set(FLASH_SOURCES
"csrc/flash_extension.cc"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
"${FA3_GEN_SRCS}"
)
Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})
target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
target_include_directories(flash_ops PRIVATE
${repo-flash-attention_SOURCE_DIR}/hopper
)
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
set(FLASH_OPS_COMPILE_DEFS
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
FLASHATTENTION_DISABLE_UNEVEN_K
FLASHATTENTION_VARLEN_ONLY
)
if(NOT ENABLE_BELOW_SM90)
list(APPEND FLASH_OPS_COMPILE_DEFS FLASHATTENTION_DISABLE_SM8x)
endif()
target_compile_definitions(flash_ops PRIVATE ${FLASH_OPS_COMPILE_DEFS})
endif()
# Build spatial_ops as a separate, optional extension for green contexts
set(SPATIAL_SOURCES
"csrc/spatial/greenctx_stream.cu"
"csrc/spatial_extension.cc"
)
Python_add_library(spatial_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SPATIAL_SOURCES})
target_compile_options(spatial_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel)
# ============================ DeepGEMM (JIT) ============================= #
# Create a separate library for DeepGEMM's Python API.
# This keeps its compilation isolated from the main common_ops.
set(DEEPGEMM_SOURCES
"${repo-deepgemm_SOURCE_DIR}/csrc/python_api.cpp"
)
Python_add_library(deep_gemm_cpp MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${DEEPGEMM_SOURCES})
# Link against necessary libraries, including nvrtc for JIT compilation.
target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} c10 cuda nvrtc mscclpp_static)
# Add include directories needed by DeepGEMM.
target_include_directories(deep_gemm_cpp PRIVATE
${repo-deepgemm_SOURCE_DIR}/deep_gemm/include
${repo-cutlass_SOURCE_DIR}/include
${repo-fmt_SOURCE_DIR}/include
)
# Apply the same compile options as common_ops.
target_compile_options(deep_gemm_cpp PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
# Create an empty __init__.py to make `deepgemm` a Python package.
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py "")
install(
FILES ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py
DESTINATION deep_gemm
RENAME __init__.py
)
# Install the compiled DeepGEMM API library.
install(TARGETS deep_gemm_cpp LIBRARY DESTINATION deep_gemm)
# Install the source files required by DeepGEMM for runtime JIT compilation.
install(
DIRECTORY ${repo-deepgemm_SOURCE_DIR}/deep_gemm/
DESTINATION deep_gemm
)
install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/"
DESTINATION "deep_gemm/include/cute")
install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cutlass/"
DESTINATION "deep_gemm/include/cutlass")
# triton_kernels
install(DIRECTORY "${repo-triton_SOURCE_DIR}/python/triton_kernels/triton_kernels/"
DESTINATION "triton_kernels"
PATTERN ".git*" EXCLUDE
PATTERN "__pycache__" EXCLUDE)
# flash attention 4
# TODO: find a better install condition.
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
# flash_attn/cute
install(DIRECTORY "${repo-flash-attention-origin_SOURCE_DIR}/flash_attn/cute/"
DESTINATION "flash_attn/cute"
PATTERN ".git*" EXCLUDE
PATTERN "__pycache__" EXCLUDE)
endif()