433 lines
14 KiB
Python
433 lines
14 KiB
Python
import unittest
|
|
|
|
import sgl_kernel
|
|
import torch
|
|
from utils import (
|
|
convert_weight,
|
|
native_w8a8_per_token_matmul,
|
|
per_token_quant_int8,
|
|
precision,
|
|
)
|
|
|
|
from sglang.srt.layers.rotary_embedding import _apply_rotary_emb
|
|
from sglang.test.test_utils import CustomTestCase
|
|
|
|
convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed
|
|
qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope
|
|
qkv_proj_with_rope_fused_weight = torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight
|
|
torch.manual_seed(1234)
|
|
# constants
|
|
kv_lora_rank = 512
|
|
qk_head_dim = 192
|
|
qk_nope_head_dim = 128
|
|
qk_rope_head_dim = 64
|
|
rotary_dim = qk_rope_head_dim
|
|
num_heads = 22
|
|
q_lora_rank = 1536
|
|
hidden_size = 7168
|
|
B = 1
|
|
eps = 1e-6
|
|
|
|
|
|
def layernorm(x, weight, variance_epsilon=1e-6, residual=None):
|
|
orig_dtype = x.dtype
|
|
x = x.to(torch.float32)
|
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
|
x = x * torch.rsqrt(variance + variance_epsilon)
|
|
return (x * weight).to(orig_dtype)
|
|
|
|
|
|
def rotary_emb(q_pe, k_pe, pos, cos_sin_cache):
|
|
orig_dtype = q_pe.dtype
|
|
q_pe = q_pe.float()
|
|
k_pe = k_pe.float()
|
|
cos_sin_cache = cos_sin_cache.float()
|
|
|
|
query_rot = q_pe[..., :rotary_dim]
|
|
key_rot = k_pe[..., :rotary_dim]
|
|
cos_sin = cos_sin_cache[pos]
|
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
|
query_rot = _apply_rotary_emb(query_rot, cos, sin, False)
|
|
key_rot = _apply_rotary_emb(key_rot, cos, sin, False)
|
|
return query_rot.to(orig_dtype), key_rot.to(orig_dtype)
|
|
|
|
|
|
def native_torch(
|
|
q_input,
|
|
hidden_states,
|
|
q_a_proj_weight,
|
|
norm_weight1,
|
|
q_b_proj_weight,
|
|
w_kc,
|
|
kv_a_proj_weight,
|
|
norm_weight2,
|
|
pos,
|
|
cos_sin_cache,
|
|
):
|
|
|
|
q = torch.matmul(hidden_states, q_a_proj_weight.t())
|
|
q = layernorm(q, norm_weight1)
|
|
q = torch.matmul(q, q_b_proj_weight.t()).view(-1, num_heads, qk_head_dim)
|
|
|
|
q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc)
|
|
|
|
q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1)
|
|
latent_cache = torch.matmul(hidden_states, kv_a_proj_weight.t())
|
|
v_input = latent_cache[..., :kv_lora_rank]
|
|
v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1)
|
|
k_input = latent_cache.unsqueeze(1)
|
|
k_input[..., :kv_lora_rank] = v_input
|
|
k_pe = k_input[..., kv_lora_rank:]
|
|
|
|
q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache)
|
|
q_input[..., kv_lora_rank:] = q_pe
|
|
k_input[..., kv_lora_rank:] = k_pe
|
|
|
|
return q_input, k_input, v_input
|
|
|
|
|
|
def native_torch_int8(
|
|
q_input,
|
|
hidden_states,
|
|
w1_q,
|
|
w1_s,
|
|
norm_weight1,
|
|
w2_q,
|
|
w2_s,
|
|
w_kc,
|
|
w3_q,
|
|
w3_s,
|
|
norm_weight2,
|
|
pos,
|
|
cos_sin_cache,
|
|
):
|
|
|
|
a_q, a_s = per_token_quant_int8(hidden_states)
|
|
q = native_w8a8_per_token_matmul(a_q, w1_q, a_s, w1_s, None, torch.bfloat16)
|
|
q = layernorm(q, norm_weight1)
|
|
|
|
a_q, a_s = per_token_quant_int8(q)
|
|
q = native_w8a8_per_token_matmul(a_q, w2_q, a_s, w2_s, None, torch.bfloat16).view(
|
|
-1, num_heads, qk_head_dim
|
|
)
|
|
|
|
q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc)
|
|
|
|
q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1)
|
|
a_q, a_s = per_token_quant_int8(hidden_states)
|
|
latent_cache = native_w8a8_per_token_matmul(
|
|
a_q, w3_q, a_s, w3_s, None, torch.bfloat16
|
|
)
|
|
v_input = latent_cache[..., :kv_lora_rank]
|
|
v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1)
|
|
k_input = latent_cache.unsqueeze(1)
|
|
k_input[..., :kv_lora_rank] = v_input
|
|
k_pe = k_input[..., kv_lora_rank:]
|
|
|
|
q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache)
|
|
q_input[..., kv_lora_rank:] = q_pe
|
|
k_input[..., kv_lora_rank:] = k_pe
|
|
|
|
return q_input, k_input, v_input
|
|
|
|
|
|
class TestQKVProjWithROPE(CustomTestCase):
|
|
def test_bf16_qkv_proj_with_rope(self):
|
|
dtype = torch.bfloat16
|
|
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
|
|
q_input = torch.empty(
|
|
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
|
|
)
|
|
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
|
|
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
|
|
q_b_proj_weight = (
|
|
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
|
|
)
|
|
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
|
|
kv_a_proj_weight = (
|
|
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
|
)
|
|
fused_weight = torch.cat([q_a_proj_weight, kv_a_proj_weight], dim=0)
|
|
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
|
pos = torch.randint(10, 100, (B,))
|
|
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
|
q_ref, k_ref, v_ref = native_torch(
|
|
q_input,
|
|
hidden_states,
|
|
q_a_proj_weight,
|
|
norm_weight1,
|
|
q_b_proj_weight,
|
|
w_kc.transpose(1, 2),
|
|
kv_a_proj_weight,
|
|
norm_weight2,
|
|
pos,
|
|
cos_sin_cache,
|
|
)
|
|
qa_packed = convert_weight_packed(q_a_proj_weight)
|
|
qb_packed = convert_weight_packed(q_b_proj_weight)
|
|
kva_packed = convert_weight_packed(kv_a_proj_weight)
|
|
wkc_packed = convert_weight_packed(w_kc)
|
|
fused_weight_packed = convert_weight_packed(fused_weight)
|
|
|
|
q_out, k_out, v_out = qkv_proj_with_rope(
|
|
hidden_states,
|
|
qa_packed,
|
|
qb_packed,
|
|
kva_packed,
|
|
wkc_packed,
|
|
norm_weight1,
|
|
norm_weight2,
|
|
pos,
|
|
cos_sin_cache,
|
|
eps,
|
|
False,
|
|
False,
|
|
None,
|
|
None,
|
|
None,
|
|
True,
|
|
None,
|
|
)
|
|
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
|
hidden_states,
|
|
fused_weight_packed,
|
|
qb_packed,
|
|
wkc_packed,
|
|
norm_weight1,
|
|
norm_weight2,
|
|
pos,
|
|
cos_sin_cache,
|
|
eps,
|
|
False,
|
|
False,
|
|
None,
|
|
None,
|
|
True,
|
|
None,
|
|
q_lora_rank,
|
|
kv_lora_rank,
|
|
qk_rope_head_dim,
|
|
)
|
|
atol = rtol = precision[q_ref.dtype]
|
|
torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(fused_q_out, q_out)
|
|
torch.testing.assert_close(fused_k_out, k_out)
|
|
torch.testing.assert_close(fused_v_out, v_out)
|
|
|
|
def test_int8_qkv_proj_with_rope(self):
|
|
dtype = torch.bfloat16
|
|
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
|
|
q_input = torch.empty(
|
|
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
|
|
)
|
|
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
|
|
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
|
|
q_b_proj_weight = (
|
|
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
|
|
)
|
|
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
|
|
kv_a_proj_weight = (
|
|
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
|
)
|
|
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
|
pos = torch.randint(10, 100, (B,))
|
|
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
|
|
|
w1_q, w1_s = per_token_quant_int8(q_a_proj_weight)
|
|
w2_q, w2_s = per_token_quant_int8(q_b_proj_weight)
|
|
w3_q, w3_s = per_token_quant_int8(kv_a_proj_weight)
|
|
q_ref, k_ref, v_ref = native_torch_int8(
|
|
q_input,
|
|
hidden_states,
|
|
w1_q,
|
|
w1_s,
|
|
norm_weight1,
|
|
w2_q,
|
|
w2_s,
|
|
w_kc.transpose(1, 2),
|
|
w3_q,
|
|
w3_s,
|
|
norm_weight2,
|
|
pos,
|
|
cos_sin_cache,
|
|
)
|
|
w1_q_packed = convert_weight_packed(w1_q)
|
|
w2_q_packed = convert_weight_packed(w2_q)
|
|
w3_q_packed = convert_weight_packed(w3_q)
|
|
wkc_packed = convert_weight_packed(w_kc)
|
|
q_out, k_out, v_out = qkv_proj_with_rope(
|
|
hidden_states,
|
|
w1_q_packed,
|
|
w2_q_packed,
|
|
w3_q_packed,
|
|
wkc_packed,
|
|
norm_weight1,
|
|
norm_weight2,
|
|
pos,
|
|
cos_sin_cache,
|
|
eps,
|
|
True,
|
|
False,
|
|
w1_s,
|
|
w2_s,
|
|
w3_s,
|
|
True,
|
|
None,
|
|
)
|
|
fused_weight = torch.cat([w1_q, w3_q], dim=0)
|
|
fused_weight_s = torch.cat([w1_s, w3_s], dim=0)
|
|
w_fused_q_packed = convert_weight_packed(fused_weight)
|
|
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
|
hidden_states,
|
|
w_fused_q_packed,
|
|
w2_q_packed,
|
|
wkc_packed,
|
|
norm_weight1,
|
|
norm_weight2,
|
|
pos,
|
|
cos_sin_cache,
|
|
eps,
|
|
True,
|
|
False,
|
|
fused_weight_s,
|
|
w2_s,
|
|
True,
|
|
None,
|
|
q_lora_rank,
|
|
kv_lora_rank,
|
|
qk_rope_head_dim,
|
|
)
|
|
atol = rtol = precision[q_ref.dtype]
|
|
torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(fused_q_out, q_out)
|
|
torch.testing.assert_close(fused_k_out, k_out)
|
|
torch.testing.assert_close(fused_v_out, v_out)
|
|
|
|
def test_fp8_qkv_proj_with_rope(self):
|
|
dtype = torch.bfloat16
|
|
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
|
|
q_input = torch.empty(
|
|
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
|
|
)
|
|
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
|
|
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
|
|
q_b_proj_weight = (
|
|
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
|
|
)
|
|
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
|
|
kv_a_proj_weight = (
|
|
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
|
)
|
|
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
|
pos = torch.randint(10, 100, (B,))
|
|
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
|
|
|
scale_block_size_N = 128
|
|
scale_block_size_K = 128
|
|
fp8_q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_proj_weight_dq = (
|
|
convert_weight(
|
|
q_a_proj_weight,
|
|
[scale_block_size_N, scale_block_size_K],
|
|
torch.bfloat16,
|
|
)
|
|
)
|
|
fp8_q_b_proj_weight, q_b_proj_weight_scale_inv, q_b_proj_weight_dq = (
|
|
convert_weight(
|
|
q_b_proj_weight,
|
|
[scale_block_size_N, scale_block_size_K],
|
|
torch.bfloat16,
|
|
)
|
|
)
|
|
(
|
|
fp8_kv_a_proj_with_mqa_weight,
|
|
kv_a_proj_with_mqa_weight_scale_inv,
|
|
kv_a_proj_with_mqa_weight_dq,
|
|
) = convert_weight(
|
|
kv_a_proj_weight, [scale_block_size_N, scale_block_size_K], torch.bfloat16
|
|
)
|
|
q_ref, k_ref, v_ref = native_torch(
|
|
q_input,
|
|
hidden_states,
|
|
q_a_proj_weight_dq,
|
|
norm_weight1,
|
|
q_b_proj_weight_dq,
|
|
w_kc.transpose(1, 2),
|
|
kv_a_proj_with_mqa_weight_dq,
|
|
norm_weight2,
|
|
pos,
|
|
cos_sin_cache,
|
|
)
|
|
fp8_q_a_proj_weight_packed = convert_weight_packed(fp8_q_a_proj_weight)
|
|
fp8_q_b_proj_weight_packed = convert_weight_packed(fp8_q_b_proj_weight)
|
|
fp8_kv_a_proj_with_mqa_weight_packed = convert_weight_packed(
|
|
fp8_kv_a_proj_with_mqa_weight
|
|
)
|
|
w_kc = convert_weight_packed(w_kc)
|
|
q_out, k_out, v_out = qkv_proj_with_rope(
|
|
hidden_states,
|
|
fp8_q_a_proj_weight_packed,
|
|
fp8_q_b_proj_weight_packed,
|
|
fp8_kv_a_proj_with_mqa_weight_packed,
|
|
w_kc,
|
|
norm_weight1,
|
|
norm_weight2,
|
|
pos,
|
|
cos_sin_cache,
|
|
eps,
|
|
False,
|
|
True,
|
|
q_a_proj_weight_scale_inv.float(),
|
|
q_b_proj_weight_scale_inv.float(),
|
|
kv_a_proj_with_mqa_weight_scale_inv.float(),
|
|
True,
|
|
[scale_block_size_N, scale_block_size_K],
|
|
)
|
|
|
|
fused_weight = torch.cat(
|
|
[fp8_q_a_proj_weight, fp8_kv_a_proj_with_mqa_weight], dim=0
|
|
)
|
|
fused_weight_s = torch.cat(
|
|
[q_a_proj_weight_scale_inv, kv_a_proj_with_mqa_weight_scale_inv], dim=0
|
|
)
|
|
fused_weight_packed = convert_weight_packed(fused_weight)
|
|
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
|
hidden_states,
|
|
fused_weight_packed,
|
|
fp8_q_b_proj_weight_packed,
|
|
w_kc,
|
|
norm_weight1,
|
|
norm_weight2,
|
|
pos,
|
|
cos_sin_cache,
|
|
eps,
|
|
False,
|
|
True,
|
|
fused_weight_s.float(),
|
|
q_b_proj_weight_scale_inv.float(),
|
|
True,
|
|
[scale_block_size_N, scale_block_size_K],
|
|
q_lora_rank,
|
|
kv_lora_rank,
|
|
qk_rope_head_dim,
|
|
)
|
|
atol = rtol = precision[q_ref.dtype]
|
|
# Due to the change in multiplication order, the error is amplified.
|
|
# In the model, with fewer layers, this doesn't cause issues, but in
|
|
# tests with more layers, we need to enlarge the tolerance to pass the tests.
|
|
torch.testing.assert_close(q_ref, q_out, atol=1e-1, rtol=1e-1)
|
|
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
|
|
torch.testing.assert_close(fused_q_out, q_out)
|
|
torch.testing.assert_close(fused_k_out, k_out)
|
|
torch.testing.assert_close(fused_v_out, v_out)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|