160 lines
4.6 KiB
Python
160 lines
4.6 KiB
Python
import triton # type: ignore[import]
|
|
import triton.language as tl # type: ignore[import]
|
|
|
|
|
|
@triton.jit
|
|
def state_merge(o, m, d, other_o, other_m, other_d):
|
|
m_max = tl.maximum(m, other_m)
|
|
d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max)
|
|
o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max)
|
|
return o, m_max, d
|
|
|
|
|
|
@triton.jit
|
|
def state_normalize(o, m, d):
|
|
o = o / d
|
|
return o, m, d
|
|
|
|
|
|
@triton.jit
|
|
def state_get_lse(o, m, d):
|
|
return m + tl.log2(d)
|
|
|
|
|
|
@triton.jit
|
|
def merge_state_kernel(
|
|
v_a_ptr,
|
|
s_a_ptr,
|
|
v_b_ptr,
|
|
s_b_ptr,
|
|
v_merged_ptr,
|
|
s_merged_ptr,
|
|
num_heads,
|
|
head_dim,
|
|
bdx: tl.constexpr,
|
|
bdy: tl.constexpr,
|
|
):
|
|
pos = tl.program_id(axis=0)
|
|
for tx in tl.range(bdx):
|
|
for head_idx in tl.range(bdy):
|
|
s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx)
|
|
s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx)
|
|
|
|
offsets = (pos * num_heads + head_idx) * head_dim + tx
|
|
v_a = tl.load(v_a_ptr + offsets)
|
|
v_b = tl.load(v_b_ptr + offsets)
|
|
|
|
v_merged, s_max, d = state_merge(
|
|
o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1
|
|
)
|
|
v_merged, s_max, d = state_normalize(v_merged, s_max, d)
|
|
v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx
|
|
tl.store(v_merged_ptr + v_merged_offset, v_merged)
|
|
|
|
if s_merged_ptr:
|
|
tl.store(
|
|
s_merged_ptr + pos * num_heads + head_idx,
|
|
tl.log2(d) + s_max,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def merge_state_in_place_kernel(
|
|
v_ptr,
|
|
s_ptr,
|
|
v_other_ptr,
|
|
s_other_ptr,
|
|
num_heads,
|
|
head_dim,
|
|
mask_ptr,
|
|
bdx: tl.constexpr,
|
|
bdy: tl.constexpr,
|
|
):
|
|
pos = tl.program_id(axis=0)
|
|
if mask_ptr:
|
|
if tl.load(mask_ptr + pos) == 0:
|
|
return
|
|
|
|
for head_idx in tl.range(bdy):
|
|
s_val = tl.load(s_ptr + pos * num_heads + head_idx)
|
|
s_other_val = tl.load(s_other_ptr + pos * num_heads + head_idx)
|
|
s_max = tl.maximum(s_val, s_other_val)
|
|
s_val = tl.exp2(s_val - s_max)
|
|
s_other_val = tl.exp2(s_other_val - s_max)
|
|
scale = s_val / (s_val + s_other_val)
|
|
other_scale = s_other_val / (s_val + s_other_val)
|
|
for tx in tl.range(bdx):
|
|
offset = (pos * num_heads + head_idx) * head_dim + tx
|
|
v_vec = tl.load(v_ptr + offset)
|
|
v_other_vec = tl.load(v_other_ptr + offset)
|
|
v_vec = scale * v_vec + other_scale * v_other_vec
|
|
tl.store(v_ptr + offset, v_vec)
|
|
if s_ptr:
|
|
tl.store(
|
|
s_ptr + pos * num_heads + head_idx,
|
|
tl.log2(s_val + s_other_val) + s_max,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def merge_states_kernel(
|
|
v_ptr,
|
|
s_ptr,
|
|
v_merged_ptr,
|
|
s_merged_ptr,
|
|
num_index_sets,
|
|
num_heads,
|
|
head_dim,
|
|
bdx: tl.constexpr,
|
|
bdy: tl.constexpr,
|
|
):
|
|
pos = tl.program_id(axis=0)
|
|
|
|
for tx in tl.range(bdx):
|
|
for head_idx in tl.range(bdy):
|
|
o, m, d = 0.0, -5e4, 1.0
|
|
for iter in tl.range(num_index_sets):
|
|
s = tl.load(
|
|
s_ptr + (pos * num_index_sets + iter) * num_heads + head_idx
|
|
)
|
|
v = tl.load(
|
|
v_ptr
|
|
+ ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim
|
|
+ tx
|
|
)
|
|
o, m, d = state_merge(o, m, d, v, s, 1)
|
|
o, m, d = state_normalize(o, m, d)
|
|
tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o)
|
|
if s_merged_ptr:
|
|
tl.store(
|
|
s_merged_ptr + pos * num_heads + head_idx, state_get_lse(o, m, d)
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def variable_length_merge_states_kernel(
|
|
v_ptr,
|
|
s_ptr,
|
|
indptr,
|
|
v_merged_ptr,
|
|
s_merged_ptr,
|
|
num_heads,
|
|
head_dim,
|
|
bdx: tl.constexpr,
|
|
bdy: tl.constexpr,
|
|
):
|
|
pos = tl.program_id(axis=0)
|
|
for tx in tl.range(bdx):
|
|
for head_idx in tl.range(bdy):
|
|
o, m, d = 0.0, -5e4, 1.0
|
|
for iter in tl.range(tl.load(indptr + pos), tl.load(indptr + pos + 1)):
|
|
s = tl.load(s_ptr + iter * num_heads + head_idx)
|
|
v = tl.load(v_ptr + (iter * num_heads + head_idx) * head_dim + tx)
|
|
o, m, d = state_merge(o, m, d, v, s, 1)
|
|
o, m, d = state_normalize(o, m, d)
|
|
tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o)
|
|
if s_merged_ptr:
|
|
tl.store(
|
|
s_merged_ptr + pos * num_heads + head_idx, state_get_lse(o, m, d)
|
|
)
|