82 lines
3.6 KiB
Python
82 lines
3.6 KiB
Python
import pytest
|
|
import torch
|
|
|
|
import flashinfer
|
|
import flashinfer.triton
|
|
|
|
|
|
@pytest.mark.parametrize("seq_len", [2048])
|
|
@pytest.mark.parametrize("num_heads", [32])
|
|
@pytest.mark.parametrize("head_dim", [128])
|
|
def test_merge_state(seq_len, num_heads, head_dim):
|
|
va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
|
|
sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
|
|
vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
|
|
sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
|
|
v_merged, s_merged = flashinfer.triton.cascade.merge_state(va, sa, vb, sb)
|
|
v_merged_std, s_merged_std = flashinfer.merge_state(va, sa, vb, sb)
|
|
|
|
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
|
|
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
|
|
|
|
|
|
@pytest.mark.parametrize("seq_len", [2048])
|
|
@pytest.mark.parametrize("num_heads", [32])
|
|
@pytest.mark.parametrize("head_dim", [128])
|
|
def test_merge_state_in_place(seq_len, num_heads, head_dim):
|
|
v = torch.randn(seq_len, num_heads, head_dim).half()
|
|
v_std = v.clone()
|
|
v, v_std = v.to("cuda:0"), v_std.to("cuda:0")
|
|
s = torch.randn(seq_len, num_heads, dtype=torch.float32)
|
|
s_std = s.clone()
|
|
s, s_std = s.to("cuda:0"), s_std.to("cuda:0")
|
|
v_other = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
|
|
s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
|
|
flashinfer.merge_state_in_place(v_std, s_std, v_other, s_other)
|
|
flashinfer.triton.cascade.merge_state_in_place(v, s, v_other, s_other)
|
|
|
|
assert torch.allclose(v, v_std, atol=1e-2)
|
|
assert torch.allclose(s, s_std, atol=1e-2)
|
|
|
|
|
|
@pytest.mark.parametrize("seq_len", [2048])
|
|
@pytest.mark.parametrize("num_heads", [32])
|
|
@pytest.mark.parametrize("head_dim", [128])
|
|
@pytest.mark.parametrize("num_states", [100])
|
|
def test_merge_states(seq_len, num_states, num_heads, head_dim):
|
|
v = torch.randn(seq_len, num_states, num_heads, head_dim).half().to("cuda:0")
|
|
s = torch.randn(seq_len, num_states, num_heads, dtype=torch.float32).to("cuda:0")
|
|
v_merged_std, s_merged_std = flashinfer.merge_states(v, s)
|
|
v_merged, s_merged = flashinfer.triton.cascade.merge_states(v, s)
|
|
|
|
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
|
|
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
|
|
|
|
|
|
@pytest.mark.parametrize("seq_len", [2048])
|
|
@pytest.mark.parametrize("num_heads", [32])
|
|
@pytest.mark.parametrize("head_dim", [128])
|
|
def test_variable_length_merge_states(seq_len, num_heads, head_dim):
|
|
max_index_sets = 512
|
|
lengths = torch.randint(low=1, high=max_index_sets, size=(seq_len,))
|
|
indptr = [0]
|
|
for i in range(seq_len):
|
|
indptr.append(indptr[-1] + lengths[i])
|
|
v = torch.randn(indptr[-1], num_heads, head_dim).half().to("cuda:0")
|
|
s = torch.randn(indptr[-1], num_heads, dtype=torch.float32).to("cuda:0")
|
|
indptr = torch.tensor(indptr, dtype=torch.int32).to("cuda:0")
|
|
v_merged, s_merged = flashinfer.triton.cascade.variable_length_merge_states(
|
|
v, s, indptr
|
|
)
|
|
for i in range(seq_len):
|
|
sub_v = v[indptr[i] : indptr[i + 1]]
|
|
sub_s = s[indptr[i] : indptr[i + 1]]
|
|
sub_v = torch.unsqueeze(sub_v, 0)
|
|
sub_s = torch.unsqueeze(sub_s, 0)
|
|
v_merged_std, s_merged_std = flashinfer.merge_states(sub_v, sub_s)
|
|
v_merged_std = torch.squeeze(v_merged_std, 0)
|
|
s_merged_std = torch.squeeze(s_merged_std, 0)
|
|
assert v_merged[i].shape == v_merged_std.shape
|
|
assert torch.allclose(v_merged[i], v_merged_std, atol=1e-2)
|
|
assert torch.allclose(s_merged[i], s_merged_std, atol=1e-2)
|