sglang_v0.5.2/flashinfer_0.3.1/flashinfer/triton/cascade.py

153 lines
3.8 KiB
Python

from typing import Optional
import torch
from .kernels.cascade import (
merge_state_in_place_kernel,
merge_state_kernel,
merge_states_kernel,
variable_length_merge_states_kernel,
)
from .utils import check_device, check_dim, check_input, check_shape
def merge_state(
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
):
check_input(v_a)
check_input(s_a)
check_input(v_b)
check_input(s_b)
check_device([v_a, s_a, v_b, s_b])
check_dim(3, v_a)
check_dim(2, s_a)
check_dim(3, v_b)
check_dim(2, s_b)
check_shape(v_a, v_b)
check_shape(s_a, s_b)
assert v_a.size(0) == s_a.size(0)
assert v_a.size(1) == s_b.size(1)
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
seq_len = v_a.size(0)
num_heads = v_a.size(1)
head_dim = v_a.size(2)
v_merged = torch.empty_like(v_a).to(s_a.device)
s_merged = torch.empty((seq_len, num_heads)).to(s_a.device)
bdx = head_dim
bdy = num_heads
merge_state_kernel[lambda meta: (seq_len,)](
v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy
)
return v_merged, s_merged
def merge_state_in_place(
v: torch.Tensor,
s: torch.Tensor,
v_other: torch.Tensor,
s_other: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
check_input(v)
check_input(s)
check_input(v_other)
check_input(s_other)
check_device([v, s, v_other, s_other])
check_dim(3, v)
check_dim(2, s)
check_dim(3, v_other)
check_dim(2, s_other)
check_shape(v, v_other)
check_shape(s, s_other)
assert v.size(0) == s.size(0)
assert v.size(1) == s.size(1)
assert s.dtype == torch.float32
assert s_other.dtype == torch.float32
if mask is not None:
check_dim(1, mask)
assert v.size(0) == mask.size(0)
assert mask.device == v.device
seq_len = v.size(0)
num_heads = v.size(1)
head_dim = v.size(2)
bdx = head_dim
bdy = num_heads
merge_state_in_place_kernel[(seq_len,)](
v, s, v_other, s_other, num_heads, head_dim, mask, bdx=bdx, bdy=bdy
)
def merge_states(v: torch.Tensor, s: torch.Tensor):
check_input(v)
check_input(s)
check_device([v, s])
check_dim(4, v)
check_dim(3, s)
assert v.size(0) == s.size(0)
assert v.size(1) == s.size(1)
assert v.size(2) == s.size(2)
seq_len = v.size(0)
num_index_sets = v.size(1)
num_heads = v.size(2)
head_dim = v.size(3)
s = s.to(torch.float32)
v_merged = torch.empty(
(seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device
)
s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device)
bdx = head_dim
bdy = num_heads
merge_states_kernel[(seq_len,)](
v,
s,
v_merged,
s_merged,
num_index_sets,
num_heads,
head_dim,
bdx=bdx,
bdy=bdy,
)
return v_merged, s_merged
def variable_length_merge_states(
v: torch.Tensor, s: torch.Tensor, indptr: torch.Tensor
):
check_input(v)
check_input(s)
check_device([v, s])
check_dim(3, v)
check_dim(2, s)
assert v.size(0) == s.size(0)
assert v.size(1) == s.size(1)
seq_len = indptr.size(0) - 1
num_heads = v.size(1)
head_dim = v.size(2)
s = s.to(torch.float32)
indptr = indptr.to(torch.int32)
v_merged = torch.empty(
(seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device
)
s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device)
bdx = head_dim
bdy = num_heads
variable_length_merge_states_kernel[(seq_len,)](
v,
s,
indptr,
v_merged,
s_merged,
num_heads,
head_dim,
bdx=bdx,
bdy=bdy,
)
return v_merged, s_merged