sglang0.4.5.post1/python/sglang/srt/hf_transformers_utils.py

268 lines
9.0 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for Huggingface Transformers."""
import contextlib
import os
import warnings
from pathlib import Path
from typing import Dict, Optional, Type, Union
from huggingface_hub import snapshot_download
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
PretrainedConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from sglang.srt.configs import (
ChatGLMConfig,
DbrxConfig,
DeepseekVL2Config,
ExaoneConfig,
MultiModalityConfig,
)
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig,
ExaoneConfig.model_type: ExaoneConfig,
DeepseekVL2Config.model_type: DeepseekVL2Config,
MultiModalityConfig.model_type: MultiModalityConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
with contextlib.suppress(ValueError):
AutoConfig.register(name, cls)
def download_from_hf(model_path: str):
if os.path.exists(model_path):
return model_path
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
def get_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
model_override_args: Optional[dict] = None,
**kwargs,
):
is_gguf = check_gguf_file(model)
if is_gguf:
kwargs["gguf_file"] = model
model = Path(model).parent
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
# FIXME: Pour contents of janus-pro's langauge_config to first-level
if isinstance(model, str) and model.lower().startswith("deepseek-ai/janus-pro"):
assert hasattr(config, "language_config")
for key, val in config.language_config.__dict__.items():
setattr(config, key, val)
setattr(config, "architectures", ["MultiModalityCausalLM"])
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
setattr(config, "_name_or_path", model)
if model_override_args:
config.update(model_override_args)
# Special architecture mapping check for GGUF models
if is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
return config
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
# have a preference for which value gets used.
CONTEXT_LENGTH_KEYS = [
"max_sequence_length",
"seq_length",
"max_seq_len",
"model_max_length",
"max_position_embeddings",
]
def get_context_length(config):
"""Get the context length of a model from a huggingface model configs."""
text_config = config
rope_scaling = getattr(text_config, "rope_scaling", None)
if rope_scaling:
rope_scaling_factor = rope_scaling.get("factor", 1)
if "original_max_position_embeddings" in rope_scaling:
rope_scaling_factor = 1
if rope_scaling.get("rope_type", None) == "llama3":
rope_scaling_factor = 1
else:
rope_scaling_factor = 1
for key in CONTEXT_LENGTH_KEYS:
val = getattr(text_config, key, None)
if val is not None:
return int(rope_scaling_factor * val)
return 2048
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
def get_tokenizer(
tokenizer_name: str,
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
is_gguf = check_gguf_file(tokenizer_name)
if is_gguf:
kwargs["gguf_file"] = tokenizer_name
tokenizer_name = Path(tokenizer_name).parent
if is_remote_url(tokenizer_name):
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(tokenizer_name)
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
tokenizer_name = client.get_local_dir()
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
clean_up_tokenization_spaces=False,
**kwargs,
)
except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments.
err_msg = (
"Failed to load the tokenizer. If you are using a LLaMA V1 model "
f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
"original tokenizer."
)
raise RuntimeError(err_msg) from e
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if not trust_remote_code and (
"does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e)
):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI."
)
raise RuntimeError(err_msg) from e
else:
raise e
if not isinstance(tokenizer, PreTrainedTokenizerFast):
warnings.warn(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead."
)
attach_additional_stop_token_ids(tokenizer)
return tokenizer
def get_processor(
tokenizer_name: str,
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
**kwargs,
):
# pop 'revision' from kwargs if present.
revision = kwargs.pop("revision", tokenizer_revision)
config = AutoConfig.from_pretrained(
tokenizer_name,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
# fix: for Qwen2-VL model, inject default 'size' if not provided.
if config.model_type in {"qwen2_vl"}:
if "size" not in kwargs:
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
attach_additional_stop_token_ids(processor.tokenizer)
return processor
def attach_additional_stop_token_ids(tokenizer):
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
if "<|eom_id|>" in tokenizer.get_added_vocab():
tokenizer.additional_stop_token_ids = set(
[tokenizer.get_added_vocab()["<|eom_id|>"]]
)
else:
tokenizer.additional_stop_token_ids = None
def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
if not model.is_file():
return False
elif model.suffix == ".gguf":
return True
with open(model, "rb") as f:
header = f.read(4)
return header == b"GGUF"