158 lines
4.8 KiB
Python
158 lines
4.8 KiB
Python
import itertools
|
|
import unittest
|
|
|
|
import sgl_kernel
|
|
import torch
|
|
from torch.nn.functional import scaled_dot_product_attention
|
|
from utils import precision
|
|
|
|
from sglang.test.test_utils import CustomTestCase
|
|
|
|
torch.manual_seed(1234)
|
|
|
|
|
|
class TestMLA(CustomTestCase):
|
|
def _run_sdpa_forward_decode(
|
|
self,
|
|
query: torch.Tensor,
|
|
output: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
v_cache: torch.Tensor,
|
|
key: torch.Tensor,
|
|
loc: torch.Tensor,
|
|
req_to_token: torch.Tensor,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
scaling=None,
|
|
enable_gqa=False,
|
|
causal=False,
|
|
):
|
|
# set kv cache
|
|
k_cache[loc] = key
|
|
|
|
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
|
query = query.movedim(0, query.dim() - 2)
|
|
|
|
start_q, start_kv = 0, 0
|
|
for seq_idx in range(seq_lens.shape[0]):
|
|
seq_len_q = 1
|
|
seq_len_kv = seq_lens[seq_idx]
|
|
end_q = start_q + seq_len_q
|
|
end_kv = start_kv + seq_len_kv
|
|
|
|
per_req_query = query[:, start_q:end_q, :]
|
|
|
|
# get key and value from cache. per_req_tokens contains the kv cache
|
|
# index for each token in the sequence.
|
|
req_pool_idx = req_pool_indices[seq_idx]
|
|
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
|
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
|
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
|
|
|
per_req_out = (
|
|
scaled_dot_product_attention(
|
|
per_req_query.unsqueeze(0),
|
|
per_req_key.unsqueeze(0),
|
|
per_req_value.unsqueeze(0),
|
|
enable_gqa=enable_gqa,
|
|
scale=scaling,
|
|
is_causal=causal,
|
|
)
|
|
.squeeze(0)
|
|
.movedim(query.dim() - 2, 0)
|
|
)
|
|
output[start_q:end_q, :, :] = per_req_out
|
|
start_q, start_kv = end_q, end_kv
|
|
|
|
return output
|
|
|
|
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, seq_len):
|
|
dtype = torch.bfloat16
|
|
|
|
total_tokens = B * seq_len
|
|
sm_scale = 1.0 / (D**0.5)
|
|
logit_cap = 0.0
|
|
num_kv_splits = 8
|
|
enable_gqa = H_Q != H_KV
|
|
|
|
# q represents the new token being generated, one per batch
|
|
q = torch.randn(B, H_Q, D, dtype=dtype)
|
|
|
|
# k_buffer and v_buffer represent all previous tokens
|
|
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype)
|
|
v_buffer = k_buffer.narrow(2, 0, D_V)
|
|
|
|
key = torch.randn(B, H_KV, D, dtype=dtype)
|
|
value = key.narrow(2, 0, D_V)
|
|
# make sure no duplicates in loc
|
|
loc = torch.randperm(total_tokens)[:B].to(torch.int64)
|
|
|
|
k_buffer2 = k_buffer.clone()
|
|
v_buffer2 = k_buffer2.narrow(2, 0, D_V)
|
|
|
|
# o will have the same shape as q
|
|
o = torch.zeros(B, H_Q, D_V, dtype=dtype)
|
|
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype)
|
|
|
|
req_to_token = torch.arange(total_tokens).reshape(B, seq_len).to(torch.int32)
|
|
b_req_idx = torch.arange(B).to(torch.int64)
|
|
b_seq_len = torch.full((B,), seq_len).to(torch.int64)
|
|
|
|
attn_logits = torch.empty(
|
|
(B, H_Q, num_kv_splits, D_V + 1),
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
torch.ops.sgl_kernel.decode_attention_cpu(
|
|
q,
|
|
k_buffer2,
|
|
v_buffer2,
|
|
o,
|
|
key,
|
|
value,
|
|
loc,
|
|
attn_logits,
|
|
req_to_token,
|
|
b_req_idx,
|
|
b_seq_len,
|
|
sm_scale,
|
|
logit_cap,
|
|
)
|
|
|
|
self._run_sdpa_forward_decode(
|
|
q,
|
|
o_grouped,
|
|
k_buffer,
|
|
v_buffer,
|
|
key,
|
|
loc,
|
|
req_to_token,
|
|
b_req_idx,
|
|
b_seq_len,
|
|
scaling=sm_scale,
|
|
enable_gqa=enable_gqa,
|
|
)
|
|
|
|
cos_sim = torch.nn.functional.cosine_similarity(
|
|
o.flatten(), o_grouped.flatten(), dim=0
|
|
)
|
|
atol = rtol = precision[q.dtype]
|
|
self.assertGreater(cos_sim.item(), 0.99)
|
|
torch.testing.assert_close(o, o_grouped, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(k_buffer, k_buffer2, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(v_buffer, v_buffer2, atol=atol, rtol=rtol)
|
|
|
|
def test_grouped_decode_attention(self):
|
|
configs = [
|
|
(1, 22, 1, 576, 512, 8 * 111),
|
|
(4, 22, 1, 576, 512, 8 * 128),
|
|
(40, 22, 1, 576, 512, 8 * 133),
|
|
]
|
|
|
|
for B, H_Q, H_KV, D, D_V, seqlen in configs:
|
|
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V, seqlen)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|