1512 lines
57 KiB
Python
1512 lines
57 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
# Adapted from:
|
|
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
|
"""Inference-only DeepseekV2 model."""
|
|
|
|
import os
|
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from sglang.srt.distributed import (
|
|
get_tensor_model_parallel_world_size,
|
|
parallel_state,
|
|
tensor_model_parallel_all_reduce,
|
|
)
|
|
from sglang.srt.layers.activation import SiluAndMul
|
|
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
|
decode_attention_fwd_grouped_rope,
|
|
)
|
|
from sglang.srt.layers.dp_attention import (
|
|
dp_gather_partial,
|
|
dp_scatter,
|
|
get_attention_dp_size,
|
|
get_attention_tp_rank,
|
|
get_attention_tp_size,
|
|
tp_all_gather,
|
|
tp_reduce_scatter,
|
|
)
|
|
from sglang.srt.layers.layernorm import RMSNorm
|
|
from sglang.srt.layers.linear import (
|
|
ColumnParallelLinear,
|
|
MergedColumnParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
from sglang.srt.layers.moe.topk import select_experts
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
from sglang.srt.layers.quantization.fp8_utils import (
|
|
block_quant_to_tensor_quant,
|
|
input_to_float8,
|
|
normalize_e4m3fn_to_e4m3fnuz,
|
|
)
|
|
from sglang.srt.layers.quantization.int8_utils import (
|
|
block_dequant as int8_block_dequant,
|
|
)
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
|
|
|
_is_hip = is_hip()
|
|
_is_cuda = is_cuda()
|
|
|
|
if _is_cuda:
|
|
from sgl_kernel import awq_dequantize, bmm_fp8
|
|
else:
|
|
from vllm import _custom_ops as ops
|
|
|
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
|
|
|
|
|
class DeepseekV2MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_act: str,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
reduce_results: bool = True,
|
|
prefix: str = "",
|
|
tp_rank: Optional[int] = None,
|
|
tp_size: Optional[int] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size,
|
|
[intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("gate_up_proj", prefix),
|
|
tp_rank=tp_rank,
|
|
tp_size=tp_size,
|
|
)
|
|
self.down_proj = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
reduce_results=reduce_results,
|
|
prefix=add_prefix("down_proj", prefix),
|
|
tp_rank=tp_rank,
|
|
tp_size=tp_size,
|
|
)
|
|
if hidden_act != "silu":
|
|
raise ValueError(
|
|
f"Unsupported activation: {hidden_act}. "
|
|
"Only silu is supported for now."
|
|
)
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x):
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class MoEGate(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(
|
|
torch.empty((config.n_routed_experts, config.hidden_size))
|
|
)
|
|
if config.topk_method == "noaux_tc":
|
|
self.e_score_correction_bias = nn.Parameter(
|
|
torch.empty((config.n_routed_experts))
|
|
)
|
|
else:
|
|
self.e_score_correction_bias = None
|
|
|
|
def forward(self, hidden_states):
|
|
logits = F.linear(hidden_states, self.weight, None)
|
|
return logits
|
|
|
|
|
|
class DeepseekV2MoE(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.routed_scaling_factor = config.routed_scaling_factor
|
|
self.n_shared_experts = config.n_shared_experts
|
|
self.routed_scaling_factor = config.routed_scaling_factor
|
|
if self.tp_size > config.n_routed_experts:
|
|
raise ValueError(
|
|
f"Tensor parallel size {self.tp_size} is greater than "
|
|
f"the number of experts {config.n_routed_experts}."
|
|
)
|
|
|
|
if config.hidden_act != "silu":
|
|
raise ValueError(
|
|
f"Unsupported activation: {config.hidden_act}. "
|
|
"Only silu is supported for now."
|
|
)
|
|
|
|
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
|
|
|
MoEImpl = (
|
|
DeepEPMoE
|
|
if global_server_args_dict["enable_deepep_moe"]
|
|
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
|
)
|
|
self.experts = MoEImpl(
|
|
num_experts=config.n_routed_experts,
|
|
top_k=config.num_experts_per_tok,
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.moe_intermediate_size,
|
|
renormalize=config.norm_topk_prob,
|
|
quant_config=quant_config,
|
|
use_grouped_topk=True,
|
|
num_expert_group=config.n_group,
|
|
topk_group=config.topk_group,
|
|
correction_bias=self.gate.e_score_correction_bias,
|
|
prefix=add_prefix("experts", prefix),
|
|
)
|
|
|
|
if config.n_shared_experts is not None:
|
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
|
# disable tp for shared experts when enable deepep moe
|
|
if not global_server_args_dict["enable_deepep_moe"]:
|
|
self.shared_experts = DeepseekV2MLP(
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
reduce_results=False,
|
|
prefix=add_prefix("shared_experts", prefix),
|
|
)
|
|
else:
|
|
self.shared_experts = DeepseekV2MLP(
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
reduce_results=False,
|
|
prefix=add_prefix("shared_experts", prefix),
|
|
tp_rank=0,
|
|
tp_size=1,
|
|
)
|
|
|
|
if global_server_args_dict["enable_deepep_moe"]:
|
|
self.num_experts = config.n_routed_experts
|
|
self.top_k = config.num_experts_per_tok
|
|
self.renormalize = config.norm_topk_prob
|
|
self.topk_group = config.topk_group
|
|
self.num_expert_group = config.n_group
|
|
self.correction_bias = (
|
|
self.gate.e_score_correction_bias.data
|
|
if self.gate.e_score_correction_bias is not None
|
|
else None
|
|
)
|
|
|
|
self.deepep_dispatcher = DeepEPDispatcher(
|
|
group=parallel_state.get_tp_group().device_group,
|
|
router_topk=self.top_k,
|
|
permute_fusion=True,
|
|
num_experts=config.n_routed_experts,
|
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
|
hidden_size=config.hidden_size,
|
|
params_dtype=config.torch_dtype,
|
|
async_finish=True, # TODO
|
|
)
|
|
|
|
def forward(
|
|
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
|
|
) -> torch.Tensor:
|
|
if not global_server_args_dict["enable_deepep_moe"]:
|
|
return self.forward_normal(hidden_states)
|
|
else:
|
|
return self.forward_deepep(hidden_states, forward_mode)
|
|
|
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
if self.n_shared_experts is not None:
|
|
shared_output = self.shared_experts(hidden_states)
|
|
# router_logits: (num_tokens, n_experts)
|
|
router_logits = self.gate(hidden_states)
|
|
final_hidden_states = (
|
|
self.experts(hidden_states=hidden_states, router_logits=router_logits)
|
|
* self.routed_scaling_factor
|
|
)
|
|
if shared_output is not None:
|
|
final_hidden_states = final_hidden_states + shared_output
|
|
if self.tp_size > 1:
|
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
|
return final_hidden_states
|
|
|
|
def forward_deepep(
|
|
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
|
) -> torch.Tensor:
|
|
shared_output = None
|
|
topk_idx = torch.full(
|
|
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
|
)
|
|
topk_weights = torch.empty(
|
|
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
|
)
|
|
if (
|
|
forward_mode is not None
|
|
and not forward_mode.is_idle()
|
|
and hidden_states.shape[0] > 0
|
|
):
|
|
# router_logits: (num_tokens, n_experts)
|
|
router_logits = self.gate(hidden_states)
|
|
if self.n_shared_experts is not None:
|
|
shared_output = self.shared_experts(hidden_states)
|
|
topk_weights, topk_idx = select_experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=self.top_k,
|
|
use_grouped_topk=True,
|
|
renormalize=self.renormalize,
|
|
topk_group=self.topk_group,
|
|
num_expert_group=self.num_expert_group,
|
|
correction_bias=self.correction_bias,
|
|
)
|
|
if self.tp_size > 1:
|
|
recv_hidden_states, reorder_topk_ids, seg_indptr = (
|
|
self.deepep_dispatcher.dispatch(
|
|
hidden_states,
|
|
topk_idx,
|
|
topk_weights,
|
|
self.num_experts,
|
|
forward_mode,
|
|
)
|
|
)
|
|
final_hidden_states = (
|
|
self.experts(
|
|
hidden_states=recv_hidden_states,
|
|
reorder_topk_ids=reorder_topk_ids,
|
|
seg_indptr=seg_indptr,
|
|
forward_mode=forward_mode,
|
|
)
|
|
* self.routed_scaling_factor
|
|
)
|
|
if self.tp_size > 1:
|
|
final_hidden_states = self.deepep_dispatcher.combine(
|
|
final_hidden_states, forward_mode
|
|
)
|
|
if shared_output is not None:
|
|
final_hidden_states = final_hidden_states + shared_output
|
|
|
|
return final_hidden_states
|
|
|
|
|
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
|
import math
|
|
|
|
if scale <= 1:
|
|
return 1.0
|
|
return 0.1 * mscale * math.log(scale) + 1.0
|
|
|
|
|
|
class DeepseekV2Attention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
v_head_dim: int,
|
|
q_lora_rank: int,
|
|
kv_lora_rank: int,
|
|
rope_theta: float = 10000,
|
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
|
max_position_embeddings: int = 8192,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
layer_id=None,
|
|
reduce_results: bool = True,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.layer_id = layer_id
|
|
self.hidden_size = hidden_size
|
|
self.qk_nope_head_dim = qk_nope_head_dim
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
|
self.v_head_dim = v_head_dim
|
|
self.q_lora_rank = q_lora_rank
|
|
self.kv_lora_rank = kv_lora_rank
|
|
|
|
self.dp_size = get_attention_dp_size()
|
|
attn_tp_rank = get_attention_tp_rank()
|
|
attn_tp_size = get_attention_tp_size()
|
|
|
|
self.num_heads = num_heads
|
|
assert num_heads % attn_tp_size == 0
|
|
self.num_local_heads = num_heads // attn_tp_size
|
|
self.scaling = self.qk_head_dim**-0.5
|
|
self.rope_theta = rope_theta
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
if self.q_lora_rank is not None:
|
|
self.q_a_proj = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.q_lora_rank,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("q_a_proj", prefix),
|
|
)
|
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
|
self.q_b_proj = ColumnParallelLinear(
|
|
q_lora_rank,
|
|
self.num_heads * self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("q_b_proj", prefix),
|
|
)
|
|
else:
|
|
self.q_proj = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
self.num_heads * self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("q_proj", prefix),
|
|
tp_rank=attn_tp_rank,
|
|
tp_size=attn_tp_size,
|
|
)
|
|
|
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
|
)
|
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
|
self.kv_b_proj = ColumnParallelLinear(
|
|
self.kv_lora_rank,
|
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("kv_b_proj", prefix),
|
|
)
|
|
# O projection.
|
|
self.o_proj = RowParallelLinear(
|
|
self.num_heads * self.v_head_dim,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("o_proj", prefix),
|
|
reduce_results=reduce_results,
|
|
tp_rank=attn_tp_rank,
|
|
tp_size=attn_tp_size,
|
|
)
|
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
|
self.rotary_emb = get_rope_wrapper(
|
|
qk_rope_head_dim,
|
|
rotary_dim=qk_rope_head_dim,
|
|
max_position=max_position_embeddings,
|
|
base=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
is_neox_style=False,
|
|
device=global_server_args_dict["device"],
|
|
)
|
|
|
|
if rope_scaling:
|
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
|
scaling_factor = rope_scaling["factor"]
|
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
|
self.scaling = self.scaling * mscale * mscale
|
|
|
|
# TODO, support head_size 192
|
|
self.attn = RadixAttention(
|
|
self.num_local_heads,
|
|
256,
|
|
self.scaling,
|
|
num_kv_heads=self.num_local_heads,
|
|
layer_id=layer_id,
|
|
prefix=add_prefix("attn", prefix),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
if hidden_states.shape[0] == 0:
|
|
assert (
|
|
not self.o_proj.reduce_results
|
|
), "short-circuiting allreduce will lead to hangs"
|
|
return hidden_states
|
|
|
|
if self.q_lora_rank is not None:
|
|
q = self.q_a_proj(hidden_states)[0]
|
|
q = self.q_a_layernorm(q)
|
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
|
else:
|
|
q = self.q_proj(hidden_states)[0].view(
|
|
-1, self.num_local_heads, self.qk_head_dim
|
|
)
|
|
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
|
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
|
latent_cache = latent_cache.unsqueeze(1)
|
|
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
|
kv = self.kv_b_proj(kv_a)[0]
|
|
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
|
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
|
q[..., self.qk_nope_head_dim :] = q_pe
|
|
k = torch.empty_like(q)
|
|
k[..., : self.qk_nope_head_dim] = k_nope
|
|
k[..., self.qk_nope_head_dim :] = k_pe
|
|
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view(
|
|
-1, self.num_local_heads * 256
|
|
)
|
|
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(
|
|
-1, self.num_local_heads * 256
|
|
)
|
|
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
|
|
-1, self.num_local_heads * 256
|
|
)
|
|
attn_output = self.attn(q, k, v, forward_batch)
|
|
attn_output = attn_output.view(-1, self.num_local_heads, 256)[
|
|
..., : self.v_head_dim
|
|
].reshape(-1, self.num_local_heads * self.v_head_dim)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class DeepseekV2AttentionMLA(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
v_head_dim: int,
|
|
q_lora_rank: int,
|
|
kv_lora_rank: int,
|
|
rope_theta: float = 10000,
|
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
|
max_position_embeddings: int = 8192,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
reduce_results: bool = True,
|
|
layer_id: int = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.layer_id = layer_id
|
|
self.hidden_size = hidden_size
|
|
self.qk_nope_head_dim = qk_nope_head_dim
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
|
self.v_head_dim = v_head_dim
|
|
self.q_lora_rank = q_lora_rank
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.dp_size = get_attention_dp_size()
|
|
attn_tp_rank = get_attention_tp_rank()
|
|
attn_tp_size = get_attention_tp_size()
|
|
|
|
self.num_heads = num_heads
|
|
assert num_heads % attn_tp_size == 0
|
|
self.num_local_heads = num_heads // attn_tp_size
|
|
self.scaling = self.qk_head_dim**-0.5
|
|
self.rope_theta = rope_theta
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
# For tensor parallel attention
|
|
if self.q_lora_rank is not None:
|
|
self.q_a_proj = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.q_lora_rank,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("q_a_proj", prefix),
|
|
)
|
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
|
self.q_b_proj = ColumnParallelLinear(
|
|
q_lora_rank,
|
|
self.num_heads * self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("q_b_proj", prefix),
|
|
tp_rank=attn_tp_rank,
|
|
tp_size=attn_tp_size,
|
|
)
|
|
else:
|
|
self.q_proj = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
self.num_heads * self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("q_proj", prefix),
|
|
tp_rank=attn_tp_rank,
|
|
tp_size=attn_tp_size,
|
|
)
|
|
self.kv_b_proj = ColumnParallelLinear(
|
|
self.kv_lora_rank,
|
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("kv_b_proj", prefix),
|
|
tp_rank=attn_tp_rank,
|
|
tp_size=attn_tp_size,
|
|
)
|
|
# O projection.
|
|
self.o_proj = RowParallelLinear(
|
|
self.num_heads * self.v_head_dim,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
reduce_results=reduce_results,
|
|
prefix=add_prefix("o_proj", prefix),
|
|
tp_rank=attn_tp_rank,
|
|
tp_size=attn_tp_size,
|
|
)
|
|
|
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
|
)
|
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
|
|
|
if rope_scaling:
|
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
|
|
|
self.rotary_emb = get_rope(
|
|
qk_rope_head_dim,
|
|
rotary_dim=qk_rope_head_dim,
|
|
max_position=max_position_embeddings,
|
|
base=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
is_neox_style=False,
|
|
)
|
|
|
|
if rope_scaling:
|
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
|
scaling_factor = rope_scaling["factor"]
|
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
|
self.scaling = self.scaling * mscale * mscale
|
|
else:
|
|
self.rotary_emb.forward = self.rotary_emb.forward_native
|
|
|
|
self.attn_mqa = RadixAttention(
|
|
self.num_local_heads,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
self.scaling,
|
|
num_kv_heads=1,
|
|
layer_id=layer_id,
|
|
v_head_dim=self.kv_lora_rank,
|
|
prefix=add_prefix("attn_mqa", prefix),
|
|
)
|
|
|
|
self.attn_mha = RadixAttention(
|
|
self.num_local_heads,
|
|
self.qk_nope_head_dim + self.qk_rope_head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_local_heads,
|
|
layer_id=layer_id,
|
|
v_head_dim=self.v_head_dim,
|
|
prefix=add_prefix("attn_mha", prefix),
|
|
)
|
|
|
|
self.w_kc = None
|
|
self.w_vc = None
|
|
self.w_scale = None
|
|
|
|
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
|
|
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
|
"flashinfer_mla_disable_ragged"
|
|
]
|
|
self.attention_backend = global_server_args_dict["attention_backend"]
|
|
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
|
|
|
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
|
if self.enable_flashinfer_mla:
|
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
|
return (
|
|
not self.flashinfer_mla_disable_ragged
|
|
and forward_batch.forward_mode.is_extend()
|
|
and not forward_batch.forward_mode.is_target_verify()
|
|
and not forward_batch.forward_mode.is_draft_extend()
|
|
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
|
)
|
|
elif self.attention_backend == "fa3":
|
|
# Flash Attention: Keep absorbing for all extend/decode
|
|
return False
|
|
else:
|
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
|
return (
|
|
forward_batch.forward_mode.is_extend()
|
|
and not forward_batch.forward_mode.is_target_verify()
|
|
and not forward_batch.forward_mode.is_draft_extend()
|
|
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
if hidden_states.shape[0] == 0:
|
|
assert (
|
|
not self.o_proj.reduce_results
|
|
), "short-circuiting allreduce will lead to hangs"
|
|
return hidden_states
|
|
|
|
if self.no_absorb(forward_batch):
|
|
return self.forward_normal(positions, hidden_states, forward_batch)
|
|
else:
|
|
if _is_hip:
|
|
if (
|
|
self.rocm_fused_decode_mla
|
|
and forward_batch.forward_mode.is_decode()
|
|
):
|
|
return self.forward_absorb_fused_mla_rope(
|
|
positions, hidden_states, forward_batch
|
|
)
|
|
else:
|
|
return self.forward_absorb(positions, hidden_states, forward_batch)
|
|
else:
|
|
return self.forward_absorb(positions, hidden_states, forward_batch)
|
|
|
|
def forward_normal(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
if self.q_lora_rank is not None:
|
|
q = self.q_a_proj(hidden_states)[0]
|
|
q = self.q_a_layernorm(q)
|
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
|
else:
|
|
q = self.q_proj(hidden_states)[0].view(
|
|
-1, self.num_local_heads, self.qk_head_dim
|
|
)
|
|
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
|
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
|
latent_cache = latent_cache.unsqueeze(1)
|
|
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
|
kv = self.kv_b_proj(kv_a)[0]
|
|
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
k_nope = kv[..., : self.qk_nope_head_dim]
|
|
v = kv[..., self.qk_nope_head_dim :]
|
|
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
|
q[..., self.qk_nope_head_dim :] = q_pe
|
|
k = torch.empty_like(q)
|
|
k[..., : self.qk_nope_head_dim] = k_nope
|
|
k[..., self.qk_nope_head_dim :] = k_pe
|
|
|
|
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
|
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
|
|
|
# Save latent cache
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
|
)
|
|
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
|
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
def forward_absorb(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
q_len = hidden_states.shape[0]
|
|
q_input = hidden_states.new_empty(
|
|
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
|
|
)
|
|
if self.q_lora_rank is not None:
|
|
q = self.q_a_proj(hidden_states)[0]
|
|
q = self.q_a_layernorm(q)
|
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
|
else:
|
|
q = self.q_proj(hidden_states)[0].view(
|
|
-1, self.num_local_heads, self.qk_head_dim
|
|
)
|
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
|
|
|
if self.w_kc.dtype == torch.float8_e4m3fnuz:
|
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
|
q_nope_out = torch.bmm(
|
|
q_nope.to(torch.bfloat16).transpose(0, 1),
|
|
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
|
)
|
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
|
q_nope_val, q_nope_scale = input_to_float8(
|
|
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
|
)
|
|
q_nope_out = bmm_fp8(
|
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
|
)
|
|
else:
|
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
|
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
|
|
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
|
v_input = latent_cache[..., : self.kv_lora_rank]
|
|
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
|
|
k_input = latent_cache.unsqueeze(1)
|
|
k_input[..., : self.kv_lora_rank] = v_input
|
|
k_pe = k_input[..., self.kv_lora_rank :]
|
|
|
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
|
q_input[..., self.kv_lora_rank :] = q_pe
|
|
k_input[..., self.kv_lora_rank :] = k_pe
|
|
|
|
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
|
|
|
if self.w_vc.dtype == torch.float8_e4m3fnuz:
|
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
|
attn_bmm_output = torch.bmm(
|
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
|
)
|
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
|
attn_output_val, attn_output_scale = input_to_float8(
|
|
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
|
)
|
|
attn_bmm_output = bmm_fp8(
|
|
attn_output_val,
|
|
self.w_vc,
|
|
attn_output_scale,
|
|
self.w_scale,
|
|
torch.bfloat16,
|
|
)
|
|
else:
|
|
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
|
|
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
|
output, _ = self.o_proj(attn_output)
|
|
|
|
return output
|
|
|
|
def forward_absorb_fused_mla_rope(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
enable_rope_fusion = (
|
|
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
|
)
|
|
q_len = hidden_states.shape[0]
|
|
q_input = hidden_states.new_empty(
|
|
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
|
|
)
|
|
if self.q_lora_rank is not None:
|
|
q = self.q_a_proj(hidden_states)[0]
|
|
q = self.q_a_layernorm(q)
|
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
|
else:
|
|
q = self.q_proj(hidden_states)[0].view(
|
|
-1, self.num_local_heads, self.qk_head_dim
|
|
)
|
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
|
|
|
if self.w_kc.dtype == torch.float8_e4m3fnuz:
|
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
|
q_nope_out = torch.bmm(
|
|
q_nope.to(torch.bfloat16).transpose(0, 1),
|
|
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
|
)
|
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
|
q_nope_val, q_nope_scale = input_to_float8(
|
|
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
|
)
|
|
q_nope_out = bmm_fp8(
|
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
|
)
|
|
else:
|
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
|
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
|
|
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
|
v_input = latent_cache[..., : self.kv_lora_rank]
|
|
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
|
|
k_input = latent_cache.unsqueeze(1)
|
|
k_input[..., : self.kv_lora_rank] = v_input
|
|
|
|
if not enable_rope_fusion:
|
|
k_pe = k_input[..., self.kv_lora_rank :]
|
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
|
q_input[..., self.kv_lora_rank :] = q_pe
|
|
k_input[..., self.kv_lora_rank :] = k_pe
|
|
k_pe_output = None
|
|
else:
|
|
k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])
|
|
|
|
q_input[..., self.kv_lora_rank :] = q_pe
|
|
|
|
# attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
|
# Use Fused ROPE with use_rope=OFF.
|
|
attn_output = torch.empty(
|
|
(q_len, self.num_local_heads, self.kv_lora_rank),
|
|
dtype=q.dtype,
|
|
device=q.device,
|
|
)
|
|
attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
|
|
forward_batch.attn_backend.forward_metadata
|
|
)
|
|
cos_sin_cache = self.rotary_emb.cos_sin_cache
|
|
num_kv_split = forward_batch.attn_backend.num_kv_splits
|
|
sm_scale = self.attn_mqa.scaling
|
|
if attn_logits is None:
|
|
attn_logits = torch.empty(
|
|
(
|
|
forward_batch.batch_size,
|
|
self.num_local_heads,
|
|
num_kv_split,
|
|
self.kv_lora_rank + 1,
|
|
),
|
|
dtype=torch.float32,
|
|
device=q.device,
|
|
)
|
|
|
|
# save current latent cache.
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
self.attn_mqa, forward_batch.out_cache_loc, k_input, None
|
|
)
|
|
key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
|
self.attn_mqa.layer_id
|
|
)
|
|
val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
|
|
|
|
decode_attention_fwd_grouped_rope(
|
|
q_input,
|
|
key_cache_buf,
|
|
val_cache_buf,
|
|
attn_output,
|
|
kv_indptr,
|
|
kv_indices,
|
|
k_pe_output,
|
|
self.kv_lora_rank,
|
|
self.rotary_emb.rotary_dim,
|
|
cos_sin_cache,
|
|
positions,
|
|
attn_logits,
|
|
num_kv_split,
|
|
sm_scale,
|
|
logit_cap=self.attn_mqa.logit_cap,
|
|
use_rope=enable_rope_fusion,
|
|
is_neox_style=self.rotary_emb.is_neox_style,
|
|
)
|
|
|
|
if enable_rope_fusion:
|
|
k_input[..., self.kv_lora_rank :] = k_pe_output
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
self.attn_mqa, forward_batch.out_cache_loc, k_input, None
|
|
)
|
|
|
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
|
|
|
if self.w_vc.dtype == torch.float8_e4m3fnuz:
|
|
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
|
attn_bmm_output = torch.bmm(
|
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
|
)
|
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
|
attn_output_val, attn_output_scale = input_to_float8(
|
|
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
|
)
|
|
attn_bmm_output = bmm_fp8(
|
|
attn_output_val,
|
|
self.w_vc,
|
|
attn_output_scale,
|
|
self.w_scale,
|
|
torch.bfloat16,
|
|
)
|
|
else:
|
|
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
|
|
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
|
output, _ = self.o_proj(attn_output)
|
|
|
|
return output
|
|
|
|
|
|
class DeepseekV2DecoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
layer_id: int,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
is_nextn: bool = False,
|
|
prefix: str = "",
|
|
) -> None:
|
|
|
|
def is_sparse_layer(l: int):
|
|
return (
|
|
config.n_routed_experts is not None
|
|
and l >= config.first_k_dense_replace
|
|
and l % config.moe_layer_freq == 0
|
|
)
|
|
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
rope_theta = getattr(config, "rope_theta", 10000)
|
|
rope_scaling = getattr(config, "rope_scaling", None)
|
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
|
self.layer_id = layer_id
|
|
self.dp_size = get_attention_dp_size()
|
|
self.attn_tp_size = get_attention_tp_size()
|
|
self.attn_tp_rank = get_attention_tp_rank()
|
|
|
|
if not global_server_args_dict["disable_mla"]:
|
|
self.self_attn = DeepseekV2AttentionMLA(
|
|
config=config,
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
qk_nope_head_dim=config.qk_nope_head_dim,
|
|
qk_rope_head_dim=config.qk_rope_head_dim,
|
|
v_head_dim=config.v_head_dim,
|
|
q_lora_rank=(
|
|
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
|
),
|
|
kv_lora_rank=config.kv_lora_rank,
|
|
rope_theta=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
max_position_embeddings=max_position_embeddings,
|
|
quant_config=quant_config,
|
|
layer_id=layer_id,
|
|
reduce_results=False,
|
|
prefix=add_prefix("self_attn", prefix),
|
|
)
|
|
else:
|
|
self.self_attn = DeepseekV2Attention(
|
|
config=config,
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
qk_nope_head_dim=config.qk_nope_head_dim,
|
|
qk_rope_head_dim=config.qk_rope_head_dim,
|
|
v_head_dim=config.v_head_dim,
|
|
q_lora_rank=(
|
|
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
|
),
|
|
kv_lora_rank=config.kv_lora_rank,
|
|
rope_theta=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
max_position_embeddings=max_position_embeddings,
|
|
quant_config=quant_config,
|
|
layer_id=layer_id,
|
|
reduce_results=False,
|
|
prefix=add_prefix("self_attn", prefix),
|
|
)
|
|
|
|
if is_nextn or is_sparse_layer(layer_id):
|
|
self.mlp = DeepseekV2MoE(
|
|
config=config,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("mlp", prefix),
|
|
)
|
|
self.is_sparse = True
|
|
else:
|
|
self.mlp = DeepseekV2MLP(
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("mlp", prefix),
|
|
)
|
|
self.is_sparse = False
|
|
|
|
self.input_is_scattered = (
|
|
is_sparse_layer(layer_id - 1)
|
|
and global_server_args_dict["enable_deepep_moe"]
|
|
)
|
|
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
|
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
residual: Optional[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
|
|
return self.forward_deepep(
|
|
positions, hidden_states, forward_batch, residual
|
|
)
|
|
else:
|
|
return self.forward_normal(
|
|
positions, hidden_states, forward_batch, residual
|
|
)
|
|
|
|
def forward_normal(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
residual: Optional[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
|
|
if hidden_states.shape[0] == 0:
|
|
residual = hidden_states
|
|
else:
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
|
|
|
# Self Attention
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
forward_batch=forward_batch,
|
|
)
|
|
|
|
if self.attn_tp_size != 1 and self.input_is_scattered:
|
|
hidden_states, local_hidden_states = (
|
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
|
hidden_states,
|
|
)
|
|
tp_all_gather(
|
|
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
|
)
|
|
residual, local_residual = (
|
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
|
residual,
|
|
)
|
|
tp_all_gather(
|
|
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
|
)
|
|
|
|
# Gather
|
|
if get_tensor_model_parallel_world_size() > 1:
|
|
# all gather and all reduce
|
|
if self.dp_size != 1:
|
|
if self.attn_tp_rank == 0:
|
|
hidden_states += residual
|
|
hidden_states, local_hidden_states = (
|
|
forward_batch.gathered_buffer,
|
|
hidden_states,
|
|
)
|
|
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
|
dp_scatter(residual, hidden_states, forward_batch)
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
else:
|
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
|
hidden_states, residual = self.post_attention_layernorm(
|
|
hidden_states, residual
|
|
)
|
|
else:
|
|
hidden_states, residual = self.post_attention_layernorm(
|
|
hidden_states, residual
|
|
)
|
|
|
|
# Fully Connected
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
|
# Scatter
|
|
if self.dp_size != 1:
|
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
|
# be careful about this!
|
|
hidden_states, global_hidden_states = (
|
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
|
hidden_states,
|
|
)
|
|
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
|
|
|
return hidden_states, residual
|
|
|
|
def forward_deepep(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
residual: Optional[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
|
|
if hidden_states.shape[0] == 0:
|
|
residual = hidden_states
|
|
else:
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
|
|
|
if self.attn_tp_size != 1 and self.input_is_scattered:
|
|
hidden_states, local_hidden_states = (
|
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
|
hidden_states,
|
|
)
|
|
tp_all_gather(
|
|
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
|
)
|
|
|
|
# Self Attention
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
forward_batch=forward_batch,
|
|
)
|
|
|
|
if self.attn_tp_size != 1:
|
|
if self.input_is_scattered:
|
|
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
|
hidden_states = tensor_list[self.attn_tp_rank]
|
|
tp_reduce_scatter(hidden_states, tensor_list)
|
|
if hidden_states.shape[0] != 0:
|
|
hidden_states, residual = self.post_attention_layernorm(
|
|
hidden_states, residual
|
|
)
|
|
else:
|
|
if self.attn_tp_rank == 0:
|
|
hidden_states += residual
|
|
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
|
hidden_states = tensor_list[self.attn_tp_rank]
|
|
tp_reduce_scatter(hidden_states, tensor_list)
|
|
residual = hidden_states
|
|
if hidden_states.shape[0] != 0:
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
else:
|
|
if hidden_states.shape[0] != 0:
|
|
hidden_states, residual = self.post_attention_layernorm(
|
|
hidden_states, residual
|
|
)
|
|
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
|
|
|
if self.is_last_layer and self.attn_tp_size != 1:
|
|
hidden_states, local_hidden_states = (
|
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
|
hidden_states,
|
|
)
|
|
tp_all_gather(
|
|
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
|
)
|
|
residual, local_residual = (
|
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
|
residual,
|
|
)
|
|
tp_all_gather(
|
|
list(residual.tensor_split(self.attn_tp_size)), local_residual
|
|
)
|
|
|
|
return hidden_states, residual
|
|
|
|
|
|
class DeepseekV2Model(nn.Module):
|
|
|
|
fall_back_to_pt_during_load = False
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.padding_id = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
|
)
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
DeepseekV2DecoderLayer(
|
|
config,
|
|
layer_id,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
|
)
|
|
for layer_id in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
self.dp_size = get_attention_dp_size()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
input_embeds: torch.Tensor = None,
|
|
) -> torch.Tensor:
|
|
|
|
if input_embeds is None:
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
else:
|
|
hidden_states = input_embeds
|
|
|
|
residual = None
|
|
for i in range(len(self.layers)):
|
|
expert_distribution_recorder.set_current_layer(i)
|
|
layer = self.layers[i]
|
|
hidden_states, residual = layer(
|
|
positions, hidden_states, forward_batch, residual
|
|
)
|
|
if not forward_batch.forward_mode.is_idle():
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
|
|
class DeepseekV2ForCausalLM(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
self.model = DeepseekV2Model(
|
|
config, quant_config, prefix=add_prefix("model", prefix)
|
|
)
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("lm_head", prefix),
|
|
)
|
|
self.logits_processor = LogitsProcessor(config)
|
|
self.dp_size = get_attention_dp_size()
|
|
|
|
@torch.no_grad()
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
input_embeds: torch.Tensor = None,
|
|
) -> torch.Tensor:
|
|
|
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
|
|
|
return self.logits_processor(
|
|
input_ids, hidden_states, self.lm_head, forward_batch
|
|
)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
|
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
# (param_name, weight_name, expert_id, shard_id)
|
|
MoEImpl = (
|
|
DeepEPMoE
|
|
if global_server_args_dict["enable_deepep_moe"]
|
|
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
|
)
|
|
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
|
ckpt_gate_proj_name="gate_proj",
|
|
ckpt_down_proj_name="down_proj",
|
|
ckpt_up_proj_name="up_proj",
|
|
num_experts=self.config.n_routed_experts,
|
|
)
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
for name, loaded_weight in weights:
|
|
# TODO(HandH1998): Modify it when nextn is supported.
|
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
|
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
|
name_list = name.split(".")
|
|
if (
|
|
len(name_list) >= 3
|
|
and int(name_list[2]) >= self.config.num_hidden_layers
|
|
):
|
|
continue
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
# Skip non-stacked layers and experts (experts handled below).
|
|
if weight_name not in name:
|
|
continue
|
|
# We have mlp.experts[0].gate_proj in the checkpoint.
|
|
# Since we handle the experts below in expert_params_mapping,
|
|
# we need to skip here BEFORE we update the name, otherwise
|
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
|
# will then be updated below in expert_params_mapping
|
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
|
if ("mlp.experts." in name) and name not in params_dict:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
for mapping in expert_params_mapping:
|
|
param_name, weight_name, expert_id, shard_id = mapping
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(
|
|
param,
|
|
loaded_weight,
|
|
name,
|
|
shard_id=shard_id,
|
|
expert_id=expert_id,
|
|
)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(
|
|
param, "weight_loader", default_weight_loader
|
|
)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
if not global_server_args_dict["disable_mla"]:
|
|
for layer_id in range(self.config.num_hidden_layers):
|
|
self_attn = self.model.layers[layer_id].self_attn
|
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
|
# AWQ compatible
|
|
if _is_cuda:
|
|
w = awq_dequantize(
|
|
self_attn.kv_b_proj.qweight,
|
|
self_attn.kv_b_proj.scales,
|
|
self_attn.kv_b_proj.qzeros,
|
|
).T
|
|
else:
|
|
w = ops.awq_dequantize(
|
|
self_attn.kv_b_proj.qweight,
|
|
self_attn.kv_b_proj.scales,
|
|
self_attn.kv_b_proj.qzeros,
|
|
0,
|
|
0,
|
|
0,
|
|
).T
|
|
else:
|
|
w = self_attn.kv_b_proj.weight
|
|
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
|
# This may affect the accuracy of fp8 model.
|
|
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
|
torch.float8_e4m3fn,
|
|
torch.float8_e4m3fnuz,
|
|
):
|
|
weight_block_size = self.quant_config.weight_block_size
|
|
if weight_block_size is not None:
|
|
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
|
if _is_hip:
|
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
|
weight=w,
|
|
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
|
input_scale=None,
|
|
)
|
|
else:
|
|
weight = w
|
|
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
|
|
|
w, scale = block_quant_to_tensor_quant(
|
|
weight, weight_scale, weight_block_size
|
|
)
|
|
self_attn.w_scale = scale
|
|
if w.dtype == torch.int8:
|
|
if hasattr(self.quant_config, "weight_block_size"):
|
|
# block-wise int8 need it
|
|
weight_block_size = self.quant_config.weight_block_size
|
|
if weight_block_size is not None:
|
|
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
|
weight = w
|
|
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
|
w = int8_block_dequant(
|
|
weight, weight_scale, weight_block_size
|
|
).to(torch.bfloat16)
|
|
else:
|
|
# channel-wise int8 need it
|
|
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
|
torch.bfloat16
|
|
)
|
|
w_kc, w_vc = w.unflatten(
|
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
|
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
|
if (
|
|
hasattr(self_attn.kv_b_proj, "weight_scale")
|
|
and self_attn.w_scale is None
|
|
):
|
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
|
if _is_hip:
|
|
self_attn.w_scale *= 2.0
|
|
|
|
def get_embed_and_head(self):
|
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
|
|
|
def set_embed_and_head(self, embed, head):
|
|
del self.model.embed_tokens.weight
|
|
del self.lm_head.weight
|
|
self.model.embed_tokens.weight = embed
|
|
self.lm_head.weight = head
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
|
pass
|
|
|
|
|
|
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
|