92 lines
4.0 KiB
Python
92 lines
4.0 KiB
Python
"""
|
|
Copyright (c) 2024 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import ctypes
|
|
import functools
|
|
import os
|
|
|
|
# Re-export
|
|
from . import cubin_loader
|
|
from . import env as env
|
|
from .activation import gen_act_and_mul_module as gen_act_and_mul_module
|
|
from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str
|
|
from .attention import gen_cudnn_fmha_module as gen_cudnn_fmha_module
|
|
from .attention import gen_batch_attention_module as gen_batch_attention_module
|
|
from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module
|
|
from .attention import gen_batch_decode_module as gen_batch_decode_module
|
|
from .attention import gen_batch_mla_module as gen_batch_mla_module
|
|
from .attention import gen_batch_mla_tvm_binding as gen_batch_mla_tvm_binding
|
|
from .attention import gen_batch_prefill_module as gen_batch_prefill_module
|
|
from .attention import (
|
|
gen_customize_batch_decode_module as gen_customize_batch_decode_module,
|
|
)
|
|
from .attention import (
|
|
gen_customize_batch_decode_tvm_binding as gen_customize_batch_decode_tvm_binding,
|
|
)
|
|
from .attention import (
|
|
gen_customize_batch_prefill_module as gen_customize_batch_prefill_module,
|
|
)
|
|
from .attention import (
|
|
gen_customize_batch_prefill_tvm_binding as gen_customize_batch_prefill_tvm_binding,
|
|
)
|
|
from .attention import (
|
|
gen_customize_single_decode_module as gen_customize_single_decode_module,
|
|
)
|
|
from .attention import (
|
|
gen_customize_single_prefill_module as gen_customize_single_prefill_module,
|
|
)
|
|
from .attention import gen_fmha_cutlass_sm100a_module as gen_fmha_cutlass_sm100a_module
|
|
from .attention import gen_pod_module as gen_pod_module
|
|
from .attention import gen_sampling_tvm_binding as gen_sampling_tvm_binding
|
|
from .attention import gen_single_decode_module as gen_single_decode_module
|
|
from .attention import gen_single_prefill_module as gen_single_prefill_module
|
|
from .attention import get_batch_attention_uri as get_batch_attention_uri
|
|
from .attention import get_batch_decode_mla_uri as get_batch_decode_mla_uri
|
|
from .attention import get_batch_decode_uri as get_batch_decode_uri
|
|
from .attention import get_batch_mla_uri as get_batch_mla_uri
|
|
from .attention import get_batch_prefill_uri as get_batch_prefill_uri
|
|
from .attention import get_pod_uri as get_pod_uri
|
|
from .attention import get_single_decode_uri as get_single_decode_uri
|
|
from .attention import get_single_prefill_uri as get_single_prefill_uri
|
|
from .attention import gen_trtllm_gen_fmha_module as gen_trtllm_gen_fmha_module
|
|
from .core import JitSpec as JitSpec
|
|
from .core import build_jit_specs as build_jit_specs
|
|
from .core import clear_cache_dir as clear_cache_dir
|
|
from .core import gen_jit_spec as gen_jit_spec
|
|
from .core import sm90a_nvcc_flags as sm90a_nvcc_flags
|
|
from .core import sm100a_nvcc_flags as sm100a_nvcc_flags
|
|
from .core import sm103a_nvcc_flags as sm103a_nvcc_flags
|
|
from .core import sm110a_nvcc_flags as sm110a_nvcc_flags
|
|
from .core import sm120a_nvcc_flags as sm120a_nvcc_flags
|
|
from .core import sm121a_nvcc_flags as sm121a_nvcc_flags
|
|
from .core import current_compilation_context as current_compilation_context
|
|
from .cubin_loader import setup_cubin_loader
|
|
|
|
|
|
@functools.cache
|
|
def get_cudnn_fmha_gen_module():
|
|
mod = gen_cudnn_fmha_module()
|
|
op = mod.build_and_load()
|
|
setup_cubin_loader(mod.get_library_path())
|
|
return op
|
|
|
|
|
|
cuda_lib_path = os.environ.get(
|
|
"CUDA_LIB_PATH", "/usr/local/cuda/targets/x86_64-linux/lib/"
|
|
)
|
|
if os.path.exists(f"{cuda_lib_path}/libcudart.so.12"):
|
|
ctypes.CDLL(f"{cuda_lib_path}/libcudart.so.12", mode=ctypes.RTLD_GLOBAL)
|