import argparse import os import shutil from itertools import product from pathlib import Path from typing import List, Tuple, Iterator import torch import torch.version from .activation import act_func_def_str, gen_act_and_mul_module from .fp8_quantization import gen_mxfp8_quantization_sm100_module from .cascade import gen_cascade_module from .fp4_quantization import ( gen_fp4_quantization_sm100_module, gen_fp4_quantization_sm90_module, ) from .fused_moe import ( gen_cutlass_fused_moe_sm100_module, gen_cutlass_fused_moe_sm90_module, gen_trtllm_gen_fused_moe_sm100_module, ) from .gemm import ( gen_gemm_module, gen_gemm_sm90_module, gen_gemm_sm100_module, gen_gemm_sm100_module_cutlass_fp4, gen_gemm_sm100_module_cutlass_fp8, gen_trtllm_gen_gemm_module, ) from .jit import JitSpec, build_jit_specs from .jit import env as jit_env from .jit import ( gen_batch_decode_module, gen_batch_mla_module, gen_batch_prefill_module, gen_fmha_cutlass_sm100a_module, gen_single_decode_module, gen_single_prefill_module, gen_trtllm_gen_fmha_module, ) from .mla import gen_mla_module from .norm import gen_norm_module from .page import gen_page_module from .quantization import gen_quantization_module from .rope import gen_rope_module from .sampling import gen_sampling_module from .tllm_utils import gen_trtllm_utils_module from .utils import gen_logging_module, version_at_least from .xqa import gen_xqa_module from .compilation_context import CompilationContext def gen_fa2( dtype_qo: torch.dtype, dtype_kv: torch.dtype, head_dim_qk: int, head_dim_vo: int, use_sliding_window: bool, use_logits_soft_cap: bool, ) -> Iterator[JitSpec]: if dtype_qo.itemsize == dtype_kv.itemsize and dtype_qo != dtype_kv: return if dtype_qo.itemsize == 1: return # fp8 tensor cores not supported in fa2 yield gen_single_prefill_module( backend="fa2", dtype_q=dtype_qo, dtype_kv=dtype_kv, dtype_o=dtype_qo, head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, pos_encoding_mode=0, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, use_fp16_qk_reduction=False, ) yield gen_batch_prefill_module( backend="fa2", dtype_q=dtype_qo, dtype_kv=dtype_kv, dtype_o=dtype_qo, dtype_idx=torch.int32, head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, pos_encoding_mode=0, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, use_fp16_qk_reduction=False, ) yield gen_single_decode_module( dtype_q=dtype_qo, dtype_kv=dtype_kv, dtype_o=dtype_qo, head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, pos_encoding_mode=0, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, ) yield gen_batch_decode_module( dtype_q=dtype_qo, dtype_kv=dtype_kv, dtype_o=dtype_qo, dtype_idx=torch.int32, head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, pos_encoding_mode=0, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, ) def gen_fa3( dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, head_dim_qk: int, head_dim_vo: int, use_sliding_window: bool, use_logits_soft_cap: bool, ) -> Iterator[JitSpec]: if dtype_q != dtype_kv: return # fa3 template do not support mixed precision if dtype_q.itemsize == 2: if dtype_q != dtype_o: return # for fp16, dtype_o must be the same as dtype_q/dtype_kv if dtype_kv.itemsize == 1: if head_dim_qk == 192 or head_dim_qk == 64: return # (192, 128) & (64, 64) not supported for fp8 yet. yield gen_batch_prefill_module( backend="fa3", dtype_q=dtype_q, dtype_kv=dtype_kv, dtype_o=dtype_o, dtype_idx=torch.int32, head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, pos_encoding_mode=0, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, use_fp16_qk_reduction=False, ) def gen_attention( f16_dtype_: List[torch.dtype], f8_dtype_: List[torch.dtype], fa2_head_dim_: List[Tuple[int, int]], fa3_head_dim_: List[Tuple[int, int]], use_sliding_window_: List[bool], use_logits_soft_cap_: List[bool], has_sm90: bool, has_sm100: bool, add_gemma: bool, add_oai_oss: bool, ) -> Iterator[JitSpec]: head_dim_ckv = 512 head_dim_kpe = 64 # FA2 MHA / MQA / GQA for ( (head_dim_qk, head_dim_vo), dtype_qo, dtype_kv, use_sliding_window, use_logits_soft_cap, ) in product( fa2_head_dim_, f16_dtype_, f16_dtype_ + f8_dtype_, use_sliding_window_, use_logits_soft_cap_, ): yield from gen_fa2( dtype_qo=dtype_qo, dtype_kv=dtype_kv, head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, ) # FA3 MHA / MQA / GQA if has_sm90: for ( (head_dim_qk, head_dim_vo), dtype_qkv, dtype_o, use_sliding_window, use_logits_soft_cap, ) in product( fa3_head_dim_, f16_dtype_ + f8_dtype_, f16_dtype_, use_sliding_window_, use_logits_soft_cap_, ): yield from gen_fa3( dtype_q=dtype_qkv, dtype_kv=dtype_qkv, dtype_o=dtype_o, head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, ) # Gemma if add_gemma: for ( dtype_qo, dtype_kv, (use_sliding_window, use_logits_soft_cap), ) in product( f16_dtype_, f16_dtype_ + f8_dtype_, [(True, True)], ): yield from gen_fa2( dtype_qo=dtype_qo, dtype_kv=dtype_kv, head_dim_qk=256, head_dim_vo=256, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, ) if has_sm90: for ( dtype_qkv, dtype_o, (use_sliding_window, use_logits_soft_cap), ) in product( f16_dtype_ + f8_dtype_, f16_dtype_, [(True, True)], ): yield from gen_fa3( dtype_q=dtype_qkv, dtype_kv=dtype_qkv, dtype_o=dtype_o, head_dim_qk=256, head_dim_vo=256, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, ) # OAI OSS if add_oai_oss: from .jit.attention import gen_batch_prefill_attention_sink_module for dtype in f16_dtype_: for backend in ["fa2", "fa3"]: for use_swa in [True, False]: yield gen_batch_prefill_attention_sink_module( backend=backend, dtype_q=dtype, dtype_kv=dtype, dtype_o=dtype, dtype_idx=torch.int32, head_dim_qk=64, head_dim_vo=64, pos_encoding_mode=0, use_sliding_window=use_swa, ) # fmha_cutlass_sm100a # NOTE: currently there's only one uri. if has_sm100: yield gen_fmha_cutlass_sm100a_module( dtype_q=torch.bfloat16, dtype_kv=torch.bfloat16, dtype_o=torch.bfloat16, dtype_idx=torch.int32, head_dim_qk=128, head_dim_vo=128, pos_encoding_mode=0, use_sliding_window=False, use_logits_soft_cap=False, ) # trtllm_gen_fmha yield gen_trtllm_gen_fmha_module() # MLA # NOTE: fp8 kv not supported in MLA mla_backend_ = ["fa2"] + (["fa3"] if has_sm90 else []) for dtype_qo in f16_dtype_: for backend in mla_backend_: yield gen_batch_mla_module( backend=backend, dtype_q=dtype_qo, dtype_kv=dtype_qo, dtype_o=dtype_qo, dtype_idx=torch.int32, head_dim_ckv=head_dim_ckv, head_dim_kpe=head_dim_kpe, use_profiler=False, ) # MLA SM100 if has_sm100: yield gen_mla_module() def gen_xqa( use_fp16_: List[bool], token_per_page_: List[int], head_size_: List[int], head_grp_size_: List[int], use_sliding_window_: List[bool], has_sm90: bool, ) -> Iterator[JitSpec]: """Generate XQA modules for various configurations.""" if not has_sm90: return # XQA requires SM90+ for ( use_fp16, token_per_page, head_size, head_grp_size, use_sliding_window, ) in product( use_fp16_, token_per_page_, head_size_, head_grp_size_, use_sliding_window_, ): # Skip invalid configurations if head_size % 16 != 0 or head_size > 256 or head_size < 16: continue if token_per_page not in [16, 32, 64, 128]: continue yield gen_xqa_module( use_fp16=use_fp16, token_per_page=token_per_page, head_size=head_size, head_grp_size=head_grp_size, use_sliding_window=use_sliding_window, ) def gen_all_modules( f16_dtype_: List[torch.dtype], f8_dtype_: List[torch.dtype], fa2_head_dim_: List[Tuple[int, int]], fa3_head_dim_: List[Tuple[int, int]], use_sliding_window_: List[bool], use_logits_soft_cap_: List[bool], has_sm90: bool, has_sm100: bool, add_comm: bool, add_gemma: bool, add_oai_oss: bool, add_moe: bool, add_act: bool, add_misc: bool, add_xqa: bool, ) -> List[JitSpec]: jit_specs: List[JitSpec] = [] jit_specs += list( gen_attention( f16_dtype_, f8_dtype_, fa2_head_dim_, fa3_head_dim_, use_sliding_window_, use_logits_soft_cap_, has_sm90, has_sm100, add_gemma, add_oai_oss, ) ) if add_act: for act_name in act_func_def_str: jit_specs.append(gen_act_and_mul_module(act_name)) if add_moe: jit_specs.append(gen_gemm_module()) if has_sm90: jit_specs.append(gen_gemm_sm90_module()) jit_specs.append(gen_fp4_quantization_sm90_module()) jit_specs.append(gen_cutlass_fused_moe_sm90_module()) if has_sm100: jit_specs.append(gen_fp4_quantization_sm100_module()) jit_specs.append(gen_cutlass_fused_moe_sm100_module()) jit_specs.append(gen_gemm_sm100_module()) jit_specs.append(gen_gemm_sm100_module_cutlass_fp4()) jit_specs.append(gen_gemm_sm100_module_cutlass_fp8()) jit_specs.append(gen_mxfp8_quantization_sm100_module()) jit_specs.append(gen_trtllm_gen_gemm_module()) jit_specs.append(gen_trtllm_gen_fused_moe_sm100_module()) if add_comm: from .comm import gen_trtllm_comm_module, gen_vllm_comm_module from .comm.nvshmem import gen_nvshmem_module from .comm.trtllm_alltoall import gen_comm_alltoall_module from .comm.trtllm_mnnvl_ar import gen_trtllm_mnnvl_comm_module jit_specs.append(gen_nvshmem_module()) jit_specs.append(gen_comm_alltoall_module()) if has_sm100: jit_specs.append(gen_trtllm_comm_module()) jit_specs.append(gen_trtllm_mnnvl_comm_module()) jit_specs.append(gen_vllm_comm_module()) if add_misc: jit_specs += [ gen_cascade_module(), gen_norm_module(), gen_page_module(), gen_quantization_module(), gen_rope_module(), gen_sampling_module(), ] if has_sm90: jit_specs.append(gen_trtllm_utils_module()) if add_xqa: # Define XQA configurations to iterate over xqa_use_fp16_ = [True, False] # fp16 and bf16 xqa_token_per_page_ = [16, 32, 64, 128] xqa_head_size_ = [64, 128, 256] xqa_head_grp_size_ = [1, 2, 4, 8] # Different group sizes for MQA/GQA jit_specs += list( gen_xqa( xqa_use_fp16_, xqa_token_per_page_, xqa_head_size_, xqa_head_grp_size_, use_sliding_window_, has_sm90, ) ) # dedup names = set() ret: List[JitSpec] = [] for jit_spec in jit_specs: if jit_spec.name not in names: names.add(jit_spec.name) ret.append(jit_spec) return ret def copy_built_kernels( jit_specs: List[JitSpec], out_dir: Path, ) -> None: if out_dir.exists(): shutil.rmtree(out_dir) out_dir.mkdir(parents=True, exist_ok=False) for jit_spec in jit_specs: src = jit_env.FLASHINFER_JIT_DIR / jit_spec.name / f"{jit_spec.name}.so" dst = out_dir / jit_spec.name / f"{jit_spec.name}.so" dst.parent.mkdir(exist_ok=False, parents=False) shutil.copy2(src, dst) def parse_bool(s: str) -> bool: if s.lower() in ("true", "1"): return True elif s.lower() in ("false", "0"): return False else: raise ValueError(f"Invalid boolean value: {s}") def parse_head_dim(head_dim: str) -> Tuple[int, int]: qo, kv = map(int, head_dim.split(",")) return qo, kv def main(): parser = argparse.ArgumentParser( description="Ahead-of-Time (AOT) build all modules" ) parser.add_argument( "--out-dir", type=Path, help="Output directory", ) parser.add_argument( "--build-dir", type=Path, help="Build directory", ) parser.add_argument( "--fa2-head-dim", nargs="*", help="FA2 head dim pair of qk and vo, separated by comma", ) parser.add_argument( "--fa3-head-dim", nargs="*", help="FA3 head dim pair of qk and vo, separated by comma", ) parser.add_argument( "--f16-dtype", nargs="*", choices=["float16", "bfloat16"], help="16-bit data type", ) parser.add_argument( "--f8-dtype", nargs="*", choices=["float8_e4m3fn", "float8_e5m2"], help="8-bit data type", ) parser.add_argument( "--use-sliding-window", nargs="*", help="Use sliding window attention", ) parser.add_argument( "--use-logits-soft-cap", nargs="*", help="Use logits soft cap", ) parser.add_argument( "--add-comm", type=parse_bool, help="Add communication kernels (trtllm_comm, vllm_comm)", ) parser.add_argument( "--add-gemma", type=parse_bool, help="Add kernels for Gemma Model (head_dim=256, use_sliding_window, use_logits_soft_cap)", ) parser.add_argument( "--add-oai-oss", type=parse_bool, help="Add kernels for OAI OSS Model (head_dim=64, use_sliding_window)", ) parser.add_argument( "--add-moe", type=parse_bool, help="Add MoE kernels", ) parser.add_argument( "--add-act", type=parse_bool, help="Add activation kernels", ) parser.add_argument( "--add-misc", type=parse_bool, help="Add miscellaneous kernels", ) parser.add_argument( "--add-xqa", type=parse_bool, help="Add XQA (Cross-Query Attention) kernels", ) args = parser.parse_args() # Default values project_root = Path(__file__).resolve().parents[1] out_dir = project_root / "aot-ops" build_dir = project_root / "build" / "aot" fa2_head_dim_ = [ (64, 64), (128, 128), (256, 256), ] fa3_head_dim_ = [ (192, 128), (128, 128), (64, 64), (256, 256), ] f16_dtype_ = [ torch.float16, torch.bfloat16, ] f8_dtype_ = [ torch.float8_e4m3fn, # torch.float8_e5m2, ] use_sliding_window_ = [ False, True, ] use_logits_soft_cap_ = [ False, True, ] add_comm = True add_gemma = True add_oai_oss = True add_moe = True add_act = True add_misc = True add_xqa = True # Override if args.out_dir: out_dir = Path(args.out_dir) if args.build_dir: build_dir = Path(args.build_dir) if args.fa2_head_dim: fa2_head_dim_ = [parse_head_dim(dim) for dim in args.fa2_head_dim] if args.fa3_head_dim: fa3_head_dim_ = [parse_head_dim(dim) for dim in args.fa3_head_dim] if args.f16_dtype: f16_dtype_ = [getattr(torch, dtype) for dtype in args.f16_dtype] if args.f8_dtype: f8_dtype_ = [getattr(torch, dtype) for dtype in args.f8_dtype] if args.use_sliding_window: use_sliding_window_ = [parse_bool(s) for s in args.use_sliding_window] if args.use_logits_soft_cap: use_logits_soft_cap_ = [parse_bool(s) for s in args.use_logits_soft_cap] if args.add_comm is not None: add_comm = args.add_comm if args.add_gemma is not None: add_gemma = args.add_gemma if args.add_oai_oss is not None: add_oai_oss = args.add_oai_oss if args.add_moe is not None: add_moe = args.add_moe if args.add_act is not None: add_act = args.add_act if args.add_misc is not None: add_misc = args.add_misc if args.add_xqa is not None: add_xqa = args.add_xqa # Cuda Arch if "FLASHINFER_CUDA_ARCH_LIST" not in os.environ: raise RuntimeError("Please explicitly set env var FLASHINFER_CUDA_ARCH_LIST.") compilation_context = CompilationContext() gencode_flags_list = compilation_context.get_nvcc_flags_list( supported_major_versions=None ) def has_sm(compute: str, version: str) -> bool: if not any(compute in flag for flag in gencode_flags_list): return False if torch.version.cuda is None: return True return version_at_least(torch.version.cuda, version) has_sm90 = has_sm("compute_90", "12.3") has_sm100 = has_sm("compute_100", "12.8") # Update data dir jit_env.FLASHINFER_CSRC_DIR = project_root / "csrc" jit_env.FLASHINFER_INCLUDE_DIR = project_root / "include" jit_env.CUTLASS_INCLUDE_DIRS = [ project_root / "3rdparty" / "cutlass" / "include", project_root / "3rdparty" / "cutlass" / "tools" / "util" / "include", ] jit_env.SPDLOG_INCLUDE_DIR = project_root / "3rdparty" / "spdlog" / "include" # Update workdir jit_env.FLASHINFER_WORKSPACE_DIR = build_dir jit_env.FLASHINFER_JIT_DIR = build_dir / "cached_ops" jit_env.FLASHINFER_GEN_SRC_DIR = build_dir / "generated" jit_env.FLASHINFER_JIT_DIR.mkdir(parents=True, exist_ok=True) jit_env.FLASHINFER_GEN_SRC_DIR.mkdir(parents=True, exist_ok=True) # Print summary print("AOT build summary:") print(" out_dir:", out_dir) print(" build_dir:", build_dir) print(" fa2_head_dim:", fa2_head_dim_) print(" fa3_head_dim:", fa3_head_dim_) print(" f16_dtype:", f16_dtype_) print(" f8_dtype:", f8_dtype_) print(" use_sliding_window:", use_sliding_window_) print(" use_logits_soft_cap:", use_logits_soft_cap_) print(" FLASHINFER_CUDA_ARCH_LIST:", os.environ["FLASHINFER_CUDA_ARCH_LIST"]) print(" has_sm90:", has_sm90) print(" has_sm100:", has_sm100) print(" add_comm:", add_comm) print(" add_gemma:", add_gemma) print(" add_oai_oss:", add_oai_oss) print(" add_moe:", add_moe) print(" add_act:", add_act) print(" add_misc:", add_misc) print(" add_xqa:", add_xqa) # Generate JIT specs print("Generating JIT specs...") jit_specs = [gen_logging_module()] jit_specs += gen_all_modules( f16_dtype_, f8_dtype_, fa2_head_dim_, fa3_head_dim_, use_sliding_window_, use_logits_soft_cap_, has_sm90, has_sm100, add_comm, add_gemma, add_oai_oss, add_moe, add_act, add_misc, add_xqa, ) print("Total ops:", len(jit_specs)) # Build build_jit_specs(jit_specs, verbose=True, skip_prebuilt=False) # Copy built kernels copy_built_kernels(jit_specs, out_dir) print("AOT kernels saved to:", out_dir) if __name__ == "__main__": main()