sglang0.4.5.post1/python/sglang/srt/configs/model_config.py

513 lines
20 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import logging
import math
from enum import IntEnum, auto
from typing import List, Optional, Set, Union
import torch
from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.utils import get_bool_env_var, is_hip
logger = logging.getLogger(__name__)
class AttentionArch(IntEnum):
MLA = auto()
MHA = auto()
class ModelConfig:
def __init__(
self,
model_path: str,
trust_remote_code: bool = True,
revision: Optional[str] = None,
context_length: Optional[int] = None,
model_override_args: Optional[str] = None,
is_embedding: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
) -> None:
self.model_path = model_path
self.revision = revision
self.quantization = quantization
# Parse args
self.maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
kwargs = {}
if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip()
self.hf_config = get_config(
self.model_path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
**kwargs,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
# Check model type
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
)
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
self.is_audio_model = is_audio_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# Derive context length
derived_context_len = get_context_length(self.hf_text_config)
if context_length is not None:
if context_length > derived_context_len:
if get_bool_env_var(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True"
):
logger.warning(
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors."
)
self.context_len = context_length
else:
raise ValueError(
f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
)
else:
self.context_len = context_length
else:
self.context_len = derived_context_len
# Unify the config keys for hf_text_config
self.head_dim = getattr(
self.hf_text_config,
"head_dim",
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
)
# FIXME: temporary special judge for MLA architecture
if (
"DeepseekV2ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
):
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
# Handle rope scaling with yarn
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
if self.hf_config.rope_scaling:
mscale_all_dim = self.hf_config.rope_scaling.get(
"mscale_all_dim", False
)
scaling_factor = self.hf_config.rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
self.head_dim = 128
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures:
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
else:
self.attention_arch = AttentionArch.MHA
self.num_attention_heads = self.hf_text_config.num_attention_heads
self.num_key_value_heads = getattr(
self.hf_text_config, "num_key_value_heads", None
)
# for Dbrx and MPT models
if self.hf_config.model_type in ["dbrx", "mpt"]:
self.num_key_value_heads = getattr(
self.hf_config.attn_config, "kv_n_heads", None
)
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
self.hidden_size = self.hf_text_config.hidden_size
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size
# Verify quantization
self._verify_quantization()
# Cache attributes
self.hf_eos_token_id = self.get_hf_eos_token_id()
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
# For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False)
)
if not new_decoder_arch_falcon and getattr(
self.hf_text_config, "multi_query", False
):
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
# For DBRX and MPT
if self.hf_config.model_type in ["mpt"]:
if "kv_n_heads" in self.hf_config.attn_config:
return self.hf_config.attn_config["kv_n_heads"]
return self.hf_config.num_attention_heads
if self.hf_config.model_type in ["dbrx"]:
return getattr(
self.hf_config.attn_config,
"kv_n_heads",
self.hf_config.num_attention_heads,
)
attributes = [
# For Falcon:
"n_head_kv",
"num_kv_heads",
# For LLaMA-2:
"num_key_value_heads",
# For ChatGLM:
"multi_query_group_num",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_text_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads
def get_num_kv_heads(self, tensor_parallel_size) -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1, total_num_kv_heads // tensor_parallel_size)
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
# compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None)
return quant_cfg
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = [
"awq",
"gptq",
"fp8",
"compressed_tensors",
"compressed-tensors",
"fbgemm_fp8",
"w8a8_fp8",
]
optimized_quantization_methods = [
"fp8",
"marlin",
"modelopt",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"fbgemm_fp8",
"compressed_tensors",
"compressed-tensors",
"experts_int8",
"w8a8_int8",
"w8a8_fp8",
]
compatible_quantization_methods = {
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
}
if self.quantization is not None:
self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config()
if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
# Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items():
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization
)
if quantization_override:
quant_method = quantization_override
self.quantization = quantization_override
break
# Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method
elif self.quantization != quant_method:
if (
self.quantization not in compatible_quantization_methods
or quant_method
not in compatible_quantization_methods[self.quantization]
):
raise ValueError(
"Quantization method specified in the model config "
f"({quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization})."
)
if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}."
)
if is_hip() and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm."
)
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.",
self.quantization,
)
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = getattr(self.hf_config, "eos_token_id", None)
if eos_ids:
# it can be either int or list of int
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
return eos_ids
def maybe_pull_model_tokenizer_from_remote(self) -> None:
"""
Pull the model config files to a temporary
directory in case of remote.
Args:
model: The model name or path.
"""
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url
if is_remote_url(self.model_path):
logger.info("Pulling model configs from remote...")
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(self.model_path)
if is_remote_url(self.model_path):
client.pull_files(allow_pattern=["*config.json"])
self.model_weights = self.model_path
self.model_path = client.get_local_dir()
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
class_name = config.architectures[0]
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr(config, "torch_dtype", torch.float16)
return config
if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
if hasattr(config, "language_config"):
return config.language_config
else:
return config
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
"float": torch.float32,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: Union[str, torch.dtype],
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type.startswith("gemma"):
if config.model_type == "gemma":
gemma_version = ""
else:
gemma_version = config.model_type[5]
logger.info(
f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16."
)
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else:
raise ValueError(f"Unknown dtype: {dtype}")
# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
return torch_dtype
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architecture
# 2. check the `is_embedding` server args
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
or "InternLM2ForRewardModel" in model_architectures
or "Qwen2ForRewardModel" in model_architectures
or "Qwen2ForSequenceClassification" in model_architectures
or "CLIPModel" in model_architectures
):
return False
else:
return not is_embedding
multimodal_model_archs = [
"DeepseekVL2ForCausalLM",
"Gemma3ForConditionalGeneration",
"Grok1VForCausalLM",
"Grok1AForCausalLM",
"LlavaLlamaForCausalLM",
"LlavaMistralForCausalLM",
"LlavaQwenForCausalLM",
"LlavaVidForCausalLM",
"MiniCPMO",
"MiniCPMV",
"MultiModalityCausalLM",
"MllamaForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"CLIPModel",
]
def is_multimodal_model(model_architectures: List[str]):
if any(
multi_model_arch in model_architectures
for multi_model_arch in multimodal_model_archs
):
return True
else:
return False
def is_multimodal_gen_model(model_architectures: List[str]):
return False
def is_image_gen_model(model_architectures: List[str]):
return False
def is_audio_model(model_architectures: List[str]):
return False
def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0