689 lines
26 KiB
Python
689 lines
26 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Inference-only MiniCPM3 model compatible with HuggingFace weights."""
|
|
|
|
import math
|
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|
from sglang.srt.layers.activation import SiluAndMul
|
|
from sglang.srt.layers.layernorm import RMSNorm
|
|
from sglang.srt.layers.linear import (
|
|
ColumnParallelLinear,
|
|
MergedColumnParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.layers.rotary_embedding import get_rope
|
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
from sglang.srt.utils import add_prefix, is_cuda_available
|
|
|
|
if is_cuda_available():
|
|
from sgl_kernel import bmm_fp8
|
|
|
|
|
|
class MiniCPM3MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_act: str,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size,
|
|
[intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("gate_up_proj", prefix),
|
|
)
|
|
self.down_proj = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("down_proj", prefix),
|
|
)
|
|
if hidden_act != "silu":
|
|
raise ValueError(
|
|
f"Unsupported activation: {hidden_act}. "
|
|
"Only silu is supported for now."
|
|
)
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x):
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
def input_to_float8(x, dtype=torch.float8_e4m3fn):
|
|
finfo = torch.finfo(dtype)
|
|
min_val, max_val = x.aminmax()
|
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
|
scale = finfo.max / amax
|
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
|
|
|
|
|
class MiniCPM3Attention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
v_head_dim: int,
|
|
q_lora_rank: int,
|
|
kv_lora_rank: int,
|
|
rope_theta: float = 10000,
|
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
|
max_position_embeddings: int = 8192,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
layer_id=None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.layer_id = layer_id
|
|
self.hidden_size = hidden_size
|
|
self.qk_nope_head_dim = qk_nope_head_dim
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
|
self.v_head_dim = v_head_dim
|
|
self.q_lora_rank = q_lora_rank
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.num_heads = num_heads
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
assert num_heads % tp_size == 0
|
|
self.num_local_heads = num_heads // tp_size
|
|
self.scaling = self.qk_head_dim**-0.5
|
|
self.rope_theta = rope_theta
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
if self.q_lora_rank is not None:
|
|
self.q_a_proj = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.q_lora_rank,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("q_a_proj", prefix),
|
|
)
|
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
|
self.q_b_proj = ColumnParallelLinear(
|
|
q_lora_rank,
|
|
self.num_heads * self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("q_b_proj", prefix),
|
|
)
|
|
else:
|
|
self.q_proj = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
self.num_heads * self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("q_proj", prefix),
|
|
)
|
|
|
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
|
)
|
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
|
self.kv_b_proj = ColumnParallelLinear(
|
|
self.kv_lora_rank,
|
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("kv_b_proj", prefix),
|
|
)
|
|
# O projection.
|
|
self.o_proj = RowParallelLinear(
|
|
self.num_heads * self.v_head_dim,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("o_proj", prefix),
|
|
)
|
|
self.rotary_emb = get_rope(
|
|
qk_rope_head_dim,
|
|
rotary_dim=qk_rope_head_dim,
|
|
max_position=max_position_embeddings,
|
|
base=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
)
|
|
|
|
# TODO support head_size 96
|
|
self.attn = RadixAttention(
|
|
self.num_local_heads,
|
|
128,
|
|
self.scaling,
|
|
num_kv_heads=self.num_local_heads,
|
|
layer_id=layer_id,
|
|
prefix=add_prefix("attn", prefix),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
if self.q_lora_rank is not None:
|
|
q = self.q_a_proj(hidden_states)[0]
|
|
q = self.q_a_layernorm(q)
|
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
|
else:
|
|
q = self.q_proj(hidden_states)[0].view(
|
|
-1, self.num_local_heads, self.qk_head_dim
|
|
)
|
|
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
|
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
|
latent_cache = latent_cache.unsqueeze(1)
|
|
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
|
kv = self.kv_b_proj(kv_a)[0]
|
|
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
|
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
|
original_shapes = [q_pe.shape, k_pe.shape]
|
|
q_pe, k_pe = self.rotary_emb(
|
|
positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1)
|
|
)
|
|
q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1])
|
|
q[..., self.qk_nope_head_dim :] = q_pe
|
|
k = torch.empty_like(q)
|
|
k[..., : self.qk_nope_head_dim] = k_nope
|
|
k[..., self.qk_nope_head_dim :] = k_pe
|
|
q = torch.nn.functional.pad(q, [0, 128 - self.qk_head_dim], value=0).view(
|
|
-1, self.num_local_heads * 128
|
|
)
|
|
k = torch.nn.functional.pad(k, [0, 128 - self.qk_head_dim], value=0).view(
|
|
-1, self.num_local_heads * 128
|
|
)
|
|
v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
|
|
-1, self.num_local_heads * 128
|
|
)
|
|
attn_output = self.attn(q, k, v, forward_batch)
|
|
attn_output = attn_output.view(-1, self.num_local_heads, 128)[
|
|
..., : self.v_head_dim
|
|
].reshape(-1, self.num_local_heads * self.v_head_dim)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class MiniCPM3AttentionMLA(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
v_head_dim: int,
|
|
q_lora_rank: int,
|
|
kv_lora_rank: int,
|
|
rope_theta: float = 10000,
|
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
|
max_position_embeddings: int = 8192,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
layer_id=None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.layer_id = layer_id
|
|
self.hidden_size = hidden_size
|
|
self.qk_nope_head_dim = qk_nope_head_dim
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
|
self.v_head_dim = v_head_dim
|
|
self.q_lora_rank = q_lora_rank
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.num_heads = num_heads
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
assert num_heads % tp_size == 0
|
|
self.num_local_heads = num_heads // tp_size
|
|
self.scaling = self.qk_head_dim**-0.5
|
|
self.rope_theta = rope_theta
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
if self.q_lora_rank is not None:
|
|
self.q_a_proj = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.q_lora_rank,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("q_a_proj", prefix),
|
|
)
|
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
|
self.q_b_proj = ColumnParallelLinear(
|
|
q_lora_rank,
|
|
self.num_heads * self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("q_b_proj", prefix),
|
|
)
|
|
else:
|
|
self.q_proj = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
self.num_heads * self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("q_proj", prefix),
|
|
)
|
|
|
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
|
)
|
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
|
self.kv_b_proj = ColumnParallelLinear(
|
|
self.kv_lora_rank,
|
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("kv_b_proj", prefix),
|
|
)
|
|
# O projection.
|
|
self.o_proj = RowParallelLinear(
|
|
self.num_heads * self.v_head_dim,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("o_proj", prefix),
|
|
)
|
|
self.rotary_emb = get_rope(
|
|
qk_rope_head_dim,
|
|
rotary_dim=qk_rope_head_dim,
|
|
max_position=max_position_embeddings,
|
|
base=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
)
|
|
|
|
self.attn = RadixAttention(
|
|
self.num_local_heads,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
self.scaling,
|
|
num_kv_heads=1,
|
|
layer_id=layer_id,
|
|
v_head_dim=self.kv_lora_rank,
|
|
prefix=add_prefix("attn", prefix),
|
|
)
|
|
|
|
self.w_kc = None
|
|
self.w_vc = None
|
|
self.w_scale = None
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
q_len = hidden_states.shape[0]
|
|
q_input = hidden_states.new_empty(
|
|
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
|
|
)
|
|
if self.q_lora_rank is not None:
|
|
q = self.q_a_proj(hidden_states)[0]
|
|
q = self.q_a_layernorm(q)
|
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
|
else:
|
|
q = self.q_proj(hidden_states)[0].view(
|
|
-1, self.num_local_heads, self.qk_head_dim
|
|
)
|
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
|
|
|
if self.w_kc.dtype == torch.float8_e4m3fn:
|
|
q_nope_val, q_nope_scale = input_to_float8(
|
|
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
|
)
|
|
q_nope_out = bmm_fp8(
|
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
|
)
|
|
else:
|
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
|
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
|
|
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
|
v_input = latent_cache[..., : self.kv_lora_rank]
|
|
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
|
|
k_input = latent_cache.unsqueeze(1)
|
|
k_input[..., : self.kv_lora_rank] = v_input
|
|
k_pe = k_input[..., self.kv_lora_rank :]
|
|
|
|
original_shapes = [q_pe.shape, k_pe.shape]
|
|
q_pe, k_pe = self.rotary_emb(
|
|
positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1)
|
|
)
|
|
q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1])
|
|
q_input[..., self.kv_lora_rank :] = q_pe
|
|
k_input[..., self.kv_lora_rank :] = k_pe
|
|
|
|
attn_output = self.attn(q_input, k_input, v_input, forward_batch)
|
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
|
|
|
if self.w_vc.dtype == torch.float8_e4m3fn:
|
|
attn_output_val, attn_output_scale = input_to_float8(
|
|
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
|
)
|
|
attn_bmm_output = bmm_fp8(
|
|
attn_output_val,
|
|
self.w_vc,
|
|
attn_output_scale,
|
|
self.w_scale,
|
|
torch.bfloat16,
|
|
)
|
|
else:
|
|
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
|
|
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
|
output, _ = self.o_proj(attn_output)
|
|
|
|
return output
|
|
|
|
|
|
class MiniCPM3DecoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
layer_id: int,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
rope_theta = getattr(config, "rope_theta", 10000)
|
|
rope_scaling = getattr(config, "rope_scaling", None)
|
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
|
if not global_server_args_dict["disable_mla"]:
|
|
self.self_attn = MiniCPM3AttentionMLA(
|
|
config=config,
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
qk_nope_head_dim=config.qk_nope_head_dim,
|
|
qk_rope_head_dim=config.qk_rope_head_dim,
|
|
v_head_dim=self.hidden_size // config.num_attention_heads,
|
|
q_lora_rank=(
|
|
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
|
),
|
|
kv_lora_rank=config.kv_lora_rank,
|
|
rope_theta=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
max_position_embeddings=max_position_embeddings,
|
|
quant_config=quant_config,
|
|
layer_id=layer_id,
|
|
prefix=add_prefix("self_attn", prefix),
|
|
)
|
|
else:
|
|
self.self_attn = MiniCPM3Attention(
|
|
config=config,
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
qk_nope_head_dim=config.qk_nope_head_dim,
|
|
qk_rope_head_dim=config.qk_rope_head_dim,
|
|
v_head_dim=self.hidden_size // config.num_attention_heads,
|
|
q_lora_rank=(
|
|
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
|
),
|
|
kv_lora_rank=config.kv_lora_rank,
|
|
rope_theta=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
max_position_embeddings=max_position_embeddings,
|
|
quant_config=quant_config,
|
|
layer_id=layer_id,
|
|
prefix=add_prefix("self_attn", prefix),
|
|
)
|
|
self.mlp = MiniCPM3MLP(
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("mlp", prefix),
|
|
)
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
residual: Optional[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# Self Attention
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
forward_batch=forward_batch,
|
|
)
|
|
hidden_states = residual + hidden_states * (
|
|
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
|
|
)
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states * (
|
|
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
|
|
)
|
|
|
|
return hidden_states, None
|
|
|
|
|
|
class MiniCPM3Model(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
self.vocab_size,
|
|
config.hidden_size,
|
|
org_num_embeddings=config.vocab_size,
|
|
prefix=add_prefix("embed_tokens", prefix),
|
|
)
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
MiniCPM3DecoderLayer(
|
|
config,
|
|
i,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix(f"layers.{i}", prefix),
|
|
)
|
|
for i in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
input_embeds: torch.Tensor = None,
|
|
) -> torch.Tensor:
|
|
if input_embeds is None:
|
|
hidden_states = self.embed_tokens(input_ids) * self.config.scale_emb
|
|
else:
|
|
hidden_states = input_embeds
|
|
residual = None
|
|
|
|
for i in range(len(self.layers)):
|
|
layer = self.layers[i]
|
|
hidden_states, residual = layer(
|
|
positions,
|
|
hidden_states,
|
|
forward_batch,
|
|
residual,
|
|
)
|
|
hidden_states = self.norm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class MiniCPM3ForCausalLM(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.num_experts = getattr(self.config, "num_experts", 0)
|
|
self.quant_config = quant_config
|
|
self.model = MiniCPM3Model(
|
|
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
|
)
|
|
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
|
if not self.config.tie_word_embeddings:
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
org_num_embeddings=config.vocab_size,
|
|
prefix=add_prefix("lm_head", prefix),
|
|
)
|
|
|
|
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
|
|
|
self.logits_processor = LogitsProcessor(config)
|
|
|
|
@torch.no_grad()
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
input_embeds: torch.Tensor = None,
|
|
) -> torch.Tensor:
|
|
if input_embeds is not None:
|
|
input_embeds = input_embeds * self.config.scale_emb
|
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
|
hidden_states = hidden_states / self.scale_width
|
|
if self.config.tie_word_embeddings:
|
|
lm_head = self.model.embed_tokens
|
|
else:
|
|
lm_head = self.lm_head
|
|
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
expert_params_mapping = [
|
|
# (param_name, weight_name, expert_id)
|
|
(
|
|
"ws" if weight_name in ["w1", "w3"] else "w2s",
|
|
f"experts.{expert_id}.{weight_name}.weight",
|
|
expert_id,
|
|
)
|
|
for expert_id in range(self.num_experts)
|
|
for weight_name in ["w1", "w2", "w3"]
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
|
# Models trained using ColossalAI may include these tensors in
|
|
# the checkpoint. Skip them.
|
|
continue
|
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
|
continue
|
|
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
for param_name, weight_name, expert_id in expert_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(
|
|
param, loaded_weight, weight_name, expert_id=expert_id
|
|
)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(
|
|
param, "weight_loader", default_weight_loader
|
|
)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
if not global_server_args_dict["disable_mla"]:
|
|
for layer_id in range(self.config.num_hidden_layers):
|
|
self_attn = self.model.layers[layer_id].self_attn
|
|
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
|
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
|
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
|
if hasattr(self_attn.kv_b_proj, "weight_scale"):
|
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
|
del self_attn.kv_b_proj
|
|
|
|
|
|
EntryClass = MiniCPM3ForCausalLM
|