sglang_v0.5.2/flashinfer_0.3.1/flashinfer/utils.py

728 lines
22 KiB
Python

"""
Copyright (c) 2023 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import functools
import math
import os
from enum import Enum
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union
import torch
import torch.version
from torch.torch_version import TorchVersion
from torch.torch_version import __version__ as torch_version
from .jit import gen_jit_spec, env as jit_env
IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1"
class PosEncodingMode(Enum):
NONE = 0
ROPE_LLAMA = 1
ALIBI = 2
class MaskMode(Enum):
NON_CAUSAL = 0
CAUSAL = 1
CUSTOM = 2
MULTIITEMSCORING = 3
class TensorLayout(Enum):
NHD = 0
HND = 1
log2e = 1.44269504088896340736
def _expand_5d(x: torch.Tensor, kv_layout: str) -> torch.Tensor:
if x.ndim not in [4, 5]:
raise ValueError("x must be 4D or 5D")
if x.ndim == 4:
# page_size == 1
if kv_layout == "NHD":
# (num_pages, 2, num_heads, head_dim) -> (num_pages, 2, page_size=1, num_heads, head_dim)
# expand to 5D on the 3nd last dimension
return x.unsqueeze(-3)
elif kv_layout == "HND":
# (num_pages, 2, num_heads, head_dim) -> (num_pages, 2, num_heads, page_size=1, head_dim)
# expand to 5D on the 2nd last dimension
return x.unsqueeze(-2)
else:
raise KeyError("Invalid kv_layout {}".format(kv_layout))
return x
def _expand_4d(x: torch.Tensor, kv_layout: str) -> torch.Tensor:
if x.ndim not in [3, 4]:
raise ValueError("x must be 3D or 4D")
if x.ndim == 3:
# page_size == 1
if kv_layout == "NHD":
# (num_pages, num_heads, head_dim) -> (num_pages, page_size=1, num_heads, head_dim)
# expand to 4D on the 3nd last dimension
return x.unsqueeze(-3)
elif kv_layout == "HND":
# (num_pages, num_heads, head_dim) -> (num_pages, num_heads, page_size=1, head_dim)
# expand to 5D on the 2nd last dimension
return x.unsqueeze(-2)
else:
raise KeyError("Invalid kv_layout {}".format(kv_layout))
return x
def next_positive_power_of_2(x: int) -> int:
if x < 1:
return 1
# Following code is equivalent to 1 << (x - 1).bit_length()
# But this impl does not contain bit_length() so can be used by torch compile.
# It can correctly handle 64bit number which should be enough for now.
n = x - 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
return n + 1
def calculate_tile_tokens_dim(num_tokens: int, num_experts: int, top_k: int) -> int:
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = num_tokens * top_k // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def _check_pos_encoding_mode(pos_encoding_mode: str) -> None:
if not hasattr(PosEncodingMode, pos_encoding_mode):
raise KeyError("Invalid pos_encoding_mode {}".format(pos_encoding_mode))
def _check_kv_layout(kv_layout: str) -> None:
if not hasattr(TensorLayout, kv_layout):
raise KeyError("Invalid kv_layout {}".format(kv_layout))
def is_float8(x: torch.Tensor) -> bool:
return x.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
def get_indptr(x: torch.Tensor) -> torch.Tensor:
x = x.to(torch.int64)
ret = torch.zeros(x.shape[0] + 1, dtype=x.dtype, device=x.device)
ret[1:] = x.cumsum(0)
return ret
def _unpack_paged_kv_cache(
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
kv_layout: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
if isinstance(paged_kv_cache, tuple):
paged_k_cache, paged_v_cache = paged_kv_cache
return (
_expand_4d(paged_k_cache, kv_layout),
_expand_4d(paged_v_cache, kv_layout),
)
elif torch.is_tensor(paged_kv_cache):
# NOTE(Zihao): split on the second dimension
paged_kv_cache = _expand_5d(paged_kv_cache, kv_layout)
paged_k_cache, paged_v_cache = paged_kv_cache.unbind(dim=1)
return paged_k_cache, paged_v_cache
else:
raise KeyError(
"Unrecognized paged_kv_cache type {}, expect a single tensor or a tuple of tensor.".format(
type(paged_kv_cache)
)
)
def get_alibi_slopes(n_heads: int) -> torch.Tensor:
n = 2 ** math.floor(math.log2(n_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))
if n < n_heads:
m_hat_0 = 2.0 ** (-4.0 / n)
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2))
m = torch.cat([m, m_hat])
return m.float()
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
key = (name, device)
buf = _cache_buf.get(key)
if buf is None:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf
return buf
# find the least power of 2 that is greater than or equal to x
def _ceil_pow2(x: int) -> int:
return 1 << (x - 1).bit_length()
def _get_range_buf(seq_len: int, device: torch.device) -> torch.Tensor:
seq_len_pow2 = _ceil_pow2(seq_len)
key = (f"range_{seq_len_pow2}", device)
buf = _cache_buf.get(key)
if buf is None:
buf = torch.arange(seq_len_pow2, device=device, dtype=torch.int32)
_cache_buf[key] = buf
return buf[:seq_len]
def _get_cache_alibi_slopes_buf(
num_qo_heads: int, device: torch.device
) -> torch.Tensor:
key = (f"alibi_slopes_{num_qo_heads}", device)
buf = _cache_buf.get(key)
if buf is None:
buf = get_alibi_slopes(num_qo_heads).to(device)
_cache_buf[key] = buf
return buf
def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
if isinstance(dtype, str):
return getattr(torch, dtype)
elif isinstance(dtype, torch.dtype):
return dtype
else:
raise TypeError(
"dtype must be a string or torch.dtype, got {}".format(type(dtype))
)
@functools.cache
def get_compute_capability(device: torch.device) -> Tuple[int, int]:
if device.type != "cuda":
raise ValueError("device must be a cuda device")
return torch.cuda.get_device_capability(device.index)
def _check_cached_qkv_data_type(
q: torch.Tensor, k: torch.Tensor, dtype_q: torch.dtype, dtype_kv: torch.dtype
) -> None:
if q.dtype != dtype_q:
raise ValueError(
f"The dtype of q {q.dtype} does not match the q_data_type {dtype_q} specified in plan function."
)
if k.dtype != dtype_kv:
raise ValueError(
f"The dtype of k {k.dtype} does not match the kv_data_type {dtype_kv} specified in plan function."
)
if IS_BUILDING_DOCS or TorchVersion(torch_version) < TorchVersion("2.4"):
def register_custom_op(
name: str,
fn: Optional[Callable] = None,
/,
*,
mutates_args: Union[str, Iterable[str]],
device_types: Optional[Union[str, Sequence[str]]] = None,
schema: Optional[str] = None,
) -> Callable:
return lambda x: x
def register_fake_op(
name: str,
fn: Optional[Callable] = None,
) -> Callable:
return lambda x: x
else:
def register_custom_op(
name: str,
fn: Optional[Callable] = None,
/,
*,
mutates_args: Union[str, Iterable[str]],
device_types: Optional[Union[str, Sequence[str]]] = None,
schema: Optional[str] = None,
) -> Callable:
# NOTE(Zihao): torch.library.custom_op has significant overhead as mentioned in the following link
# https://github.com/vllm-project/vllm/blob/36e76700453924c8d421db99af70a88a1df835cd/vllm/utils.py#L1660-L1674
# return torch.library.custom_op(
# name,
# fn,
# mutates_args=mutates_args,
# device_types=device_types,
# schema=schema,
# )
return lambda x: x
def register_fake_op(
name: str,
fn: Optional[Callable] = None,
) -> Callable:
# return torch.library.register_fake(name, fn)
return lambda x: x
def determine_gemm_backend(device: torch.device) -> str:
major, _ = get_compute_capability(device)
if major == 9 and torch.version.cuda >= "12.3":
return "sm90"
else:
return "sm80"
def is_fa3_backend_supported(
pos_encoding_mode: int,
use_fp16_qk_reductions: bool,
use_custom_mask: bool,
dtype_q: torch.dtype,
dtype_kv: torch.dtype,
) -> bool:
"""
Check if the FA3 backend is supported based on the given parameters.
NOTE(Zihao): this function is a workaround for the lack of support for certain features in
our FA3 backend, and will be removed once the backend is fully supported.
Parameters
----------
pos_encoding_mode : int
The positional encoding mode.
use_fp16_qk_reductions : bool
Whether FP16 QK reductions are allowed.
use_custom_mask : bool
Whether a custom mask is used.
dtype_q : torch.dtype
The data type of the query tensor.
dtype_kv : torch.dtype
The data type of the key-value tensor.
Returns
-------
bool
True if the FA3 backend is supported, False otherwise.
"""
if use_custom_mask:
return False
if pos_encoding_mode != PosEncodingMode.NONE.value:
return False
if use_fp16_qk_reductions:
return False
return True
def is_cutlass_backend_supported(
pos_encoding_mode: int,
use_fp16_qk_reductions: bool,
use_custom_mask: bool,
dtype_q: torch.dtype,
dtype_kv: torch.dtype,
) -> bool:
"""
Check if the cutlass backend is supported based on the given parameters.
Parameters
----------
pos_encoding_mode : int
The positional encoding mode.
use_fp16_qk_reductions : bool
Whether FP16 QK reductions are allowed.
use_custom_mask : bool
Whether a custom mask is used.
dtype_q : torch.dtype
The data type of the query tensor.
dtype_kv : torch.dtype
The data type of the key-value tensor.
Returns
-------
bool
True if the cutlass backend is supported, False otherwise.
"""
if use_custom_mask:
return False
if pos_encoding_mode != PosEncodingMode.NONE.value:
return False
if use_fp16_qk_reductions:
return False
if dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2]:
return False
if dtype_kv in [torch.float8_e4m3fn, torch.float8_e5m2]:
return False
return True
def determine_attention_backend(
device: torch.device,
pos_encoding_mode: int,
use_fp16_qk_reductions: bool,
use_custom_mask: bool,
dtype_q: torch.dtype,
dtype_kv: torch.dtype,
) -> str:
"""
Determine the appropriate attention backend based on the device and parameters.
Parameters
----------
device : torch.device
The device to be used.
mask_mode : int
The mask mode.
pos_encoding_mode : int
The positional encoding mode.
use_fp16_qk_reductions : bool
Whether FP16 QK reductions are allowed.
use_custom_mask : bool
Whether a custom mask is used.
dtype_q : torch.dtype
The data type of the query tensor.
dtype_kv : torch.dtype
The data type of the key-value tensor.
Returns
-------
str
The name of the attention backend to be used.
"""
if is_sm90a_supported(device) and is_fa3_backend_supported(
pos_encoding_mode,
use_fp16_qk_reductions,
use_custom_mask,
dtype_q,
dtype_kv,
):
return "fa3"
else:
return "fa2"
def version_at_least(version: str, base_version: str) -> bool:
from packaging import version as pkg_version
return pkg_version.parse(version) >= pkg_version.parse(base_version)
def has_cuda_cudart() -> bool:
"""
Check if cuda.cudart module is available (cuda-python <= 12.9).
Returns:
True if cuda.cudart exists, False otherwise
"""
import importlib.util
return importlib.util.find_spec("cuda.cudart") is not None
def is_sm90a_supported(device: torch.device) -> bool:
major, _ = get_compute_capability(device)
return major == 9 and version_at_least(torch.version.cuda, "12.3")
def is_sm100a_supported(device: torch.device) -> bool:
major, _ = get_compute_capability(device)
return major == 10 and version_at_least(torch.version.cuda, "12.8")
def is_sm110a_supported(device: torch.device) -> bool:
major, _ = get_compute_capability(device)
return major == 11 and version_at_least(torch.version.cuda, "13.0")
def is_sm120a_supported(device: torch.device) -> bool:
major, minor = get_compute_capability(device)
return major == 12 and minor == 0 and version_at_least(torch.version.cuda, "12.8")
def is_sm121a_supported(device: torch.device) -> bool:
major, minor = get_compute_capability(device)
return major == 12 and minor == 1 and version_at_least(torch.version.cuda, "12.9")
def determine_mla_backend(device: torch.device) -> str:
return "fa3" if is_sm90a_supported(device) else "fa2"
def check_shape_dtype_device(
x: torch.Tensor,
expected_shape: Optional[Sequence[int]],
expected_dtype: Optional[torch.dtype],
expected_device: Optional[torch.device],
name: str,
) -> None:
if expected_shape and x.shape != torch.Size(expected_shape):
raise ValueError(
f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}"
)
if expected_dtype and x.dtype != expected_dtype:
raise ValueError(
f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}"
)
if expected_device and x.device != expected_device:
raise ValueError(
f"Invalid device of {name}: expected {expected_device}, got {x.device}"
)
def gen_logging_module():
return gen_jit_spec(
"logging",
[
jit_env.FLASHINFER_CSRC_DIR / "logging.cc",
],
extra_include_paths=[
jit_env.SPDLOG_INCLUDE_DIR,
jit_env.FLASHINFER_INCLUDE_DIR,
],
)
@functools.cache
def get_logging_module():
return gen_logging_module().build_and_load()
class LogLevel(Enum):
TRACE = 0
DEBUG = 1
INFO = 2
WARN = 3
ERROR = 4
CRITICAL = 5
log_level_map = {
"trace": LogLevel.TRACE,
"debug": LogLevel.DEBUG,
"info": LogLevel.INFO,
"warn": LogLevel.WARN,
"error": LogLevel.ERROR,
"critical": LogLevel.CRITICAL,
}
def set_log_level(lvl_str: str) -> None:
get_logging_module().set_log_level(log_level_map[lvl_str].value)
def device_support_pdl(device: torch.device) -> bool:
if device.type != "cuda":
return False
major, _ = get_compute_capability(device)
return major >= 9
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
def round_up(x: int, y: int) -> int:
"""Round up x to the nearest multiple of y"""
return ceil_div(x, y) * y
def get_device_sm_count(device: torch.device) -> int:
return torch.cuda.get_device_properties(device).multi_processor_count
class FP4Tensor:
"""Wrapper class for FP4 tensors.
Since PyTorch doesn't natively support FP4, this wrapper contains:
- data: uint8 tensor storing the compressed FP4 data, the size of innermost dimension is ceil(original_dim / 2) since each uint8 stores 2 FP4 values
- scale: float8_e4m3fn tensor storing the scale factors
"""
def __init__(
self,
data: torch.Tensor,
scale: torch.Tensor,
scale_start_index: int = 0,
original_shape: Optional[Tuple[int, ...]] = None,
):
"""Initialize FP4Tensor.
Parameters
----------
data : torch.Tensor
uint8 tensor storing the compressed FP4 data
scale : torch.Tensor
float8_e4m3fn tensor storing the scale factors
scale_start_index : int
The start token index of the scale factors. This is needed when two kernels (like prefill and decode kernels) are reusing the same scale factor tensor with different offsets.
original_shape : Optional[Tuple[int, ...]]
The original shape before compression.
"""
if data.dtype != torch.uint8:
raise ValueError(f"data must be uint8 tensor, got {data.dtype}")
# Validate scale factor tensor and scale start index
if scale.dtype != torch.float8_e4m3fn:
raise ValueError(f"scale must be float8_e4m3fn tensor, got {scale.dtype}")
if scale.shape[0] % 128 != 0:
raise ValueError(
f"scale.shape[0] must be a multiple of 128, got {scale.shape[0]}"
)
if scale_start_index < 0 or scale_start_index >= scale.shape[0]:
raise ValueError(
f"scale start index must be in the range [0, scale.shape[0]). "
f"scale_start_index={scale_start_index}, scale.shape[0]={scale.shape[0]}"
)
if scale_start_index + data.shape[0] > scale.shape[0]:
raise ValueError(
f"scale start index + data.shape[0] must not exceed scale.shape[0]. "
f"scale_start_index={scale_start_index}, data.shape[0]={data.shape[0]}, scale.shape[0]={scale.shape[0]}"
)
# Validate shape relationship if original_shape is provided
if original_shape is not None:
if data.shape[:-1] != original_shape[:-1]:
raise ValueError(
f"data and original_shape must have the same dimensions except the last one. "
f"data.shape={data.shape}, original_shape={original_shape}"
)
# Check the last dimension relationship: data_dim = ceil(original_dim / 2)
expected_data_dim = math.ceil(original_shape[-1] / 2)
if data.shape[-1] != expected_data_dim:
raise ValueError(
f"data last dimension must be ceil(original_shape[-1] / 2). "
f"data.shape[-1]={data.shape[-1]}, original_shape[-1]={original_shape[-1]}, "
f"expected={expected_data_dim}"
)
self.data = data
self.scale = scale
self.scale_start_index = scale_start_index
self.original_shape = original_shape
self.dtype = "nvfp4"
# yapf: disable
srcToDstBlk16RowMap = [
0, 8,
1, 9,
2, 10,
3, 11,
4, 12,
5, 13,
6, 14,
7, 15
]
srcToDstBlk32RowMap = [
0, 8, 16, 24,
1, 9, 17, 25,
2, 10, 18, 26,
3, 11, 19, 27,
4, 12, 20, 28,
5, 13, 21, 29,
6, 14, 22, 30,
7, 15, 23, 31
]
# yapf: enable
def get_shuffle_block_size(epilogue_tile_m: int) -> int:
shuffle_block_size = 16
if epilogue_tile_m % 128 == 0:
shuffle_block_size = 32
return shuffle_block_size
def get_shuffle_matrix_a_row_indices(
input_tensor: torch.Tensor, epilogue_tile_m: int
) -> torch.Tensor:
"""
Higher-level PyTorch approach to reorder the rows in blocks of size 16 or 32.
- We do NOT try to handle custom e2m1 memory usage (i.e. no 'K/2' bytes).
- Instead, we purely reorder rows in a standard PyTorch shape [M, K].
"""
assert input_tensor.dim() == 2, (
f"input_tensor should be a 2D tensor, not {input_tensor.dim()}"
)
# M, K from the input
M, K = input_tensor.shape
# Choose block size 16 or 32
shuffle_block_size = get_shuffle_block_size(epilogue_tile_m)
row_map = srcToDstBlk16RowMap if shuffle_block_size == 16 else srcToDstBlk32RowMap
assert M % shuffle_block_size == 0, (
f"input_tensor.shape[0] must be multiples of {shuffle_block_size}"
)
# row_indices[new_row] = old_row
# so row_indices is an array of size M telling us from which old_row
# the new_row should be taken.
row_indices = torch.empty(M, dtype=torch.long)
for old_row in range(M):
block_idx = old_row // shuffle_block_size
row_in_block = old_row % shuffle_block_size
mapped_row_in_block = row_map[row_in_block]
new_row = block_idx * shuffle_block_size + mapped_row_in_block
row_indices[new_row] = old_row
return row_indices
def get_shuffle_matrix_sf_a_row_indices(
input_tensor: torch.Tensor, epilogue_tile_m: int, num_elts_per_sf: int = 16
) -> torch.Tensor:
assert input_tensor.dtype == torch.uint8
assert num_elts_per_sf == 16
assert input_tensor.dim() == 2, (
f"input_tensor should be a 2D tensor, not {input_tensor.dim()}"
)
# M, K from the input
M, K = input_tensor.shape
assert M % 128 == 0
assert K % 4 == 0
row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
return row_indices