492 lines
18 KiB
Python
492 lines
18 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
import flashinfer
|
|
|
|
|
|
def wmape(target: torch.Tensor, preds: torch.Tensor):
|
|
sum_abs_error = (preds - target).abs().sum().detach().item()
|
|
sum_scale = target.abs().sum().detach().item()
|
|
return sum_abs_error / sum_scale
|
|
|
|
|
|
from rope_reference import *
|
|
|
|
|
|
class DeepseekV2RMSNorm(nn.Module):
|
|
def __init__(self, hidden_size, eps=1e-6):
|
|
"""
|
|
DeepseekV2RMSNorm is equivalent to T5LayerNorm
|
|
"""
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return (self.weight * hidden_states).to(input_dtype)
|
|
|
|
|
|
class DeepseekV2AttentionVanilla(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
self.hidden_size = 5120
|
|
self.num_heads = 128
|
|
|
|
self.q_lora_rank = 1536
|
|
self.qk_rope_head_dim = 64
|
|
self.kv_lora_rank = 512
|
|
self.v_head_dim = 128
|
|
self.qk_nope_head_dim = 128
|
|
self.q_head_dim = 192 # 192 = config.qk_nope_head_dim + config.qk_rope_head_dim
|
|
|
|
self.rope_theta = 10000
|
|
|
|
# W^DQ ~ [5120, 1536]
|
|
self.q_a_proj = nn.Linear(
|
|
self.hidden_size,
|
|
self.q_lora_rank,
|
|
bias=False,
|
|
)
|
|
torch.nn.init.normal_(self.q_a_proj.weight)
|
|
|
|
self.q_a_layernorm = DeepseekV2RMSNorm(self.q_lora_rank)
|
|
|
|
# W^UQ & W^QR = [1536, 128*(128+64)]
|
|
self.q_b_proj = nn.Linear(
|
|
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
|
)
|
|
torch.nn.init.normal_(self.q_b_proj.weight)
|
|
|
|
# We don't need these modules, since we already have cached k_pe and compressed_kv tensor.
|
|
# self.kv_a_proj_with_mqa = nn.Linear( # [,5120]-->[, 512+64] W^DKV & W^KR = [5120, 512+64]
|
|
# self.hidden_size,
|
|
# self.kv_lora_rank + self.qk_rope_head_dim,
|
|
# bias=False,
|
|
# )
|
|
# self.kv_a_layernorm = DeepseekV2RMSNorm(self.kv_lora_rank)
|
|
|
|
# W^UK & W^UV ~ [512, 128*(128+128)]
|
|
self.kv_b_proj = nn.Linear(
|
|
self.kv_lora_rank,
|
|
self.num_heads
|
|
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
|
bias=False,
|
|
)
|
|
torch.nn.init.normal_(self.kv_b_proj.weight)
|
|
|
|
# W^O ~ [128*128, 5120]
|
|
self.o_proj = nn.Linear(
|
|
self.num_heads * self.v_head_dim,
|
|
self.hidden_size,
|
|
bias=False,
|
|
)
|
|
torch.nn.init.normal_(self.o_proj.weight)
|
|
|
|
self.softmax_scale = self.q_head_dim ** (-0.5)
|
|
|
|
def run_decode(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
compressed_kv_normed_cache: torch.Tensor,
|
|
k_pe_cache: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
bsz, q_len, _ = hidden_states.size()
|
|
if q_len != 1:
|
|
raise ValueError(
|
|
f"Only support decode, but got hidden_states[{hidden_states.size()}]"
|
|
)
|
|
|
|
ckv_bsz, kv_len, ckv_dim = compressed_kv_normed_cache.size()
|
|
if ckv_bsz != bsz or ckv_dim != self.kv_lora_rank:
|
|
raise ValueError(
|
|
f"Unexpected shape: compressed_kv_normed_cache[{compressed_kv_normed_cache.size()}]"
|
|
)
|
|
|
|
kpe_bsz, kpe_len, kpe_dim = k_pe_cache.size()
|
|
if kpe_bsz != bsz or kpe_dim != self.qk_rope_head_dim or kv_len != kpe_len:
|
|
raise ValueError(f"Unexpected shape: k_pe_cache[{k_pe_cache.size()}]")
|
|
|
|
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
|
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
|
# q_nope ~ [bsz, q_len, 128] q_pe ~ [bsz, q_len, 64]
|
|
q_nope, q_pe = torch.split(
|
|
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
|
)
|
|
|
|
k_pe = k_pe_cache.view(bsz, kv_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
|
kv = (
|
|
self.kv_b_proj(compressed_kv_normed_cache)
|
|
.view(bsz, kv_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
.transpose(1, 2)
|
|
)
|
|
# k_nope ~ [bsz, num_heads, kv_len, 128] value_states ~ [bsz, num_heads, kv_len, 128]
|
|
k_nope, value_states = torch.split(
|
|
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
|
)
|
|
if k_nope.size() != (bsz, self.num_heads, kv_len, self.qk_nope_head_dim):
|
|
raise ValueError(f"k_nope[{k_nope.size()}]")
|
|
if value_states.size() != (bsz, self.num_heads, kv_len, self.v_head_dim):
|
|
raise ValueError(f"value_states[{value_states.size()}]")
|
|
|
|
freqs_cis = precompute_freqs_cis(
|
|
self.qk_rope_head_dim, kv_len, self.rope_theta, use_scaled=False
|
|
).to(q_pe.device)
|
|
q_pe, k_pe = apply_rotary_emb(
|
|
q_pe.transpose(1, 2).repeat(1, kv_len, 1, 1),
|
|
k_pe.transpose(1, 2),
|
|
freqs_cis,
|
|
)
|
|
q_pe = q_pe[:, -1:, :, :].transpose(1, 2)
|
|
k_pe = k_pe.transpose(1, 2)
|
|
|
|
# Concat q_nope and q_pe to produce a new Q tensor with head_dim = 192
|
|
query_states = q.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
|
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
|
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
|
|
|
# Concat k_nope and k_pe to produce a new K tensor with head_dim = 192
|
|
key_states = k_pe.new_empty(bsz, self.num_heads, kv_len, self.q_head_dim)
|
|
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
|
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
|
|
|
attn_weights = (
|
|
torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
|
|
)
|
|
|
|
attn_weights = nn.functional.softmax(
|
|
attn_weights, dim=-1, dtype=torch.float32
|
|
).to(query_states.dtype)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
attn_output = attn_output.transpose(1, 2).reshape(
|
|
bsz, q_len, self.num_heads * self.v_head_dim
|
|
)
|
|
|
|
output = self.o_proj(attn_output)
|
|
|
|
return output
|
|
|
|
|
|
class DeepseekV2AttentionMatAbsorbDecode(nn.Module):
|
|
def __init__(self, mla_vanilla: DeepseekV2AttentionVanilla):
|
|
super().__init__()
|
|
|
|
self.hidden_size = mla_vanilla.hidden_size # 5120
|
|
self.num_heads = mla_vanilla.num_heads # 128
|
|
|
|
self.q_lora_rank = mla_vanilla.q_lora_rank # 1536
|
|
self.qk_rope_head_dim = mla_vanilla.qk_rope_head_dim # 64
|
|
self.kv_lora_rank = mla_vanilla.kv_lora_rank # 512
|
|
self.v_head_dim = mla_vanilla.v_head_dim # 128
|
|
self.qk_nope_head_dim = mla_vanilla.qk_nope_head_dim # 128
|
|
self.q_head_dim = (
|
|
mla_vanilla.q_head_dim
|
|
) # qk_nope_head_dim + qk_rope_head_dim # 128+64=192
|
|
|
|
self.softmax_scale = mla_vanilla.softmax_scale
|
|
|
|
self.rope_theta = mla_vanilla.rope_theta
|
|
|
|
# W^DQ ~ [5120, 1536]
|
|
self.W_DQ = mla_vanilla.q_a_proj.weight.transpose(0, 1)
|
|
|
|
self.q_a_layernorm = DeepseekV2RMSNorm(self.q_lora_rank)
|
|
|
|
# W_UQ ~ [1536, 128, 128]
|
|
W_UQ, W_QR = torch.split(
|
|
mla_vanilla.q_b_proj.weight.t().view(
|
|
self.q_lora_rank, self.num_heads, self.q_head_dim
|
|
),
|
|
[self.qk_nope_head_dim, self.qk_rope_head_dim],
|
|
-1,
|
|
)
|
|
# W_UQ ~ [1536, 128*64]
|
|
self.W_QR = W_QR.reshape(
|
|
self.q_lora_rank, self.num_heads * self.qk_rope_head_dim
|
|
)
|
|
|
|
# W_UK ~ [512, 128, 128] W_UV ~ [512, 128, 128]
|
|
W_UK, W_UV = torch.split(
|
|
mla_vanilla.kv_b_proj.weight.t().view(
|
|
self.kv_lora_rank,
|
|
self.num_heads,
|
|
self.qk_nope_head_dim + self.v_head_dim,
|
|
),
|
|
[self.qk_nope_head_dim, self.v_head_dim],
|
|
-1,
|
|
)
|
|
|
|
# Now we merge W_UQ and W_UK (absorb W_UK into W_UQ)
|
|
# q~q_lora_rank n~num_heads d~qk_nope_head_dim l~kv_lora_rank
|
|
self.W_UQ_UK = torch.einsum("q n d, l n d -> q n l", W_UQ, W_UK).flatten(
|
|
start_dim=1
|
|
) # [1536, 65536]
|
|
|
|
W_O = mla_vanilla.o_proj.weight.view(
|
|
self.hidden_size, self.num_heads, self.v_head_dim
|
|
)
|
|
|
|
# Merge W_UV and W_O (absorb W_UV into W_O)
|
|
# l~kv_lora_rank n~num_heads d~v_head_dim h~hidden_size
|
|
self.W_UV_O = torch.einsum("l n d, h n d -> n l h", W_UV, W_O).flatten(
|
|
start_dim=0, end_dim=1
|
|
) # [65536, 5120]
|
|
|
|
def run_proof_of_concept(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
compressed_kv_normed_cache: torch.Tensor,
|
|
k_pe_cache: torch.Tensor,
|
|
use_flashinfer_kernel: bool,
|
|
convert_float16: bool,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
c_Q = torch.matmul(hidden_states, self.W_DQ)
|
|
# c_Q ~ [bsz, q_lora_rank:1536]
|
|
c_Q = self.q_a_layernorm(c_Q)
|
|
|
|
q_pe = torch.matmul(
|
|
c_Q,
|
|
self.W_QR, # c_Q ~ [bsz, q_lora_rank~1536]
|
|
) # W_QR ~ [1536, 128*64]
|
|
# q_pe ~ [bsz, 128, 64]
|
|
q_pe = q_pe.reshape(bsz, self.num_heads, self.qk_rope_head_dim)
|
|
|
|
q_nope = torch.matmul(c_Q, self.W_UQ_UK) # W_UQ_UK~[1536, 128*512]
|
|
# q_nope ~ [bsz, 128, 512]
|
|
q_nope = q_nope.reshape(bsz, self.num_heads, self.kv_lora_rank)
|
|
|
|
q_kv_dtype = torch.float16
|
|
if convert_float16:
|
|
q_nope = q_nope.to(q_kv_dtype)
|
|
q_pe = q_pe.to(q_kv_dtype)
|
|
compressed_kv_normed_cache = compressed_kv_normed_cache.to(q_kv_dtype)
|
|
k_pe_cache = k_pe_cache.to(q_kv_dtype)
|
|
|
|
if not use_flashinfer_kernel:
|
|
freqs_cis = precompute_freqs_cis(
|
|
self.qk_rope_head_dim, kv_len, self.rope_theta, use_scaled=False
|
|
).to(k_pe_cache.device)
|
|
q_pe, k_pe_cache = apply_rotary_emb(
|
|
q_pe.unsqueeze(1).repeat(1, kv_len, 1, 1),
|
|
k_pe_cache.unsqueeze(2),
|
|
freqs_cis,
|
|
)
|
|
q_pe = q_pe[:, -1:, :, :].squeeze(1)
|
|
k_pe_cache = k_pe_cache.squeeze(2)
|
|
|
|
# attn_weights_pe ~ [bsz, 128, kv_len]
|
|
attn_weights_pe = torch.matmul(
|
|
q_pe, # [bsz, num_heads, qk_rope_head_dim]
|
|
k_pe_cache.transpose(
|
|
1, 2
|
|
), # [bsz, kv_len, 64] view(bsz, kv_len, self.qk_rope_head_dim)
|
|
)
|
|
# attn_weights_nope ~ [bsz, 128, kv_len]
|
|
attn_weights_nope = torch.matmul(
|
|
q_nope, # [bsz, 128, 512]
|
|
compressed_kv_normed_cache.transpose(1, 2), # view(bsz, kv_len, 512)
|
|
)
|
|
|
|
attn_weights = (attn_weights_pe + attn_weights_nope) * self.softmax_scale
|
|
|
|
attn_weights = nn.functional.softmax(
|
|
attn_weights, dim=-1, dtype=torch.float32
|
|
).to(q_nope.dtype)
|
|
|
|
# attn_output ~ {attn_output.shape}") # [bsz, 128, 512]
|
|
attn_output = torch.matmul(
|
|
attn_weights, # [bsz, 128, kv_len]
|
|
compressed_kv_normed_cache, # [bsz, kv_len, 512]
|
|
)
|
|
|
|
else:
|
|
print("Now use MLA decode kernel!\n")
|
|
if kv_len % page_size != 0:
|
|
raise ValueError(
|
|
"For simplicity, kv_len should be multiple of page_size."
|
|
)
|
|
freqs_cis = precompute_freqs_cis(
|
|
self.qk_rope_head_dim, kv_len, self.rope_theta, use_scaled=False
|
|
).to(k_pe_cache.device)
|
|
q_pe, k_pe_cache = apply_rotary_emb(
|
|
q_pe.unsqueeze(1).repeat(1, kv_len, 1, 1),
|
|
k_pe_cache.unsqueeze(2),
|
|
freqs_cis,
|
|
)
|
|
q_pe = q_pe[:, -1:, :, :].squeeze(1).contiguous()
|
|
k_pe_cache = k_pe_cache.squeeze(2)
|
|
num_pages_per_seq = kv_len // page_size
|
|
total_num_pages = num_pages_per_seq * bsz
|
|
|
|
kv_indptr = torch.arange(0, bsz + 1).to(dev_id).int() * num_pages_per_seq
|
|
kv_indices = torch.arange(0, total_num_pages).to(dev_id).int()
|
|
kv_last_page_len = torch.full((bsz,), page_size, dtype=torch.int32).to(
|
|
dev_id
|
|
)
|
|
|
|
paged_ckv_cache = compressed_kv_normed_cache.reshape(
|
|
total_num_pages, page_size, self.kv_lora_rank
|
|
)
|
|
paged_kpe_cache = k_pe_cache.reshape(
|
|
total_num_pages, page_size, self.qk_rope_head_dim
|
|
)
|
|
|
|
workspace_buffer = torch.empty(64 * 1024 * 1024, dtype=torch.int8).to(
|
|
dev_id
|
|
)
|
|
wrapper = flashinfer.BatchDecodeMlaWithPagedKVCacheWrapper(
|
|
workspace_buffer,
|
|
use_cuda_graph=True,
|
|
use_tensor_cores=True,
|
|
paged_kv_indptr_buffer=kv_indptr,
|
|
paged_kv_indices_buffer=kv_indices,
|
|
paged_kv_last_page_len_buffer=kv_last_page_len,
|
|
)
|
|
wrapper.plan(
|
|
kv_indptr,
|
|
kv_indices,
|
|
kv_last_page_len,
|
|
num_qo_heads=self.num_heads,
|
|
head_dim_compressed_kv=self.kv_lora_rank,
|
|
page_size=page_size,
|
|
sm_scale=self.softmax_scale,
|
|
rope_theta=self.rope_theta,
|
|
data_type=q_kv_dtype,
|
|
q_data_type=q_kv_dtype,
|
|
)
|
|
|
|
attn_output = wrapper.run(q_nope, q_pe, paged_ckv_cache, paged_kpe_cache)
|
|
|
|
s = torch.cuda.Stream()
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(s):
|
|
for _ in range(3):
|
|
o, lse = wrapper.run(
|
|
q_nope, q_pe, paged_ckv_cache, paged_kpe_cache, return_lse=True
|
|
)
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g):
|
|
attn_output = wrapper.run(
|
|
q_nope, q_pe, paged_ckv_cache, paged_kpe_cache
|
|
)
|
|
g.replay()
|
|
|
|
# output ~ [bsz, 5120]
|
|
output = torch.matmul(
|
|
attn_output.to(self.W_UV_O.dtype).reshape(
|
|
bsz, self.num_heads * self.kv_lora_rank
|
|
),
|
|
self.W_UV_O,
|
|
) # W_UV_O ~ [65536, 5120]
|
|
|
|
return output
|
|
|
|
|
|
if __name__ == "__main__":
|
|
dev_id = 0
|
|
|
|
torch.manual_seed(666)
|
|
torch.set_grad_enabled(False)
|
|
|
|
mla_vanilla = DeepseekV2AttentionVanilla().cuda(device=dev_id)
|
|
|
|
bsz = 6
|
|
kv_len = 640
|
|
page_size = 16
|
|
|
|
hidden_states = torch.randn([bsz, 1, mla_vanilla.hidden_size]).to(dev_id)
|
|
compressed_kv_normed_cache = torch.randn(
|
|
[bsz, kv_len, mla_vanilla.kv_lora_rank]
|
|
).to(dev_id)
|
|
k_pe_cache = torch.randn([bsz, kv_len, mla_vanilla.qk_rope_head_dim]).to(dev_id)
|
|
|
|
output_vanilla = mla_vanilla.run_decode(
|
|
hidden_states, compressed_kv_normed_cache, k_pe_cache
|
|
)
|
|
|
|
mla_mat_absorb = DeepseekV2AttentionMatAbsorbDecode(mla_vanilla).cuda(device=dev_id)
|
|
output_mat_absorbed_use_torch_f32 = mla_mat_absorb.run_proof_of_concept(
|
|
hidden_states.squeeze(1),
|
|
compressed_kv_normed_cache,
|
|
k_pe_cache,
|
|
use_flashinfer_kernel=False,
|
|
convert_float16=False,
|
|
)
|
|
output_mat_absorbed_use_torch_f16 = mla_mat_absorb.run_proof_of_concept(
|
|
hidden_states.squeeze(1),
|
|
compressed_kv_normed_cache,
|
|
k_pe_cache,
|
|
use_flashinfer_kernel=False,
|
|
convert_float16=True,
|
|
)
|
|
output_mat_absorbed_use_flashinfer = mla_mat_absorb.run_proof_of_concept(
|
|
hidden_states.squeeze(1),
|
|
compressed_kv_normed_cache,
|
|
k_pe_cache,
|
|
use_flashinfer_kernel=True,
|
|
convert_float16=True,
|
|
)
|
|
|
|
cos_use_torch_f32 = F.cosine_similarity(
|
|
output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f32.reshape(-1), dim=0
|
|
)
|
|
print(f"cos_use_torch_f32 = {cos_use_torch_f32}")
|
|
assert cos_use_torch_f32 > 0.99
|
|
|
|
wmape_use_torch_f32 = wmape(
|
|
output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f32.reshape(-1)
|
|
)
|
|
print(f"wmape_use_torch_f32 = {wmape_use_torch_f32}")
|
|
assert wmape_use_torch_f32 < 0.02
|
|
|
|
mse_use_torch_f32 = F.mse_loss(
|
|
output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f32.reshape(-1)
|
|
)
|
|
print(f"mse_use_torch_f32={mse_use_torch_f32}\n")
|
|
|
|
cos_use_torch_f16 = F.cosine_similarity(
|
|
output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1), dim=0
|
|
)
|
|
print(f"cos_use_torch_f16 = {cos_use_torch_f16}")
|
|
assert cos_use_torch_f16 > 0.99
|
|
|
|
wmape_use_torch_f16 = wmape(
|
|
output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1)
|
|
)
|
|
print(f"wmape_use_torch_f16 = {wmape_use_torch_f16}")
|
|
assert wmape_use_torch_f16 < 0.03
|
|
|
|
mse_use_torch_f16 = F.mse_loss(
|
|
output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1)
|
|
)
|
|
print(f"mse_use_torch_f16 = {mse_use_torch_f16}\n")
|
|
|
|
cos_use_flashinfer = F.cosine_similarity(
|
|
output_vanilla.reshape(-1),
|
|
output_mat_absorbed_use_flashinfer.reshape(-1),
|
|
dim=0,
|
|
)
|
|
print(f"cos_use_flashinfer = {cos_use_flashinfer}")
|
|
assert cos_use_flashinfer > 0.99
|
|
|
|
wmape_use_flashinfer = wmape(
|
|
output_vanilla.reshape(-1), output_mat_absorbed_use_flashinfer.reshape(-1)
|
|
)
|
|
print(f"wmape_use_flashinfer = {wmape_use_flashinfer}")
|
|
assert wmape_use_flashinfer < 0.02
|
|
|
|
mse_use_flashinfer = F.mse_loss(
|
|
output_vanilla.reshape(-1), output_mat_absorbed_use_flashinfer.reshape(-1)
|
|
)
|
|
print(f"mse_use_flashinfer = {mse_use_flashinfer}")
|