90 lines
3.2 KiB
Python
90 lines
3.2 KiB
Python
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
|
import enum
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from typing import List, Optional, Union
|
|
|
|
from sglang.srt.utils import is_hip
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LoadFormat(str, enum.Enum):
|
|
AUTO = "auto"
|
|
PT = "pt"
|
|
SAFETENSORS = "safetensors"
|
|
NPCACHE = "npcache"
|
|
DUMMY = "dummy"
|
|
SHARDED_STATE = "sharded_state"
|
|
GGUF = "gguf"
|
|
BITSANDBYTES = "bitsandbytes"
|
|
MISTRAL = "mistral"
|
|
LAYERED = "layered"
|
|
JAX = "jax"
|
|
REMOTE = "remote"
|
|
|
|
|
|
@dataclass
|
|
class LoadConfig:
|
|
"""
|
|
download_dir: Directory to download and load the weights, default to the
|
|
default cache directory of huggingface.
|
|
load_format: The format of the model weights to load:
|
|
"auto" will try to load the weights in the safetensors format and
|
|
fall back to the pytorch bin format if safetensors format is
|
|
not available.
|
|
"pt" will load the weights in the pytorch bin format.
|
|
"safetensors" will load the weights in the safetensors format.
|
|
"npcache" will load the weights in pytorch format and store
|
|
a numpy cache to speed up the loading.
|
|
"dummy" will initialize the weights with random values, which is
|
|
mainly for profiling.
|
|
"bitsandbytes" will load nf4 type weights.
|
|
ignore_patterns: The list of patterns to ignore when loading the model.
|
|
Default to "original/**/*" to avoid repeated loading of llama's
|
|
checkpoints.
|
|
decryption_key_file: If set, decrypts the output files with a password read
|
|
from this file (after PBKDF2).
|
|
"""
|
|
|
|
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
|
download_dir: Optional[str] = None
|
|
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
|
ignore_patterns: Optional[Union[List[str], str]] = None
|
|
decryption_key_file: Optional[str] = None
|
|
|
|
def __post_init__(self):
|
|
model_loader_extra_config = self.model_loader_extra_config or {}
|
|
if isinstance(model_loader_extra_config, str):
|
|
self.model_loader_extra_config = json.loads(model_loader_extra_config)
|
|
self._verify_load_format()
|
|
|
|
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
|
logger.info(
|
|
"Ignoring the following patterns when downloading weights: %s",
|
|
self.ignore_patterns,
|
|
)
|
|
else:
|
|
self.ignore_patterns = ["original/**/*"]
|
|
|
|
def _verify_load_format(self) -> None:
|
|
if not isinstance(self.load_format, str):
|
|
return
|
|
|
|
load_format = self.load_format.lower()
|
|
self.load_format = LoadFormat(load_format)
|
|
|
|
rocm_not_supported_load_format: List[str] = []
|
|
if is_hip() and load_format in rocm_not_supported_load_format:
|
|
rocm_supported_load_format = [
|
|
f
|
|
for f in LoadFormat.__members__
|
|
if (f not in rocm_not_supported_load_format)
|
|
]
|
|
raise ValueError(
|
|
f"load format '{load_format}' is not supported in ROCm. "
|
|
f"Supported load formats are "
|
|
f"{rocm_supported_load_format}"
|
|
)
|