153 lines
3.8 KiB
Python
153 lines
3.8 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from .kernels.cascade import (
|
|
merge_state_in_place_kernel,
|
|
merge_state_kernel,
|
|
merge_states_kernel,
|
|
variable_length_merge_states_kernel,
|
|
)
|
|
from .utils import check_device, check_dim, check_input, check_shape
|
|
|
|
|
|
def merge_state(
|
|
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
|
|
):
|
|
check_input(v_a)
|
|
check_input(s_a)
|
|
check_input(v_b)
|
|
check_input(s_b)
|
|
check_device([v_a, s_a, v_b, s_b])
|
|
check_dim(3, v_a)
|
|
check_dim(2, s_a)
|
|
check_dim(3, v_b)
|
|
check_dim(2, s_b)
|
|
check_shape(v_a, v_b)
|
|
check_shape(s_a, s_b)
|
|
assert v_a.size(0) == s_a.size(0)
|
|
assert v_a.size(1) == s_b.size(1)
|
|
s_a = s_a.to(torch.float32)
|
|
s_b = s_b.to(torch.float32)
|
|
seq_len = v_a.size(0)
|
|
num_heads = v_a.size(1)
|
|
head_dim = v_a.size(2)
|
|
v_merged = torch.empty_like(v_a).to(s_a.device)
|
|
s_merged = torch.empty((seq_len, num_heads)).to(s_a.device)
|
|
bdx = head_dim
|
|
bdy = num_heads
|
|
|
|
merge_state_kernel[lambda meta: (seq_len,)](
|
|
v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy
|
|
)
|
|
|
|
return v_merged, s_merged
|
|
|
|
|
|
def merge_state_in_place(
|
|
v: torch.Tensor,
|
|
s: torch.Tensor,
|
|
v_other: torch.Tensor,
|
|
s_other: torch.Tensor,
|
|
mask: Optional[torch.Tensor] = None,
|
|
):
|
|
check_input(v)
|
|
check_input(s)
|
|
check_input(v_other)
|
|
check_input(s_other)
|
|
check_device([v, s, v_other, s_other])
|
|
check_dim(3, v)
|
|
check_dim(2, s)
|
|
check_dim(3, v_other)
|
|
check_dim(2, s_other)
|
|
check_shape(v, v_other)
|
|
check_shape(s, s_other)
|
|
assert v.size(0) == s.size(0)
|
|
assert v.size(1) == s.size(1)
|
|
assert s.dtype == torch.float32
|
|
assert s_other.dtype == torch.float32
|
|
if mask is not None:
|
|
check_dim(1, mask)
|
|
assert v.size(0) == mask.size(0)
|
|
assert mask.device == v.device
|
|
seq_len = v.size(0)
|
|
num_heads = v.size(1)
|
|
head_dim = v.size(2)
|
|
|
|
bdx = head_dim
|
|
bdy = num_heads
|
|
merge_state_in_place_kernel[(seq_len,)](
|
|
v, s, v_other, s_other, num_heads, head_dim, mask, bdx=bdx, bdy=bdy
|
|
)
|
|
|
|
|
|
def merge_states(v: torch.Tensor, s: torch.Tensor):
|
|
check_input(v)
|
|
check_input(s)
|
|
check_device([v, s])
|
|
check_dim(4, v)
|
|
check_dim(3, s)
|
|
assert v.size(0) == s.size(0)
|
|
assert v.size(1) == s.size(1)
|
|
assert v.size(2) == s.size(2)
|
|
seq_len = v.size(0)
|
|
num_index_sets = v.size(1)
|
|
num_heads = v.size(2)
|
|
head_dim = v.size(3)
|
|
s = s.to(torch.float32)
|
|
v_merged = torch.empty(
|
|
(seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device
|
|
)
|
|
s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device)
|
|
|
|
bdx = head_dim
|
|
bdy = num_heads
|
|
merge_states_kernel[(seq_len,)](
|
|
v,
|
|
s,
|
|
v_merged,
|
|
s_merged,
|
|
num_index_sets,
|
|
num_heads,
|
|
head_dim,
|
|
bdx=bdx,
|
|
bdy=bdy,
|
|
)
|
|
return v_merged, s_merged
|
|
|
|
|
|
def variable_length_merge_states(
|
|
v: torch.Tensor, s: torch.Tensor, indptr: torch.Tensor
|
|
):
|
|
check_input(v)
|
|
check_input(s)
|
|
check_device([v, s])
|
|
check_dim(3, v)
|
|
check_dim(2, s)
|
|
assert v.size(0) == s.size(0)
|
|
assert v.size(1) == s.size(1)
|
|
seq_len = indptr.size(0) - 1
|
|
num_heads = v.size(1)
|
|
head_dim = v.size(2)
|
|
s = s.to(torch.float32)
|
|
indptr = indptr.to(torch.int32)
|
|
v_merged = torch.empty(
|
|
(seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device
|
|
)
|
|
s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device)
|
|
|
|
bdx = head_dim
|
|
bdy = num_heads
|
|
variable_length_merge_states_kernel[(seq_len,)](
|
|
v,
|
|
s,
|
|
indptr,
|
|
v_merged,
|
|
s_merged,
|
|
num_heads,
|
|
head_dim,
|
|
bdx=bdx,
|
|
bdy=bdy,
|
|
)
|
|
return v_merged, s_merged
|