# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Inference-only MiniCPM3 model compatible with HuggingFace weights.""" import math from typing import Any, Dict, Iterable, Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, ReplicatedLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, is_cuda_available if is_cuda_available(): from sgl_kernel import bmm_fp8 class MiniCPM3MLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, prefix=add_prefix("down_proj", prefix), ) if hidden_act != "silu": raise ValueError( f"Unsupported activation: {hidden_act}. " "Only silu is supported for now." ) self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x def input_to_float8(x, dtype=torch.float8_e4m3fn): finfo = torch.finfo(dtype) min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) scale = finfo.max / amax x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() class MiniCPM3Attention(nn.Module): def __init__( self, config: PretrainedConfig, hidden_size: int, num_heads: int, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, q_lora_rank: int, kv_lora_rank: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, layer_id=None, prefix: str = "", ) -> None: super().__init__() self.layer_id = layer_id self.hidden_size = hidden_size self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.num_heads = num_heads tp_size = get_tensor_model_parallel_world_size() assert num_heads % tp_size == 0 self.num_local_heads = num_heads // tp_size self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings if self.q_lora_rank is not None: self.q_a_proj = ReplicatedLinear( self.hidden_size, self.q_lora_rank, bias=False, quant_config=quant_config, prefix=add_prefix("q_a_proj", prefix), ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear( q_lora_rank, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, prefix=add_prefix("q_b_proj", prefix), ) else: self.q_proj = ColumnParallelLinear( self.hidden_size, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, prefix=add_prefix("q_proj", prefix), ) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, prefix=add_prefix("kv_a_proj_with_mqa", prefix), ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, prefix=add_prefix("kv_b_proj", prefix), ) # O projection. self.o_proj = RowParallelLinear( self.num_heads * self.v_head_dim, self.hidden_size, bias=False, quant_config=quant_config, prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, ) # TODO support head_size 96 self.attn = RadixAttention( self.num_local_heads, 128, self.scaling, num_kv_heads=self.num_local_heads, layer_id=layer_id, prefix=add_prefix("attn", prefix), ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: q = self.q_proj(hidden_states)[0].view( -1, self.num_local_heads, self.qk_head_dim ) _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) kv_a = self.kv_a_layernorm(kv_a.contiguous()) kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = latent_cache[:, :, self.kv_lora_rank :] original_shapes = [q_pe.shape, k_pe.shape] q_pe, k_pe = self.rotary_emb( positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1) ) q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1]) q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope k[..., self.qk_nope_head_dim :] = k_pe q = torch.nn.functional.pad(q, [0, 128 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 128 ) k = torch.nn.functional.pad(k, [0, 128 - self.qk_head_dim], value=0).view( -1, self.num_local_heads * 128 ) v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view( -1, self.num_local_heads * 128 ) attn_output = self.attn(q, k, v, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, 128)[ ..., : self.v_head_dim ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output class MiniCPM3AttentionMLA(nn.Module): def __init__( self, config: PretrainedConfig, hidden_size: int, num_heads: int, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, q_lora_rank: int, kv_lora_rank: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, layer_id=None, prefix: str = "", ) -> None: super().__init__() self.layer_id = layer_id self.hidden_size = hidden_size self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.num_heads = num_heads tp_size = get_tensor_model_parallel_world_size() assert num_heads % tp_size == 0 self.num_local_heads = num_heads // tp_size self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings if self.q_lora_rank is not None: self.q_a_proj = ReplicatedLinear( self.hidden_size, self.q_lora_rank, bias=False, quant_config=quant_config, prefix=add_prefix("q_a_proj", prefix), ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear( q_lora_rank, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, prefix=add_prefix("q_b_proj", prefix), ) else: self.q_proj = ColumnParallelLinear( self.hidden_size, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, prefix=add_prefix("q_proj", prefix), ) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, prefix=add_prefix("kv_a_proj_with_mqa", prefix), ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, prefix=add_prefix("kv_b_proj", prefix), ) # O projection. self.o_proj = RowParallelLinear( self.num_heads * self.v_head_dim, self.hidden_size, bias=False, quant_config=quant_config, prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, ) self.attn = RadixAttention( self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim, self.scaling, num_kv_heads=1, layer_id=layer_id, v_head_dim=self.kv_lora_rank, prefix=add_prefix("attn", prefix), ) self.w_kc = None self.w_vc = None self.w_scale = None def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: q_len = hidden_states.shape[0] q_input = hidden_states.new_empty( q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim ) if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: q = self.q_proj(hidden_states)[0].view( -1, self.num_local_heads, self.qk_head_dim ) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) if self.w_kc.dtype == torch.float8_e4m3fn: q_nope_val, q_nope_scale = input_to_float8( q_nope.transpose(0, 1), torch.float8_e4m3fn ) q_nope_out = bmm_fp8( q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 ) else: q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] v_input = latent_cache[..., : self.kv_lora_rank] v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) k_input[..., : self.kv_lora_rank] = v_input k_pe = k_input[..., self.kv_lora_rank :] original_shapes = [q_pe.shape, k_pe.shape] q_pe, k_pe = self.rotary_emb( positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1) ) q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1]) q_input[..., self.kv_lora_rank :] = q_pe k_input[..., self.kv_lora_rank :] = k_pe attn_output = self.attn(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) if self.w_vc.dtype == torch.float8_e4m3fn: attn_output_val, attn_output_scale = input_to_float8( attn_output.transpose(0, 1), torch.float8_e4m3fn ) attn_bmm_output = bmm_fp8( attn_output_val, self.w_vc, attn_output_scale, self.w_scale, torch.bfloat16, ) else: attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc) attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) output, _ = self.o_proj(attn_output) return output class MiniCPM3DecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) if not global_server_args_dict["disable_mla"]: self.self_attn = MiniCPM3AttentionMLA( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, qk_nope_head_dim=config.qk_nope_head_dim, qk_rope_head_dim=config.qk_rope_head_dim, v_head_dim=self.hidden_size // config.num_attention_heads, q_lora_rank=( config.q_lora_rank if hasattr(config, "q_lora_rank") else None ), kv_lora_rank=config.kv_lora_rank, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, layer_id=layer_id, prefix=add_prefix("self_attn", prefix), ) else: self.self_attn = MiniCPM3Attention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, qk_nope_head_dim=config.qk_nope_head_dim, qk_rope_head_dim=config.qk_rope_head_dim, v_head_dim=self.hidden_size // config.num_attention_heads, q_lora_rank=( config.q_lora_rank if hasattr(config, "q_lora_rank") else None ), kv_lora_rank=config.kv_lora_rank, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, layer_id=layer_id, prefix=add_prefix("self_attn", prefix), ) self.mlp = MiniCPM3MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, ) hidden_states = residual + hidden_states * ( self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) ) # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states * ( self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) ) return hidden_states, None class MiniCPM3Model(nn.Module): def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, prefix=add_prefix("embed_tokens", prefix), ) self.layers = nn.ModuleList( [ MiniCPM3DecoderLayer( config, i, quant_config=quant_config, prefix=add_prefix(f"layers.{i}", prefix), ) for i in range(config.num_hidden_layers) ] ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) * self.config.scale_emb else: hidden_states = input_embeds residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual, ) hidden_states = self.norm(hidden_states) return hidden_states class MiniCPM3ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config self.model = MiniCPM3Model( config, quant_config=quant_config, prefix=add_prefix("model", prefix) ) # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) if not self.config.tie_word_embeddings: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, prefix=add_prefix("lm_head", prefix), ) self.scale_width = self.config.hidden_size / self.config.dim_model_base self.logits_processor = LogitsProcessor(config) @torch.no_grad() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: if input_embeds is not None: input_embeds = input_embeds * self.config.scale_emb hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = hidden_states / self.scale_width if self.config.tie_word_embeddings: lm_head = self.model.embed_tokens else: lm_head = self.lm_head return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] expert_params_mapping = [ # (param_name, weight_name, expert_id) ( "ws" if weight_name in ["w1", "w3"] else "w2s", f"experts.{expert_id}.{weight_name}.weight", expert_id, ) for expert_id in range(self.num_experts) for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue if self.config.tie_word_embeddings and "lm_head.weight" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: for param_name, weight_name, expert_id in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader( param, loaded_weight, weight_name, expert_id=expert_id ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) if not global_server_args_dict["disable_mla"]: for layer_id in range(self.config.num_hidden_layers): self_attn = self.model.layers[layer_id].self_attn w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) self_attn.w_vc = w_vc.contiguous().transpose(1, 2) if hasattr(self_attn.kv_b_proj, "weight_scale"): self_attn.w_scale = self_attn.kv_b_proj.weight_scale del self_attn.kv_b_proj EntryClass = MiniCPM3ForCausalLM