sglang0.4.5.post1/python/sglang/srt/configs/load_config.py

90 lines
3.2 KiB
Python

# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
import enum
import json
import logging
from dataclasses import dataclass, field
from typing import List, Optional, Union
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
class LoadFormat(str, enum.Enum):
AUTO = "auto"
PT = "pt"
SAFETENSORS = "safetensors"
NPCACHE = "npcache"
DUMMY = "dummy"
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
LAYERED = "layered"
JAX = "jax"
REMOTE = "remote"
@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2).
"""
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config)
self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
self.ignore_patterns,
)
else:
self.ignore_patterns = ["original/**/*"]
def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f
for f in LoadFormat.__members__
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format '{load_format}' is not supported in ROCm. "
f"Supported load formats are "
f"{rocm_supported_load_format}"
)