sglang_v0.5.2/flashinfer_0.3.1/flashinfer/triton/kernels/cascade.py

160 lines
4.6 KiB
Python

import triton # type: ignore[import]
import triton.language as tl # type: ignore[import]
@triton.jit
def state_merge(o, m, d, other_o, other_m, other_d):
m_max = tl.maximum(m, other_m)
d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max)
o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max)
return o, m_max, d
@triton.jit
def state_normalize(o, m, d):
o = o / d
return o, m, d
@triton.jit
def state_get_lse(o, m, d):
return m + tl.log2(d)
@triton.jit
def merge_state_kernel(
v_a_ptr,
s_a_ptr,
v_b_ptr,
s_b_ptr,
v_merged_ptr,
s_merged_ptr,
num_heads,
head_dim,
bdx: tl.constexpr,
bdy: tl.constexpr,
):
pos = tl.program_id(axis=0)
for tx in tl.range(bdx):
for head_idx in tl.range(bdy):
s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx)
s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx)
offsets = (pos * num_heads + head_idx) * head_dim + tx
v_a = tl.load(v_a_ptr + offsets)
v_b = tl.load(v_b_ptr + offsets)
v_merged, s_max, d = state_merge(
o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1
)
v_merged, s_max, d = state_normalize(v_merged, s_max, d)
v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx
tl.store(v_merged_ptr + v_merged_offset, v_merged)
if s_merged_ptr:
tl.store(
s_merged_ptr + pos * num_heads + head_idx,
tl.log2(d) + s_max,
)
@triton.jit
def merge_state_in_place_kernel(
v_ptr,
s_ptr,
v_other_ptr,
s_other_ptr,
num_heads,
head_dim,
mask_ptr,
bdx: tl.constexpr,
bdy: tl.constexpr,
):
pos = tl.program_id(axis=0)
if mask_ptr:
if tl.load(mask_ptr + pos) == 0:
return
for head_idx in tl.range(bdy):
s_val = tl.load(s_ptr + pos * num_heads + head_idx)
s_other_val = tl.load(s_other_ptr + pos * num_heads + head_idx)
s_max = tl.maximum(s_val, s_other_val)
s_val = tl.exp2(s_val - s_max)
s_other_val = tl.exp2(s_other_val - s_max)
scale = s_val / (s_val + s_other_val)
other_scale = s_other_val / (s_val + s_other_val)
for tx in tl.range(bdx):
offset = (pos * num_heads + head_idx) * head_dim + tx
v_vec = tl.load(v_ptr + offset)
v_other_vec = tl.load(v_other_ptr + offset)
v_vec = scale * v_vec + other_scale * v_other_vec
tl.store(v_ptr + offset, v_vec)
if s_ptr:
tl.store(
s_ptr + pos * num_heads + head_idx,
tl.log2(s_val + s_other_val) + s_max,
)
@triton.jit
def merge_states_kernel(
v_ptr,
s_ptr,
v_merged_ptr,
s_merged_ptr,
num_index_sets,
num_heads,
head_dim,
bdx: tl.constexpr,
bdy: tl.constexpr,
):
pos = tl.program_id(axis=0)
for tx in tl.range(bdx):
for head_idx in tl.range(bdy):
o, m, d = 0.0, -5e4, 1.0
for iter in tl.range(num_index_sets):
s = tl.load(
s_ptr + (pos * num_index_sets + iter) * num_heads + head_idx
)
v = tl.load(
v_ptr
+ ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim
+ tx
)
o, m, d = state_merge(o, m, d, v, s, 1)
o, m, d = state_normalize(o, m, d)
tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o)
if s_merged_ptr:
tl.store(
s_merged_ptr + pos * num_heads + head_idx, state_get_lse(o, m, d)
)
@triton.jit
def variable_length_merge_states_kernel(
v_ptr,
s_ptr,
indptr,
v_merged_ptr,
s_merged_ptr,
num_heads,
head_dim,
bdx: tl.constexpr,
bdy: tl.constexpr,
):
pos = tl.program_id(axis=0)
for tx in tl.range(bdx):
for head_idx in tl.range(bdy):
o, m, d = 0.0, -5e4, 1.0
for iter in tl.range(tl.load(indptr + pos), tl.load(indptr + pos + 1)):
s = tl.load(s_ptr + iter * num_heads + head_idx)
v = tl.load(v_ptr + (iter * num_heads + head_idx) * head_dim + tx)
o, m, d = state_merge(o, m, d, v, s, 1)
o, m, d = state_normalize(o, m, d)
tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o)
if s_merged_ptr:
tl.store(
s_merged_ptr + pos * num_heads + head_idx, state_get_lse(o, m, d)
)