sglang0.4.5.post1/python/sglang/srt/models/torch_native_llama.py

525 lines
19 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""
Inference-only LLaMA model compatible with HuggingFace weights.
This model supports tensor parallelism (TP) using the PyTorch tensor parallel package.
Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
Here is a quick example to enable TP:
```python
from sglang.srt.model_parallel import tensor_parallel
device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
tensor_parallel(model, device_mesh)
```
An end-to-end example can be found in `python/sglang/bench_one_batch.py`.
You can run it with the following command:
```bash
$ python3 -m sglang.bench_one_batch --correct \
--model meta-llama/Meta-Llama-3-8B \
--json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' \
--tensor-parallel-size 2 \
--disable-cuda-graph
```
We will eanble CUDA Graph support soon.
"""
import types
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from torch.nn.parameter import Parameter
from transformers import LlamaConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
def gate_up_proj_weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: int,
):
# shard_id: (shard_offset, shard_size)
gate_up_offsets = {}
current_shard_offset = 0
for i, output_size in enumerate(self.output_sizes):
# Everything shrinks by tp_size if TP enabled
output_size = output_size // tp_size
gate_up_offsets[i] = (current_shard_offset, output_size)
current_shard_offset += output_size
# Re-size the param to the size after TP
if current_shard_offset != param.shape[0]:
# The clone will free the original, full tensor
param.data = param.data.narrow(0, 0, current_shard_offset).clone()
# Now load gate or up
assert loaded_shard_id < len(self.output_sizes)
param_data = param.data
shard_offset, shard_size = gate_up_offsets[loaded_shard_id]
param_data = param_data.narrow(0, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class LlamaMLP(nn.Module):
_tp_plan = {
"gate_up_proj": "Colwise_Sharded",
"down_proj": "Rowwise",
}
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = torch.nn.Linear(
hidden_size,
intermediate_size * 2,
bias=False,
)
self.gate_up_proj.output_sizes = [intermediate_size] * 2
self.gate_up_proj.weight_loader = types.MethodType(
gate_up_proj_weight_loader, self.gate_up_proj
)
self.gate_up_proj.weight.weight_loader = self.gate_up_proj.weight_loader
self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
def qkv_proj_weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: str,
):
num_heads = self.num_heads // tp_size
num_kv_heads = self.num_kv_heads // tp_size
# shard_id: (shard_offset, shard_size)
qkv_offsets = {
"q": (0, num_heads * self.head_size),
"k": (num_heads * self.head_size, num_kv_heads * self.head_size),
"v": (
(num_heads + num_kv_heads) * self.head_size,
num_kv_heads * self.head_size,
),
}
total_size = qkv_offsets["v"][0] + qkv_offsets["v"][1]
# Re-size the param to the size after TP
if total_size != param.shape[0]:
# The clone will free the original, full tensor
param.data = param.data.narrow(0, 0, total_size).clone()
# Now load q, k or v
shard_offset, shard_size = qkv_offsets[loaded_shard_id]
param_data = param.data
param_data = param_data.narrow(0, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class LlamaAttention(nn.Module):
_tp_plan = {
"qkv_proj": "Colwise_Sharded",
"o_proj": "Rowwise",
}
def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_is_neox_style: bool = True,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(
config, "head_dim", self.hidden_size // self.total_num_heads
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = torch.nn.Linear(
hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
bias=False,
)
self.qkv_proj.head_size = self.head_dim
self.qkv_proj.num_heads = self.total_num_heads
self.qkv_proj.num_kv_heads = self.total_num_kv_heads
self.qkv_proj.weight_loader = types.MethodType(
qkv_proj_weight_loader, self.qkv_proj
)
self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader
self.qkv_proj.weight.output_dim = 0
self.o_proj = torch.nn.Linear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=rope_is_neox_style,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, forward_batch)
output = self.o_proj(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings
)
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
rope_is_neox_style=rope_is_neox_style,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
)
for i in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class TorchNativeLlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.supports_torch_tp = True
self.model = LlamaModel(config, quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
# turning off autotune for fp8dq since it doesn't give speedup and
# increases compile time significantly
torch._inductor.config.max_autotune_gemm_backends = "ATEN"
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def get_hidden_dim(self, module_name):
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return self.config.hidden_size, self.config.hidden_size
elif module_name in ["kv_proj"]:
return self.config.hidden_size, self.config.hidden_size // (
self.config.num_attention_heads // self.config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
def get_module_name(self, name):
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_module_name_from_weight_name(self, name):
stacked_params_mapping = [
# (param_name, shard_name, shard_id, num_shard)
("qkv_proj", "q_proj", "q", 3),
("qkv_proj", "k_proj", "k", 3),
("qkv_proj", "v_proj", "v", 3),
("gate_up_proj", "gate_proj", 0, 2),
("gate_up_proj", "up_proj", 1, 2),
]
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
if weight_name in name:
return (
name.replace(weight_name, param_name)[: -len(".weight")],
num_shard,
)
return name[: -len(".weight")], 1
def get_num_params(self):
params_dict = dict(self.named_parameters())
return len(params_dict)
def load_weights_to_module(
self,
fqn: str,
weights: Iterable[Tuple[str, torch.Tensor]],
):
"""Load weights onto submodule pointed by path `fqn`."""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
module = self.get_submodule(fqn)
params_dict = dict(module.named_parameters(prefix=fqn, recurse=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") or name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") or name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
):
"""Load weights onto the full model."""
self.load_weights_to_module("", weights)
class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
pass
EntryClass = [TorchNativeLlamaForCausalLM, TorchNativePhi3ForCausalLM]