sglang0.4.5.post1/python/sglang/srt/models/grok.py

799 lines
27 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
import functools
import json
import logging
import math
import os
import warnings
from typing import Iterable, Optional, Tuple
import numpy as np
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.router import fused_moe_router_shim
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import dump_to_file
logger = logging.getLogger(__name__)
debug_tensor_dump_output_folder = None
debug_tensor_dump_inject = False
class Grok1MoE(nn.Module):
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(
self,
config: PretrainedConfig,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
reduce_results=True,
use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
# Gate always runs at full precision for stability (see https://arxiv.org/pdf/2101.03961)
self.gate = ReplicatedLinear(
hidden_size,
num_experts,
bias=False,
params_dtype=torch.float32,
quant_config=None,
)
self.router_logit_softcapping = getattr(
config, "router_logit_softcapping", 30.0
)
custom_routing_function = functools.partial(
fused_moe_router_shim, self.router_logit_softcapping
)
kwargs = {}
if global_server_args_dict["enable_ep_moe"]:
MoEImpl = EPMoE
else:
MoEImpl = FusedMoE
kwargs["reduce_results"] = reduce_results
kwargs["use_presharded_weights"] = use_presharded_weights
kwargs["inplace"] = inplace
kwargs["no_combine"] = no_combine
self.experts = MoEImpl(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
custom_routing_function=custom_routing_function,
activation="gelu",
**kwargs,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# need to assert self.gate.quant_method is unquantized
return self.experts(hidden_states, self.gate.weight)
class Grok1Attention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
load_presharded_attn: bool = False,
) -> None:
super().__init__()
self.config = config
self.layer_id = layer_id
self.hidden_size = hidden_size
attn_tp_rank = get_tensor_model_parallel_rank()
attn_tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % attn_tp_size == 0
self.num_heads = self.total_num_heads // attn_tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= attn_tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % attn_tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert attn_tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
self.head_dim = getattr(config, "head_dim", 128)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.load_presharded_attn = load_presharded_attn
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
load_presharded_attn=self.load_presharded_attn,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
use_presharded_weights=self.load_presharded_attn,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
logit_cap=logit_cap,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results
), "short-circuiting allreduce will lead to hangs"
return hidden_states
if debug_tensor_dump_output_folder:
dump_to_file(
debug_tensor_dump_output_folder,
f"attn_input_{self.layer_id}",
hidden_states,
)
if debug_tensor_dump_inject:
name = os.path.join(
debug_tensor_dump_output_folder,
f"jax_dump_attn_input_{self.layer_id}.npy",
)
logger.info(f"Load {name} from jax.")
x = np.load(name)
hidden_states = torch.tensor(x[0, : hidden_states.shape[0]]).to(
hidden_states
)
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
if debug_tensor_dump_output_folder:
num_tokens = q.shape[0]
num_heads_q = self.num_heads
head_dim = self.head_dim
num_heads_kv = k.numel() // (num_tokens * head_dim)
dump_to_file(
debug_tensor_dump_output_folder,
f"q_{self.layer_id}",
tensor_model_parallel_all_gather(
q.reshape(num_tokens, num_heads_q, head_dim).contiguous(), dim=1
).contiguous(),
)
dump_to_file(
debug_tensor_dump_output_folder,
f"k_{self.layer_id}",
tensor_model_parallel_all_gather(
k.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1
).contiguous(),
)
dump_to_file(
debug_tensor_dump_output_folder,
f"v_{self.layer_id}",
tensor_model_parallel_all_gather(
v.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1
).contiguous(),
)
attn_output = self.attn(q, k, v, forward_batch)
if debug_tensor_dump_output_folder:
dump_to_file(
debug_tensor_dump_output_folder,
f"attn_output_{self.layer_id}",
tensor_model_parallel_all_gather(
attn_output.reshape(num_tokens, num_heads_q, head_dim).contiguous(),
dim=1,
).contiguous(),
)
output, _ = self.o_proj(attn_output)
return output
class Grok1DecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
load_presharded_moe: bool = False,
load_presharded_attn: bool = False,
load_presharded_mlp: bool = False,
) -> None:
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.layer_id = layer_id
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = Grok1Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
quant_config=quant_config,
reduce_results=False,
load_presharded_attn=load_presharded_attn,
)
self.block_sparse_moe = Grok1MoE(
config=config,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=getattr(
config,
"moe_intermediate_size",
getattr(config, "intermediate_size", None),
),
quant_config=quant_config,
reduce_results=True,
use_presharded_weights=load_presharded_moe,
inplace=True,
no_combine=False, # just a suggestion to not combine topk
)
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.ffn = self.block_sparse_moe
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor] = None,
deferred_norm: Optional[RMSNorm] = None,
) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
# Self Attention
if deferred_norm is not None:
assert residual is not None
# here hidden_states is output of ffn, residual is residual from after previous attn layer
hidden_states, residual = fused_dual_residual_rmsnorm(
hidden_states,
residual,
deferred_norm.weight,
self.pre_attn_norm.weight,
deferred_norm.variance_epsilon,
)
else:
# here hidden_states is the residual
hidden_states, residual = (
fused_rmsnorm(
hidden_states,
self.pre_attn_norm.weight,
self.pre_attn_norm.variance_epsilon,
),
hidden_states,
)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
if get_tensor_model_parallel_world_size() > 1:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = fused_dual_residual_rmsnorm(
hidden_states,
residual,
self.post_attn_norm.weight,
self.pre_moe_norm.weight,
self.post_attn_norm.variance_epsilon,
)
# Fully Connected
hidden_states = self.ffn(hidden_states)
return hidden_states, residual, self.post_moe_norm # defer layernorm
class Grok1Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
load_presharded_moe: bool = False,
load_presharded_embedding: bool = False,
load_presharded_attn: bool = False,
load_presharded_mlp: bool = False,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
use_presharded_weights=load_presharded_embedding,
)
self.layers = nn.ModuleList(
[
Grok1DecoderLayer(
config,
i,
quant_config=quant_config,
load_presharded_moe=load_presharded_moe,
load_presharded_attn=load_presharded_attn,
load_presharded_mlp=load_presharded_mlp,
)
for i in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
hidden_states.mul_(self.config.embedding_multiplier_scale)
else:
hidden_states = input_embeds
residual, deferred_norm = None, None
for i in range(len(self.layers)):
hidden_states, residual, deferred_norm = self.layers[i](
positions, hidden_states, forward_batch, residual, deferred_norm
)
if debug_tensor_dump_output_folder:
hidden_states = (
fused_rmsnorm(
hidden_states,
deferred_norm.weight,
deferred_norm.variance_epsilon,
)
+ residual
)
dump_to_file(
debug_tensor_dump_output_folder,
"last_hidden_before_norm",
hidden_states,
)
hidden_states = fused_rmsnorm(
hidden_states,
self.norm.weight,
self.norm.variance_epsilon,
)
dump_to_file(
debug_tensor_dump_output_folder,
"last_hidden_after_norm",
hidden_states,
)
else:
hidden_states, _ = fused_dual_residual_rmsnorm(
hidden_states,
residual,
deferred_norm.weight,
self.norm.weight,
deferred_norm.variance_epsilon,
)
return hidden_states
class Grok1ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
# Get presharded weights.
self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
self.load_presharded_moe = (
self.config.num_local_experts > 0
and get_tensor_model_parallel_world_size() > 1
)
self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
self.load_presharded_embedding = getattr(
config, "load_presharded_embedding", False
)
self.is_weights_presharded = (
self.load_presharded_mlp
or self.load_presharded_moe
or self.load_presharded_attn
or self.load_presharded_embedding
)
if self.is_weights_presharded:
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
default_replicate_lm_head = False
self.replicate_lm_head = getattr(
config, "replicate_lm_head", default_replicate_lm_head
)
self.model = Grok1Model(
config,
quant_config=quant_config,
load_presharded_moe=self.load_presharded_moe,
load_presharded_embedding=self.load_presharded_embedding,
load_presharded_attn=self.load_presharded_attn,
load_presharded_mlp=self.load_presharded_mlp,
)
lm_head_params_dtype = None
if self.replicate_lm_head:
self.lm_head = ReplicatedLinear(
config.hidden_size,
config.vocab_size,
bias=False,
params_dtype=lm_head_params_dtype,
)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
use_presharded_weights=self.load_presharded_embedding,
params_dtype=lm_head_params_dtype,
)
self.logits_processor = LogitsProcessor(config)
# Dump tensors for debugging
global debug_tensor_dump_output_folder, debug_tensor_dump_inject
debug_tensor_dump_output_folder = global_server_args_dict[
"debug_tensor_dump_output_folder"
]
debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"]
warnings.filterwarnings("ignore", category=FutureWarning)
if get_tensor_model_parallel_rank() == 0:
logger.info(
f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if debug_tensor_dump_output_folder:
dump_to_file(debug_tensor_dump_output_folder, "input_ids", input_ids)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
num_experts: Optional[int] = None,
ignore_parent_name: bool = False,
) -> dict[str, torch.Tensor]:
if num_experts is None:
num_experts = self.config.num_local_experts
stacked_params_mapping = []
stacked_params_mapping += [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
stacked_params_mapping += [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=num_experts,
)
params_dict = dict(self.named_parameters())
all_names = set(params_dict.keys())
hit_names = set()
def load_weight_wrapper(
name: str, loaded_weight: torch.Tensor, *args, **kwargs
):
if ignore_parent_name:
name = name.split(".")[-1]
if name not in params_dict:
return
# Fuse constant multipliers into the weights
if "lm_head" in name:
loaded_weight = (
loaded_weight.to(torch.float32)
* self.config.output_multiplier_scale
)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, *args, **kwargs)
hit_names.add(name)
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
load_weight_wrapper(name, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
load_weight_wrapper(
name,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name is None:
continue
load_weight_wrapper(name=name, loaded_weight=loaded_weight)
if len(hit_names) > 5:
missing = all_names - hit_names
missing_exclude_scales = {x for x in missing if "scale" not in x}
logger.info(
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
)
if len(missing_exclude_scales) > 0:
raise ValueError(
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
)
elif len(hit_names) == 0:
raise ValueError("load_weights failed because it did not hit any names.")
return hit_names
def get_num_params_analytical(self):
cfg = self.config
moe_intermediate_size = getattr(
cfg,
"moe_intermediate_size",
getattr(cfg, "intermediate_size", None),
)
num_experts = cfg.num_local_experts
wq = (
cfg.num_hidden_layers
* cfg.hidden_size
* cfg.num_attention_heads
* cfg.head_dim
)
wkv = (
cfg.num_hidden_layers
* cfg.hidden_size
* cfg.num_key_value_heads
* cfg.head_dim
* 2
)
out = (
cfg.num_hidden_layers
* cfg.hidden_size
* cfg.num_attention_heads
* cfg.head_dim
)
ffn1 = (
cfg.num_hidden_layers
* num_experts
* cfg.hidden_size
* moe_intermediate_size
* 2
)
ffn2 = (
cfg.num_hidden_layers
* num_experts
* cfg.hidden_size
* moe_intermediate_size
)
embed = cfg.hidden_size * cfg.vocab_size * 2
return wq + wkv + out + ffn1 + ffn2 + embed
def get_num_params_torch(self):
return (
sum(p.numel() for p in self.parameters())
* get_tensor_model_parallel_world_size()
)
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
def _prepare_presharded_weights(
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
) -> Tuple[str, list[str], bool]:
import glob
import os
if get_tensor_model_parallel_world_size() == 1:
return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt)
if not os.path.isdir(model_name_or_path):
from sglang.srt.model_loader.weight_utils import download_weights_from_hf
allow_patterns = ["*.safetensors", "*.bin"]
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
else:
hf_folder = model_name_or_path
tp_rank = get_tensor_model_parallel_rank()
# The old format
allow_patterns = [f"*-{tp_rank:03d}.bin"]
# The new format
allow_patterns += [f"*-TP-{tp_rank:03d}.safetensors", "*-TP-common.safetensors"]
hf_weights_files = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if hf_weights_files[0].endswith("safetensors"):
use_safetensors = True
else:
use_safetensors = False
return hf_folder, hf_weights_files, use_safetensors
class Grok1ModelForCausalLM(Grok1ForCausalLM):
"""An alias for backward-compatbility."""
pass
EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]