sglang.0.4.8.post1/sglang/test/srt/cpu/test_extend.py

190 lines
6.9 KiB
Python

import unittest
import sgl_kernel
import torch
from torch.nn.functional import scaled_dot_product_attention
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestExtendAttention(CustomTestCase):
def _run_sdpa_forward_extend(
self,
query: torch.Tensor,
output: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
extend_prefix_lens: torch.Tensor,
extend_seq_lens: torch.Tensor,
scaling=None,
enable_gqa=False,
causal=False,
):
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
assert seq_lens.shape[0] == extend_seq_lens.shape[0]
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query = query.movedim(0, query.dim() - 2)
start_q, start_kv = 0, 0
for seq_idx in range(seq_lens.shape[0]):
extend_seq_len_q = extend_seq_lens[seq_idx]
prefill_seq_len_q = extend_prefix_lens[seq_idx]
seq_len_kv = seq_lens[seq_idx]
end_q = start_q + extend_seq_len_q
end_kv = start_kv + seq_len_kv
per_req_query = query[:, start_q:end_q, :]
per_req_query_redudant = torch.empty(
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
dtype=per_req_query.dtype,
device=per_req_query.device,
)
per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query
# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx = req_pool_indices[seq_idx]
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_out_redudant = (
scaled_dot_product_attention(
per_req_query_redudant.unsqueeze(0),
per_req_key.unsqueeze(0),
per_req_value.unsqueeze(0),
enable_gqa=enable_gqa,
scale=scaling,
is_causal=causal,
)
.squeeze(0)
.movedim(query.dim() - 2, 0)
)
output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]
start_q, start_kv = end_q, end_kv
return output
def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D, DV, mla=False):
dtype = torch.bfloat16
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32)
if mla:
b_seq_len_prefix.zero_()
b_seq_len_extend = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
b_req_idx = torch.arange(B, dtype=torch.int32)
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32)
b_start_loc = torch.zeros((B,), dtype=torch.int32)
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
for i in range(B):
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
H_BUF = 1 if mla else H_KV
k_buffer = torch.randn((total_token_num, H_BUF, D), dtype=dtype)
v_buffer = torch.randn((total_token_num, H_BUF, DV), dtype=dtype)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype)
v_extend = torch.empty((extend_token_num, H_KV, DV), dtype=dtype)
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype)
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.randn(
(b_seq_len_extend[i], H_Q, D), dtype=dtype
)
# k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1)
v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1)
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
b_seq_len_extend = b_seq_len - b_seq_len_prefix
b_start_loc_extend = torch.zeros_like(b_seq_len)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
sm_scale = 1.0 / (D**0.5)
logit_cap = 0.0
# handle index type
b_req_idx = b_req_idx.to(torch.int64)
b_seq_len = b_seq_len.to(torch.int64)
enable_gqa = H_Q != H_KV
o_ref = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
self._run_sdpa_forward_extend(
q_extend,
o_ref,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_seq_len_prefix,
b_seq_len_extend,
scaling=sm_scale,
enable_gqa=enable_gqa,
causal=True,
)
o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
torch.ops.sgl_kernel.extend_attention_cpu(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_seq_len_extend,
b_start_loc_extend,
max_len_extend,
sm_scale,
logit_cap,
)
torch.testing.assert_close(o_ref, o_extend, atol=1e-2, rtol=1e-2)
def test_extend_attention(self):
for is_mla in [True, False]:
self._test_extend_attention_once(1, 123, 1, 1, 128, 96, is_mla)
self._test_extend_attention_once(1, 123, 16, 1, 128, 96, is_mla)
self._test_extend_attention_once(4, 1230, 16, 4, 128, 96, is_mla)
if __name__ == "__main__":
unittest.main()