61 lines
2.8 KiB
Python
61 lines
2.8 KiB
Python
"""
|
|
Copyright (c) 2025 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
from . import pytorch, tvm
|
|
from .pytorch import gen_cudnn_fmha_module as gen_cudnn_fmha_module
|
|
from .pytorch import gen_batch_attention_module as gen_batch_attention_module
|
|
from .pytorch import gen_batch_decode_mla_module as gen_batch_decode_mla_module
|
|
from .pytorch import gen_batch_decode_module as gen_batch_decode_module
|
|
from .pytorch import gen_batch_mla_module as gen_batch_mla_module
|
|
from .pytorch import gen_batch_prefill_module as gen_batch_prefill_module
|
|
from .pytorch import (
|
|
gen_customize_batch_decode_module as gen_customize_batch_decode_module,
|
|
)
|
|
from .pytorch import (
|
|
gen_customize_batch_prefill_module as gen_customize_batch_prefill_module,
|
|
)
|
|
from .pytorch import (
|
|
gen_customize_single_decode_module as gen_customize_single_decode_module,
|
|
)
|
|
from .pytorch import (
|
|
gen_customize_single_prefill_module as gen_customize_single_prefill_module,
|
|
)
|
|
from .pytorch import gen_fmha_cutlass_sm100a_module as gen_fmha_cutlass_sm100a_module
|
|
from .pytorch import gen_pod_module as gen_pod_module
|
|
from .pytorch import gen_single_decode_module as gen_single_decode_module
|
|
from .pytorch import gen_single_prefill_module as gen_single_prefill_module
|
|
from .pytorch import get_batch_attention_uri as get_batch_attention_uri
|
|
from .pytorch import get_batch_decode_mla_uri as get_batch_decode_mla_uri
|
|
from .pytorch import get_batch_decode_uri as get_batch_decode_uri
|
|
from .pytorch import get_batch_mla_uri as get_batch_mla_uri
|
|
from .pytorch import get_batch_prefill_uri as get_batch_prefill_uri
|
|
from .pytorch import get_pod_uri as get_pod_uri
|
|
from .pytorch import get_single_decode_uri as get_single_decode_uri
|
|
from .pytorch import get_single_prefill_uri as get_single_prefill_uri
|
|
from .pytorch import gen_trtllm_gen_fmha_module as gen_trtllm_gen_fmha_module
|
|
from .pytorch import (
|
|
gen_batch_prefill_attention_sink_module as gen_batch_prefill_attention_sink_module,
|
|
get_batch_prefill_attention_sink_uri as get_batch_prefill_attention_sink_uri,
|
|
)
|
|
from .tvm import gen_batch_mla_tvm_binding as gen_batch_mla_tvm_binding
|
|
from .tvm import (
|
|
gen_customize_batch_decode_tvm_binding as gen_customize_batch_decode_tvm_binding,
|
|
)
|
|
from .tvm import (
|
|
gen_customize_batch_prefill_tvm_binding as gen_customize_batch_prefill_tvm_binding,
|
|
)
|
|
from .tvm import gen_sampling_tvm_binding as gen_sampling_tvm_binding
|