sglang_v0.5.2/flashinfer_0.3.1/flashinfer/tllm_utils.py

48 lines
1.5 KiB
Python

import functools
from .jit import env as jit_env
from .jit import gen_jit_spec
def gen_trtllm_utils_module():
return gen_jit_spec(
"trtllm_utils",
[
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal/tensorrt_llm/kernels/delayStream.cu",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp",
],
extra_include_paths=[
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal"
/ "tensorrt_llm"
/ "cutlass_extensions"
/ "include",
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal"
/ "tensorrt_llm"
/ "kernels"
/ "cutlass_kernels"
/ "include",
jit_env.FLASHINFER_CSRC_DIR
/ "nv_internal"
/ "tensorrt_llm"
/ "kernels"
/ "cutlass_kernels",
],
)
@functools.cache
def get_trtllm_utils_module():
return gen_trtllm_utils_module().build_and_load()
def delay_kernel(stream_delay_micro_secs):
get_trtllm_utils_module().delay_kernel(stream_delay_micro_secs)