sglang_v0.5.2/flashinfer_0.3.1/flashinfer/aot.py

729 lines
21 KiB
Python

import argparse
import os
import shutil
from itertools import product
from pathlib import Path
from typing import List, Tuple, Iterator
import torch
import torch.version
from .activation import act_func_def_str, gen_act_and_mul_module
from .fp8_quantization import gen_mxfp8_quantization_sm100_module
from .cascade import gen_cascade_module
from .fp4_quantization import (
gen_fp4_quantization_sm100_module,
gen_fp4_quantization_sm90_module,
)
from .fused_moe import (
gen_cutlass_fused_moe_sm100_module,
gen_cutlass_fused_moe_sm90_module,
gen_trtllm_gen_fused_moe_sm100_module,
)
from .gemm import (
gen_gemm_module,
gen_gemm_sm90_module,
gen_gemm_sm100_module,
gen_gemm_sm100_module_cutlass_fp4,
gen_gemm_sm100_module_cutlass_fp8,
gen_trtllm_gen_gemm_module,
)
from .jit import JitSpec, build_jit_specs
from .jit import env as jit_env
from .jit import (
gen_batch_decode_module,
gen_batch_mla_module,
gen_batch_prefill_module,
gen_fmha_cutlass_sm100a_module,
gen_single_decode_module,
gen_single_prefill_module,
gen_trtllm_gen_fmha_module,
)
from .mla import gen_mla_module
from .norm import gen_norm_module
from .page import gen_page_module
from .quantization import gen_quantization_module
from .rope import gen_rope_module
from .sampling import gen_sampling_module
from .tllm_utils import gen_trtllm_utils_module
from .utils import gen_logging_module, version_at_least
from .xqa import gen_xqa_module
from .compilation_context import CompilationContext
def gen_fa2(
dtype_qo: torch.dtype,
dtype_kv: torch.dtype,
head_dim_qk: int,
head_dim_vo: int,
use_sliding_window: bool,
use_logits_soft_cap: bool,
) -> Iterator[JitSpec]:
if dtype_qo.itemsize == dtype_kv.itemsize and dtype_qo != dtype_kv:
return
if dtype_qo.itemsize == 1:
return # fp8 tensor cores not supported in fa2
yield gen_single_prefill_module(
backend="fa2",
dtype_q=dtype_qo,
dtype_kv=dtype_kv,
dtype_o=dtype_qo,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=0,
use_sliding_window=use_sliding_window,
use_logits_soft_cap=use_logits_soft_cap,
use_fp16_qk_reduction=False,
)
yield gen_batch_prefill_module(
backend="fa2",
dtype_q=dtype_qo,
dtype_kv=dtype_kv,
dtype_o=dtype_qo,
dtype_idx=torch.int32,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=0,
use_sliding_window=use_sliding_window,
use_logits_soft_cap=use_logits_soft_cap,
use_fp16_qk_reduction=False,
)
yield gen_single_decode_module(
dtype_q=dtype_qo,
dtype_kv=dtype_kv,
dtype_o=dtype_qo,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=0,
use_sliding_window=use_sliding_window,
use_logits_soft_cap=use_logits_soft_cap,
)
yield gen_batch_decode_module(
dtype_q=dtype_qo,
dtype_kv=dtype_kv,
dtype_o=dtype_qo,
dtype_idx=torch.int32,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=0,
use_sliding_window=use_sliding_window,
use_logits_soft_cap=use_logits_soft_cap,
)
def gen_fa3(
dtype_q: torch.dtype,
dtype_kv: torch.dtype,
dtype_o: torch.dtype,
head_dim_qk: int,
head_dim_vo: int,
use_sliding_window: bool,
use_logits_soft_cap: bool,
) -> Iterator[JitSpec]:
if dtype_q != dtype_kv:
return # fa3 template do not support mixed precision
if dtype_q.itemsize == 2:
if dtype_q != dtype_o:
return # for fp16, dtype_o must be the same as dtype_q/dtype_kv
if dtype_kv.itemsize == 1:
if head_dim_qk == 192 or head_dim_qk == 64:
return # (192, 128) & (64, 64) not supported for fp8 yet.
yield gen_batch_prefill_module(
backend="fa3",
dtype_q=dtype_q,
dtype_kv=dtype_kv,
dtype_o=dtype_o,
dtype_idx=torch.int32,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=0,
use_sliding_window=use_sliding_window,
use_logits_soft_cap=use_logits_soft_cap,
use_fp16_qk_reduction=False,
)
def gen_attention(
f16_dtype_: List[torch.dtype],
f8_dtype_: List[torch.dtype],
fa2_head_dim_: List[Tuple[int, int]],
fa3_head_dim_: List[Tuple[int, int]],
use_sliding_window_: List[bool],
use_logits_soft_cap_: List[bool],
has_sm90: bool,
has_sm100: bool,
add_gemma: bool,
add_oai_oss: bool,
) -> Iterator[JitSpec]:
head_dim_ckv = 512
head_dim_kpe = 64
# FA2 MHA / MQA / GQA
for (
(head_dim_qk, head_dim_vo),
dtype_qo,
dtype_kv,
use_sliding_window,
use_logits_soft_cap,
) in product(
fa2_head_dim_,
f16_dtype_,
f16_dtype_ + f8_dtype_,
use_sliding_window_,
use_logits_soft_cap_,
):
yield from gen_fa2(
dtype_qo=dtype_qo,
dtype_kv=dtype_kv,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
use_sliding_window=use_sliding_window,
use_logits_soft_cap=use_logits_soft_cap,
)
# FA3 MHA / MQA / GQA
if has_sm90:
for (
(head_dim_qk, head_dim_vo),
dtype_qkv,
dtype_o,
use_sliding_window,
use_logits_soft_cap,
) in product(
fa3_head_dim_,
f16_dtype_ + f8_dtype_,
f16_dtype_,
use_sliding_window_,
use_logits_soft_cap_,
):
yield from gen_fa3(
dtype_q=dtype_qkv,
dtype_kv=dtype_qkv,
dtype_o=dtype_o,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
use_sliding_window=use_sliding_window,
use_logits_soft_cap=use_logits_soft_cap,
)
# Gemma
if add_gemma:
for (
dtype_qo,
dtype_kv,
(use_sliding_window, use_logits_soft_cap),
) in product(
f16_dtype_,
f16_dtype_ + f8_dtype_,
[(True, True)],
):
yield from gen_fa2(
dtype_qo=dtype_qo,
dtype_kv=dtype_kv,
head_dim_qk=256,
head_dim_vo=256,
use_sliding_window=use_sliding_window,
use_logits_soft_cap=use_logits_soft_cap,
)
if has_sm90:
for (
dtype_qkv,
dtype_o,
(use_sliding_window, use_logits_soft_cap),
) in product(
f16_dtype_ + f8_dtype_,
f16_dtype_,
[(True, True)],
):
yield from gen_fa3(
dtype_q=dtype_qkv,
dtype_kv=dtype_qkv,
dtype_o=dtype_o,
head_dim_qk=256,
head_dim_vo=256,
use_sliding_window=use_sliding_window,
use_logits_soft_cap=use_logits_soft_cap,
)
# OAI OSS
if add_oai_oss:
from .jit.attention import gen_batch_prefill_attention_sink_module
for dtype in f16_dtype_:
for backend in ["fa2", "fa3"]:
for use_swa in [True, False]:
yield gen_batch_prefill_attention_sink_module(
backend=backend,
dtype_q=dtype,
dtype_kv=dtype,
dtype_o=dtype,
dtype_idx=torch.int32,
head_dim_qk=64,
head_dim_vo=64,
pos_encoding_mode=0,
use_sliding_window=use_swa,
)
# fmha_cutlass_sm100a
# NOTE: currently there's only one uri.
if has_sm100:
yield gen_fmha_cutlass_sm100a_module(
dtype_q=torch.bfloat16,
dtype_kv=torch.bfloat16,
dtype_o=torch.bfloat16,
dtype_idx=torch.int32,
head_dim_qk=128,
head_dim_vo=128,
pos_encoding_mode=0,
use_sliding_window=False,
use_logits_soft_cap=False,
)
# trtllm_gen_fmha
yield gen_trtllm_gen_fmha_module()
# MLA
# NOTE: fp8 kv not supported in MLA
mla_backend_ = ["fa2"] + (["fa3"] if has_sm90 else [])
for dtype_qo in f16_dtype_:
for backend in mla_backend_:
yield gen_batch_mla_module(
backend=backend,
dtype_q=dtype_qo,
dtype_kv=dtype_qo,
dtype_o=dtype_qo,
dtype_idx=torch.int32,
head_dim_ckv=head_dim_ckv,
head_dim_kpe=head_dim_kpe,
use_profiler=False,
)
# MLA SM100
if has_sm100:
yield gen_mla_module()
def gen_xqa(
use_fp16_: List[bool],
token_per_page_: List[int],
head_size_: List[int],
head_grp_size_: List[int],
use_sliding_window_: List[bool],
has_sm90: bool,
) -> Iterator[JitSpec]:
"""Generate XQA modules for various configurations."""
if not has_sm90:
return # XQA requires SM90+
for (
use_fp16,
token_per_page,
head_size,
head_grp_size,
use_sliding_window,
) in product(
use_fp16_,
token_per_page_,
head_size_,
head_grp_size_,
use_sliding_window_,
):
# Skip invalid configurations
if head_size % 16 != 0 or head_size > 256 or head_size < 16:
continue
if token_per_page not in [16, 32, 64, 128]:
continue
yield gen_xqa_module(
use_fp16=use_fp16,
token_per_page=token_per_page,
head_size=head_size,
head_grp_size=head_grp_size,
use_sliding_window=use_sliding_window,
)
def gen_all_modules(
f16_dtype_: List[torch.dtype],
f8_dtype_: List[torch.dtype],
fa2_head_dim_: List[Tuple[int, int]],
fa3_head_dim_: List[Tuple[int, int]],
use_sliding_window_: List[bool],
use_logits_soft_cap_: List[bool],
has_sm90: bool,
has_sm100: bool,
add_comm: bool,
add_gemma: bool,
add_oai_oss: bool,
add_moe: bool,
add_act: bool,
add_misc: bool,
add_xqa: bool,
) -> List[JitSpec]:
jit_specs: List[JitSpec] = []
jit_specs += list(
gen_attention(
f16_dtype_,
f8_dtype_,
fa2_head_dim_,
fa3_head_dim_,
use_sliding_window_,
use_logits_soft_cap_,
has_sm90,
has_sm100,
add_gemma,
add_oai_oss,
)
)
if add_act:
for act_name in act_func_def_str:
jit_specs.append(gen_act_and_mul_module(act_name))
if add_moe:
jit_specs.append(gen_gemm_module())
if has_sm90:
jit_specs.append(gen_gemm_sm90_module())
jit_specs.append(gen_fp4_quantization_sm90_module())
jit_specs.append(gen_cutlass_fused_moe_sm90_module())
if has_sm100:
jit_specs.append(gen_fp4_quantization_sm100_module())
jit_specs.append(gen_cutlass_fused_moe_sm100_module())
jit_specs.append(gen_gemm_sm100_module())
jit_specs.append(gen_gemm_sm100_module_cutlass_fp4())
jit_specs.append(gen_gemm_sm100_module_cutlass_fp8())
jit_specs.append(gen_mxfp8_quantization_sm100_module())
jit_specs.append(gen_trtllm_gen_gemm_module())
jit_specs.append(gen_trtllm_gen_fused_moe_sm100_module())
if add_comm:
from .comm import gen_trtllm_comm_module, gen_vllm_comm_module
from .comm.nvshmem import gen_nvshmem_module
from .comm.trtllm_alltoall import gen_comm_alltoall_module
from .comm.trtllm_mnnvl_ar import gen_trtllm_mnnvl_comm_module
jit_specs.append(gen_nvshmem_module())
jit_specs.append(gen_comm_alltoall_module())
if has_sm100:
jit_specs.append(gen_trtllm_comm_module())
jit_specs.append(gen_trtllm_mnnvl_comm_module())
jit_specs.append(gen_vllm_comm_module())
if add_misc:
jit_specs += [
gen_cascade_module(),
gen_norm_module(),
gen_page_module(),
gen_quantization_module(),
gen_rope_module(),
gen_sampling_module(),
]
if has_sm90:
jit_specs.append(gen_trtllm_utils_module())
if add_xqa:
# Define XQA configurations to iterate over
xqa_use_fp16_ = [True, False] # fp16 and bf16
xqa_token_per_page_ = [16, 32, 64, 128]
xqa_head_size_ = [64, 128, 256]
xqa_head_grp_size_ = [1, 2, 4, 8] # Different group sizes for MQA/GQA
jit_specs += list(
gen_xqa(
xqa_use_fp16_,
xqa_token_per_page_,
xqa_head_size_,
xqa_head_grp_size_,
use_sliding_window_,
has_sm90,
)
)
# dedup
names = set()
ret: List[JitSpec] = []
for jit_spec in jit_specs:
if jit_spec.name not in names:
names.add(jit_spec.name)
ret.append(jit_spec)
return ret
def copy_built_kernels(
jit_specs: List[JitSpec],
out_dir: Path,
) -> None:
if out_dir.exists():
shutil.rmtree(out_dir)
out_dir.mkdir(parents=True, exist_ok=False)
for jit_spec in jit_specs:
src = jit_env.FLASHINFER_JIT_DIR / jit_spec.name / f"{jit_spec.name}.so"
dst = out_dir / jit_spec.name / f"{jit_spec.name}.so"
dst.parent.mkdir(exist_ok=False, parents=False)
shutil.copy2(src, dst)
def parse_bool(s: str) -> bool:
if s.lower() in ("true", "1"):
return True
elif s.lower() in ("false", "0"):
return False
else:
raise ValueError(f"Invalid boolean value: {s}")
def parse_head_dim(head_dim: str) -> Tuple[int, int]:
qo, kv = map(int, head_dim.split(","))
return qo, kv
def main():
parser = argparse.ArgumentParser(
description="Ahead-of-Time (AOT) build all modules"
)
parser.add_argument(
"--out-dir",
type=Path,
help="Output directory",
)
parser.add_argument(
"--build-dir",
type=Path,
help="Build directory",
)
parser.add_argument(
"--fa2-head-dim",
nargs="*",
help="FA2 head dim pair of qk and vo, separated by comma",
)
parser.add_argument(
"--fa3-head-dim",
nargs="*",
help="FA3 head dim pair of qk and vo, separated by comma",
)
parser.add_argument(
"--f16-dtype",
nargs="*",
choices=["float16", "bfloat16"],
help="16-bit data type",
)
parser.add_argument(
"--f8-dtype",
nargs="*",
choices=["float8_e4m3fn", "float8_e5m2"],
help="8-bit data type",
)
parser.add_argument(
"--use-sliding-window",
nargs="*",
help="Use sliding window attention",
)
parser.add_argument(
"--use-logits-soft-cap",
nargs="*",
help="Use logits soft cap",
)
parser.add_argument(
"--add-comm",
type=parse_bool,
help="Add communication kernels (trtllm_comm, vllm_comm)",
)
parser.add_argument(
"--add-gemma",
type=parse_bool,
help="Add kernels for Gemma Model (head_dim=256, use_sliding_window, use_logits_soft_cap)",
)
parser.add_argument(
"--add-oai-oss",
type=parse_bool,
help="Add kernels for OAI OSS Model (head_dim=64, use_sliding_window)",
)
parser.add_argument(
"--add-moe",
type=parse_bool,
help="Add MoE kernels",
)
parser.add_argument(
"--add-act",
type=parse_bool,
help="Add activation kernels",
)
parser.add_argument(
"--add-misc",
type=parse_bool,
help="Add miscellaneous kernels",
)
parser.add_argument(
"--add-xqa",
type=parse_bool,
help="Add XQA (Cross-Query Attention) kernels",
)
args = parser.parse_args()
# Default values
project_root = Path(__file__).resolve().parents[1]
out_dir = project_root / "aot-ops"
build_dir = project_root / "build" / "aot"
fa2_head_dim_ = [
(64, 64),
(128, 128),
(256, 256),
]
fa3_head_dim_ = [
(192, 128),
(128, 128),
(64, 64),
(256, 256),
]
f16_dtype_ = [
torch.float16,
torch.bfloat16,
]
f8_dtype_ = [
torch.float8_e4m3fn,
# torch.float8_e5m2,
]
use_sliding_window_ = [
False,
True,
]
use_logits_soft_cap_ = [
False,
True,
]
add_comm = True
add_gemma = True
add_oai_oss = True
add_moe = True
add_act = True
add_misc = True
add_xqa = True
# Override
if args.out_dir:
out_dir = Path(args.out_dir)
if args.build_dir:
build_dir = Path(args.build_dir)
if args.fa2_head_dim:
fa2_head_dim_ = [parse_head_dim(dim) for dim in args.fa2_head_dim]
if args.fa3_head_dim:
fa3_head_dim_ = [parse_head_dim(dim) for dim in args.fa3_head_dim]
if args.f16_dtype:
f16_dtype_ = [getattr(torch, dtype) for dtype in args.f16_dtype]
if args.f8_dtype:
f8_dtype_ = [getattr(torch, dtype) for dtype in args.f8_dtype]
if args.use_sliding_window:
use_sliding_window_ = [parse_bool(s) for s in args.use_sliding_window]
if args.use_logits_soft_cap:
use_logits_soft_cap_ = [parse_bool(s) for s in args.use_logits_soft_cap]
if args.add_comm is not None:
add_comm = args.add_comm
if args.add_gemma is not None:
add_gemma = args.add_gemma
if args.add_oai_oss is not None:
add_oai_oss = args.add_oai_oss
if args.add_moe is not None:
add_moe = args.add_moe
if args.add_act is not None:
add_act = args.add_act
if args.add_misc is not None:
add_misc = args.add_misc
if args.add_xqa is not None:
add_xqa = args.add_xqa
# Cuda Arch
if "FLASHINFER_CUDA_ARCH_LIST" not in os.environ:
raise RuntimeError("Please explicitly set env var FLASHINFER_CUDA_ARCH_LIST.")
compilation_context = CompilationContext()
gencode_flags_list = compilation_context.get_nvcc_flags_list(
supported_major_versions=None
)
def has_sm(compute: str, version: str) -> bool:
if not any(compute in flag for flag in gencode_flags_list):
return False
if torch.version.cuda is None:
return True
return version_at_least(torch.version.cuda, version)
has_sm90 = has_sm("compute_90", "12.3")
has_sm100 = has_sm("compute_100", "12.8")
# Update data dir
jit_env.FLASHINFER_CSRC_DIR = project_root / "csrc"
jit_env.FLASHINFER_INCLUDE_DIR = project_root / "include"
jit_env.CUTLASS_INCLUDE_DIRS = [
project_root / "3rdparty" / "cutlass" / "include",
project_root / "3rdparty" / "cutlass" / "tools" / "util" / "include",
]
jit_env.SPDLOG_INCLUDE_DIR = project_root / "3rdparty" / "spdlog" / "include"
# Update workdir
jit_env.FLASHINFER_WORKSPACE_DIR = build_dir
jit_env.FLASHINFER_JIT_DIR = build_dir / "cached_ops"
jit_env.FLASHINFER_GEN_SRC_DIR = build_dir / "generated"
jit_env.FLASHINFER_JIT_DIR.mkdir(parents=True, exist_ok=True)
jit_env.FLASHINFER_GEN_SRC_DIR.mkdir(parents=True, exist_ok=True)
# Print summary
print("AOT build summary:")
print(" out_dir:", out_dir)
print(" build_dir:", build_dir)
print(" fa2_head_dim:", fa2_head_dim_)
print(" fa3_head_dim:", fa3_head_dim_)
print(" f16_dtype:", f16_dtype_)
print(" f8_dtype:", f8_dtype_)
print(" use_sliding_window:", use_sliding_window_)
print(" use_logits_soft_cap:", use_logits_soft_cap_)
print(" FLASHINFER_CUDA_ARCH_LIST:", os.environ["FLASHINFER_CUDA_ARCH_LIST"])
print(" has_sm90:", has_sm90)
print(" has_sm100:", has_sm100)
print(" add_comm:", add_comm)
print(" add_gemma:", add_gemma)
print(" add_oai_oss:", add_oai_oss)
print(" add_moe:", add_moe)
print(" add_act:", add_act)
print(" add_misc:", add_misc)
print(" add_xqa:", add_xqa)
# Generate JIT specs
print("Generating JIT specs...")
jit_specs = [gen_logging_module()]
jit_specs += gen_all_modules(
f16_dtype_,
f8_dtype_,
fa2_head_dim_,
fa3_head_dim_,
use_sliding_window_,
use_logits_soft_cap_,
has_sm90,
has_sm100,
add_comm,
add_gemma,
add_oai_oss,
add_moe,
add_act,
add_misc,
add_xqa,
)
print("Total ops:", len(jit_specs))
# Build
build_jit_specs(jit_specs, verbose=True, skip_prebuilt=False)
# Copy built kernels
copy_built_kernels(jit_specs, out_dir)
print("AOT kernels saved to:", out_dir)
if __name__ == "__main__":
main()