sglang_v0.5.2/flashinfer_0.3.1/flashinfer/jit/attention/__init__.py

61 lines
2.8 KiB
Python

"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from . import pytorch, tvm
from .pytorch import gen_cudnn_fmha_module as gen_cudnn_fmha_module
from .pytorch import gen_batch_attention_module as gen_batch_attention_module
from .pytorch import gen_batch_decode_mla_module as gen_batch_decode_mla_module
from .pytorch import gen_batch_decode_module as gen_batch_decode_module
from .pytorch import gen_batch_mla_module as gen_batch_mla_module
from .pytorch import gen_batch_prefill_module as gen_batch_prefill_module
from .pytorch import (
gen_customize_batch_decode_module as gen_customize_batch_decode_module,
)
from .pytorch import (
gen_customize_batch_prefill_module as gen_customize_batch_prefill_module,
)
from .pytorch import (
gen_customize_single_decode_module as gen_customize_single_decode_module,
)
from .pytorch import (
gen_customize_single_prefill_module as gen_customize_single_prefill_module,
)
from .pytorch import gen_fmha_cutlass_sm100a_module as gen_fmha_cutlass_sm100a_module
from .pytorch import gen_pod_module as gen_pod_module
from .pytorch import gen_single_decode_module as gen_single_decode_module
from .pytorch import gen_single_prefill_module as gen_single_prefill_module
from .pytorch import get_batch_attention_uri as get_batch_attention_uri
from .pytorch import get_batch_decode_mla_uri as get_batch_decode_mla_uri
from .pytorch import get_batch_decode_uri as get_batch_decode_uri
from .pytorch import get_batch_mla_uri as get_batch_mla_uri
from .pytorch import get_batch_prefill_uri as get_batch_prefill_uri
from .pytorch import get_pod_uri as get_pod_uri
from .pytorch import get_single_decode_uri as get_single_decode_uri
from .pytorch import get_single_prefill_uri as get_single_prefill_uri
from .pytorch import gen_trtllm_gen_fmha_module as gen_trtllm_gen_fmha_module
from .pytorch import (
gen_batch_prefill_attention_sink_module as gen_batch_prefill_attention_sink_module,
get_batch_prefill_attention_sink_uri as get_batch_prefill_attention_sink_uri,
)
from .tvm import gen_batch_mla_tvm_binding as gen_batch_mla_tvm_binding
from .tvm import (
gen_customize_batch_decode_tvm_binding as gen_customize_batch_decode_tvm_binding,
)
from .tvm import (
gen_customize_batch_prefill_tvm_binding as gen_customize_batch_prefill_tvm_binding,
)
from .tvm import gen_sampling_tvm_binding as gen_sampling_tvm_binding