# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py import enum import json import logging from dataclasses import dataclass, field from typing import List, Optional, Union from sglang.srt.utils import is_hip logger = logging.getLogger(__name__) class LoadFormat(str, enum.Enum): AUTO = "auto" PT = "pt" SAFETENSORS = "safetensors" NPCACHE = "npcache" DUMMY = "dummy" SHARDED_STATE = "sharded_state" GGUF = "gguf" BITSANDBYTES = "bitsandbytes" MISTRAL = "mistral" LAYERED = "layered" JAX = "jax" REMOTE = "remote" @dataclass class LoadConfig: """ download_dir: Directory to download and load the weights, default to the default cache directory of huggingface. load_format: The format of the model weights to load: "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. "pt" will load the weights in the pytorch bin format. "safetensors" will load the weights in the safetensors format. "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. "dummy" will initialize the weights with random values, which is mainly for profiling. "bitsandbytes" will load nf4 type weights. ignore_patterns: The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints. decryption_key_file: If set, decrypts the output files with a password read from this file (after PBKDF2). """ load_format: Union[str, LoadFormat] = LoadFormat.AUTO download_dir: Optional[str] = None model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) ignore_patterns: Optional[Union[List[str], str]] = None decryption_key_file: Optional[str] = None def __post_init__(self): model_loader_extra_config = self.model_loader_extra_config or {} if isinstance(model_loader_extra_config, str): self.model_loader_extra_config = json.loads(model_loader_extra_config) self._verify_load_format() if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: logger.info( "Ignoring the following patterns when downloading weights: %s", self.ignore_patterns, ) else: self.ignore_patterns = ["original/**/*"] def _verify_load_format(self) -> None: if not isinstance(self.load_format, str): return load_format = self.load_format.lower() self.load_format = LoadFormat(load_format) rocm_not_supported_load_format: List[str] = [] if is_hip() and load_format in rocm_not_supported_load_format: rocm_supported_load_format = [ f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format) ] raise ValueError( f"load format '{load_format}' is not supported in ROCm. " f"Supported load formats are " f"{rocm_supported_load_format}" )