# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Common utilities.""" import base64 import builtins import ctypes import dataclasses import io import ipaddress import itertools import json import logging import os import pickle import random import re import resource import shutil import signal import socket import subprocess import sys import tempfile import threading import time import warnings from contextlib import contextmanager from functools import lru_cache from importlib.metadata import PackageNotFoundError, version from importlib.util import find_spec from io import BytesIO from multiprocessing.reduction import ForkingPickler from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union import numpy as np import psutil import requests import torch import torch.distributed import torch.distributed as dist import triton import zmq from fastapi.responses import ORJSONResponse from packaging import version as pkg_version from PIL import Image from starlette.routing import Mount from torch import nn from torch.func import functional_call from torch.library import Library from torch.profiler import ProfilerActivity, profile, record_function from torch.utils._contextlib import _DecoratorContextManager from triton.runtime.cache import ( FileCacheManager, default_cache_dir, default_dump_dir, default_override_dir, ) logger = logging.getLogger(__name__) show_time_cost = False time_infos = {} HIP_FP8_E4M3_FNUZ_MAX = 224.0 def get_bool_env_var(name: str, default: str = "false") -> bool: value = os.getenv(name, default) return value.lower() in ("true", "1") # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip def is_hip() -> bool: return torch.version.hip is not None if is_hip(): FP8_E4M3_MAX = HIP_FP8_E4M3_FNUZ_MAX else: FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FP8_E4M3_MIN = -FP8_E4M3_MAX builtins.FP8_E4M3_MAX = FP8_E4M3_MAX builtins.FP8_E4M3_MIN = FP8_E4M3_MIN def is_rocm() -> bool: return torch.cuda.is_available() and torch.version.hip def is_cuda(): return torch.cuda.is_available() and torch.version.cuda def is_cuda_alike(): return is_cuda() or is_hip() def is_hpu() -> bool: return hasattr(torch, "hpu") and torch.hpu.is_available() def is_xpu() -> bool: return hasattr(torch, "xpu") and torch.xpu.is_available() def is_flashinfer_available(): """ Check whether flashinfer is available. As of Oct. 6, 2024, it is only available on NVIDIA GPUs. """ if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"): return False return is_cuda() def is_cuda_available(): return is_cuda() _ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var( "SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false" ) class DynamicGradMode(_DecoratorContextManager): """ A combination of torch.no_grad and torch.inference_mode, with their behavior controlled by an environment variable. Just refer to them. """ @staticmethod def set_inference_mode(mode: bool): if isinstance(mode, bool): global _ENABLE_TORCH_INFERENCE_MODE _ENABLE_TORCH_INFERENCE_MODE = mode else: logger.warning("mode is not a boolean object") def __init__(self, mode=True): if not torch._jit_internal.is_scripting(): super().__init__() if _ENABLE_TORCH_INFERENCE_MODE: self.mode = mode else: self.prev = False def __new__(cls, mode_or_orig_func=True if _ENABLE_TORCH_INFERENCE_MODE else None): if mode_or_orig_func is None or isinstance(mode_or_orig_func, bool): return super().__new__(cls) return cls()(mode_or_orig_func) def __enter__(self) -> None: if _ENABLE_TORCH_INFERENCE_MODE: self._inference_mode_context = torch._C._InferenceMode(self.mode) self._inference_mode_context.__enter__() else: self.prev = torch.is_grad_enabled() torch.set_grad_enabled(False) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if _ENABLE_TORCH_INFERENCE_MODE: self._inference_mode_context.__exit__(exc_type, exc_value, traceback) else: torch.set_grad_enabled(self.prev) def clone(self) -> "DynamicGradMode": r""" Create a copy of this class """ if _ENABLE_TORCH_INFERENCE_MODE: return self.__class__(self.mode) else: return self.__class__() def enable_show_time_cost(): global show_time_cost show_time_cost = True class TimeInfo: def __init__(self, name, interval=0.1, color=0, indent=0): self.name = name self.interval = interval self.color = color self.indent = indent self.acc_time = 0 self.last_acc_time = 0 def check(self): if self.acc_time - self.last_acc_time > self.interval: self.last_acc_time = self.acc_time return True return False def pretty_print(self): print(f"\x1b[{self.color}m", end="") print("-" * self.indent * 2, end="") print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m") def mark_start(name, interval=0.1, color=0, indent=0): global time_infos, show_time_cost if not show_time_cost: return torch.cuda.synchronize() if time_infos.get(name, None) is None: time_infos[name] = TimeInfo(name, interval, color, indent) time_infos[name].acc_time -= time.time() def mark_end(name): global time_infos, show_time_cost if not show_time_cost: return torch.cuda.synchronize() time_infos[name].acc_time += time.time() if time_infos[name].check(): time_infos[name].pretty_print() def calculate_time(show=False, min_cost_ms=0.0): def wrapper(func): def inner_func(*args, **kwargs): torch.cuda.synchronize() if show: start_time = time.time() result = func(*args, **kwargs) torch.cuda.synchronize() if show: cost_time = (time.time() - start_time) * 1000 if cost_time > min_cost_ms: print(f"Function {func.__name__} took {cost_time} ms to run.") return result return inner_func return wrapper def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True): """ Get available memory for cuda:gpu_id device. When distributed is True, the available memory is the minimum available memory of all GPUs. """ if device == "cuda": num_gpus = cuda_device_count_stateless() assert gpu_id < num_gpus if torch.cuda.current_device() != gpu_id: print( f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ", "which may cause useless memory allocation for torch CUDA context.", ) if empty_cache: torch.cuda.empty_cache() free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id) elif device == "xpu": num_gpus = torch.xpu.device_count() assert gpu_id < num_gpus if torch.xpu.current_device() != gpu_id: print( f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ", "which may cause useless memory allocation for torch XPU context.", ) if empty_cache: torch.xpu.empty_cache() used_memory = torch.xpu.memory_allocated() total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory free_gpu_memory = total_gpu_memory - used_memory elif device == "hpu": num_gpus = torch.hpu.device_count() assert gpu_id < num_gpus if torch.hpu.current_device() != gpu_id: print( f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ", "which may cause useless memory allocation for torch HPU context.", ) free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info() elif device == "cpu": # TODO: rename the variables in the current function to be not GPU specific free_gpu_memory = psutil.virtual_memory().available if distributed: tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( torch.device(device, gpu_id) ) torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN) free_gpu_memory = tensor.item() return free_gpu_memory / (1 << 30) def is_pin_memory_available() -> bool: return torch.cuda.is_available() _CPU_OFFLOAD_BYTES = 0 _CPU_OFFLOAD_MAX_BYTES = 0 def set_cpu_offload_max_bytes(max_bytes: int) -> None: global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES _CPU_OFFLOAD_BYTES = 0 _CPU_OFFLOAD_MAX_BYTES = max_bytes def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: device = next(module.parameters()).device if device == torch.device("cpu"): return module global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: return module pin_memory = is_pin_memory_available() # offload parameters to CPU # use pin_memory if possible, which helps cudagraph capture speed offloaded_parameters = False for p in module.parameters(): if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: # we use per-parameter offloading # one module might have some parameters offloaded and some not break # `torch.empty_like` does not support `pin_memory` argument cpu_data = torch.empty_strided( size=p.data.size(), stride=p.data.stride(), dtype=p.data.dtype, layout=p.data.layout, device="cpu", pin_memory=pin_memory, ) cpu_data.copy_(p.data) p.data = cpu_data _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() offloaded_parameters = True if offloaded_parameters: original_forward = module.forward def forward(*args, **kwargs): module.forward = original_forward device_state = { # here we blindly call `to(device)` # if the parameter is already on the device, it will be a no-op k: v.to(device, non_blocking=True) for k, v in module.state_dict().items() } output = functional_call(module, device_state, args=args, kwargs=kwargs) module.forward = forward return output module.forward = forward return module class LayerFn(Protocol): def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ... def make_layers( num_hidden_layers: int, layer_fn: LayerFn, prefix: str = "", ) -> Tuple[int, int, torch.nn.ModuleList]: """Make a list of layers with the given layer function""" modules = torch.nn.ModuleList( [ maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix))) for idx in range(num_hidden_layers) ] ) return modules def set_random_seed(seed: int) -> None: """Set the random seed for all libraries.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def is_port_available(port): """Return whether a port is available.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("", port)) s.listen(1) return True except socket.error: return False except OverflowError: return False def decode_video_base64(video_base64): from PIL import Image # Decode the base64 string video_bytes = base64.b64decode(video_base64) # Placeholder for the start indices of each PNG image img_starts = [] frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG")) assert frame_format in [ "PNG", "JPEG", ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'" if frame_format == "PNG": # Find each PNG start signature to isolate images i = 0 while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature # Check if we found the start of a PNG file if ( video_bytes[i] == 0x89 and video_bytes[i + 1] == 0x50 and video_bytes[i + 2] == 0x4E and video_bytes[i + 3] == 0x47 and video_bytes[i + 4] == 0x0D and video_bytes[i + 5] == 0x0A and video_bytes[i + 6] == 0x1A and video_bytes[i + 7] == 0x0A ): img_starts.append(i) i += 8 # Skip the PNG signature else: i += 1 else: # Find each JPEG start (0xFFD8) to isolate images i = 0 while ( i < len(video_bytes) - 1 ): # Adjusted for the length of the JPEG SOI signature # Check if we found the start of a JPEG file if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8: img_starts.append(i) # Move to the next byte to continue searching for the next image start i += 2 else: i += 1 frames = [] for start_idx in img_starts: # Assuming each image is back-to-back, the end of one image is the start of another # The last image goes until the end of the byte string end_idx = ( img_starts[img_starts.index(start_idx) + 1] if img_starts.index(start_idx) + 1 < len(img_starts) else len(video_bytes) ) img_bytes = video_bytes[start_idx:end_idx] # Convert bytes to a PIL Image img = Image.open(BytesIO(img_bytes)) # Convert PIL Image to a NumPy array frame = np.array(img) # Append the frame to the list of frames frames.append(frame) # Ensure there's at least one frame to avoid errors with np.stack if frames: return np.stack(frames, axis=0), img.size else: return np.array([]), ( 0, 0, ) # Return an empty array and size tuple if no frames were found def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray: # Use soundfile here, since librosa use it under the hood, # and librosa will not support audio loading in the future import soundfile as sf from scipy.signal import resample # print(f"loading {audio_file}") # Load audio data if isinstance(audio_file, bytes): audio, original_sr = sf.read(BytesIO(audio_file)) elif audio_file.startswith("data:"): audio_file = audio_file.split(",")[1] audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file))) elif isinstance(audio_file, str): audio, original_sr = sf.read(audio_file) else: raise ValueError(f"Invalid audio format: {audio_file}") # Resample audio if the original sample rate is different from the desired sample rate if original_sr != sr: num_samples = int(len(audio) * float(sr) / original_sr) audio = resample(audio, num_samples) # Convert to mono if requested and audio is stereo if mono and len(audio.shape) > 1: audio = np.mean(audio, axis=1) return audio def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]: image = image_size = None if isinstance(image_file, bytes): image = Image.open(BytesIO(image_file)) elif image_file.startswith("http://") or image_file.startswith("https://"): timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) response = requests.get(image_file, stream=True, timeout=timeout).raw image = Image.open(response) response.close() elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): image = Image.open(image_file) elif image_file.startswith("data:"): image_file = image_file.split(",")[1] image = Image.open(BytesIO(base64.b64decode(image_file))) elif image_file.startswith("video:"): image_file = image_file.replace("video:", "") image, image_size = decode_video_base64(image_file) elif isinstance(image_file, str): image = Image.open(BytesIO(base64.b64decode(image_file))) else: raise ValueError(f"Invalid image: {image}") return image, image_size def suppress_other_loggers(): try: from vllm.logger import logger as vllm_default_logger except ImportError: return vllm_default_logger.setLevel(logging.WARN) logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel( logging.WARN ) logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel( logging.WARN ) logging.getLogger("vllm.config").setLevel(logging.ERROR) warnings.filterwarnings( "ignore", category=UserWarning, message="The given NumPy array is not writable" ) def assert_pkg_version(pkg: str, min_version: str, message: str): try: installed_version = version(pkg) if pkg_version.parse(installed_version) < pkg_version.parse(min_version): raise Exception( f"{pkg} is installed with version {installed_version}, which " f"is less than the minimum required version {min_version}. " + message ) except PackageNotFoundError: raise Exception( f"{pkg} with minimum required version {min_version} is not installed. " + message ) def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None): """Kill the process and all its child processes.""" # Remove sigchld handler to avoid spammy logs. if threading.current_thread() is threading.main_thread(): signal.signal(signal.SIGCHLD, signal.SIG_DFL) if parent_pid is None: parent_pid = os.getpid() include_parent = False try: itself = psutil.Process(parent_pid) except psutil.NoSuchProcess: return children = itself.children(recursive=True) for child in children: if child.pid == skip_pid: continue try: child.kill() except psutil.NoSuchProcess: pass if include_parent: try: if parent_pid == os.getpid(): itself.kill() sys.exit(0) itself.kill() # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), # so we send an additional signal to kill them. itself.send_signal(signal.SIGQUIT) except psutil.NoSuchProcess: pass def monkey_patch_p2p_access_check(): """ Monkey patch the slow p2p access check. NOTE: We assume the p2p access is always allowed, which can be wrong for some setups. """ import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) # Suppress the warnings from this delete function when using sglang.bench_one_batch from sglang.srt.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce, ) setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None) def monkey_patch_vllm_gguf_config(): try: from vllm.model_executor.layers.quantization.gguf import ( GGUFConfig, GGUFEmbeddingMethod, GGUFLinearMethod, ) except ImportError: return from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding def get_quant_method_with_embedding_replaced( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): return GGUFLinearMethod(self) elif isinstance(layer, VocabParallelEmbedding): # patch to own VocabParallelEmbedding return GGUFEmbeddingMethod(self) return None setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced) def maybe_set_triton_cache_manager() -> None: """Set environment variable to tell Triton to use a custom cache manager""" cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) if cache_manger is None: manager = "sglang.srt.utils:CustomCacheManager" logger.debug("Setting Triton cache manager to: %s", manager) os.environ["TRITON_CACHE_MANAGER"] = manager class CustomCacheManager(FileCacheManager): # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py def __init__(self, key, override=False, dump=False): self.key = key self.lock_path = None if dump: self.cache_dir = default_dump_dir() self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) elif override: self.cache_dir = default_override_dir() self.cache_dir = os.path.join(self.cache_dir, self.key) else: # create cache directory if it doesn't exist self.cache_dir = ( os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() ) if self.cache_dir: self.cache_dir = f"{self.cache_dir}_{os.getpid()}" self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) else: raise RuntimeError("Could not create or locate cache dir") def set_ulimit(target_soft_limit=65535): resource_type = resource.RLIMIT_NOFILE current_soft, current_hard = resource.getrlimit(resource_type) if current_soft < target_soft_limit: try: resource.setrlimit(resource_type, (target_soft_limit, current_hard)) except ValueError as e: logger.warning(f"Fail to set RLIMIT_NOFILE: {e}") def add_api_key_middleware(app, api_key: str): @app.middleware("http") async def authentication(request, call_next): if request.method == "OPTIONS": return await call_next(request) #if request.url.path.startswith("/health"): # return await call_next(request) EXEMPT_PATHS = { "/health", "/api/tags" } if request.url.path in EXEMPT_PATHS: return await call_next(request) if request.headers.get("Authorization") != "Bearer " + api_key: return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401) return await call_next(request) def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str): if get_bool_env_var("SGLANG_USE_MODELSCOPE"): if not os.path.exists(model_path): from modelscope import snapshot_download model_path = snapshot_download(model_path) tokenizer_path = snapshot_download( tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"] ) return model_path, tokenizer_path def configure_logger(server_args, prefix: str = ""): if SGLANG_LOGGING_CONFIG_PATH := os.getenv("SGLANG_LOGGING_CONFIG_PATH"): if not os.path.exists(SGLANG_LOGGING_CONFIG_PATH): raise Exception( "Setting SGLANG_LOGGING_CONFIG_PATH from env with " f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!" ) with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file: custom_config = json.loads(file.read()) logging.config.dictConfig(custom_config) return format = f"[%(asctime)s{prefix}] %(message)s" # format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s" logging.basicConfig( level=getattr(logging, server_args.log_level.upper()), format=format, datefmt="%Y-%m-%d %H:%M:%S", force=True, ) # source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9 def replace_submodule( model: nn.Module, module_name: str, new_module: nn.Module ) -> nn.Module: """Replace a submodule in a model with a new module.""" parent = model.get_submodule(".".join(module_name.split(".")[:-1])) target_name = module_name.split(".")[-1] setattr(parent, target_name, new_module) return new_module def set_weight_attrs( weight: torch.Tensor, weight_attrs: Optional[Dict[str, Any]], ): """Set attributes on a weight tensor. This method is used to set attributes on a weight tensor. This method will not overwrite existing attributes. Args: weight: The weight tensor. weight_attrs: A dictionary of attributes to set on the weight tensor. """ if weight_attrs is None: return for key, value in weight_attrs.items(): assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}" setattr(weight, key, value) def broadcast_pyobj( data: List[Any], rank: int, dist_group: Optional[torch.distributed.ProcessGroup] = None, src: int = 0, ): """Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" if rank == 0: if len(data) == 0: tensor_size = torch.tensor([0], dtype=torch.long) dist.broadcast(tensor_size, src=src, group=dist_group) else: serialized_data = pickle.dumps(data) size = len(serialized_data) tensor_data = torch.ByteTensor( np.frombuffer(serialized_data, dtype=np.uint8) ) tensor_size = torch.tensor([size], dtype=torch.long) dist.broadcast(tensor_size, src=src, group=dist_group) dist.broadcast(tensor_data, src=src, group=dist_group) return data else: tensor_size = torch.tensor([0], dtype=torch.long) dist.broadcast(tensor_size, src=src, group=dist_group) size = tensor_size.item() if size == 0: return [] tensor_data = torch.empty(size, dtype=torch.uint8) dist.broadcast(tensor_data, src=src, group=dist_group) serialized_data = bytes(tensor_data.cpu().numpy()) data = pickle.loads(serialized_data) return data step_counter = 0 def pytorch_profile(name, func, *args, data_size=-1): """ Args: name (string): the name of recorded function. func: the function to be profiled. args: the arguments of the profiled function. data_size (int): some measurement of the computation complexity. Usually, it could be the batch size. """ global step_counter os.makedirs("trace", exist_ok=True) with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), # on_trace_ready=tensorboard_trace_handler('./log_dir'), record_shapes=True, profile_memory=True, with_stack=True, ) as prof: with record_function(name): with open(f"trace/size_{step_counter}.json", "w") as f: json.dump({"size": data_size}, f) result = func(*args) prof.export_chrome_trace(f"trace/{name}_{step_counter}.json") step_counter += 1 return result def get_zmq_socket( context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool ): mem = psutil.virtual_memory() total_mem = mem.total / 1024**3 available_mem = mem.available / 1024**3 if total_mem > 32 and available_mem > 16: buf_size = int(0.5 * 1024**3) else: buf_size = -1 socket = context.socket(socket_type) def set_send_opt(): socket.setsockopt(zmq.SNDHWM, 0) socket.setsockopt(zmq.SNDBUF, buf_size) def set_recv_opt(): socket.setsockopt(zmq.RCVHWM, 0) socket.setsockopt(zmq.RCVBUF, buf_size) if socket_type == zmq.PUSH: set_send_opt() elif socket_type == zmq.PULL: set_recv_opt() elif socket_type == zmq.DEALER: set_send_opt() set_recv_opt() else: raise ValueError(f"Unsupported socket type: {socket_type}") if bind: socket.bind(endpoint) else: socket.connect(endpoint) return socket def dump_to_file(dirpath, name, value): from sglang.srt.distributed import get_tensor_model_parallel_rank if get_tensor_model_parallel_rank() != 0: return os.makedirs(dirpath, exist_ok=True) if value.dtype is torch.bfloat16: value = value.float() value = value.cpu().numpy() output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy") logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}") np.save(output_filename, value) def is_triton_3(): return triton.__version__.startswith("3.") def maybe_torch_compile(*args, **kwargs): """ torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax. Therefore, we disable it here. """ def decorator(func): if is_triton_3(): return torch.compile(*args, **kwargs)(func) return func return decorator def delete_directory(dirpath): try: # This will remove the directory and all its contents shutil.rmtree(dirpath) except OSError as e: print(f"Warning: {dirpath} : {e.strerror}") # Temporary directory for prometheus multiprocess mode # Cleaned up automatically when this object is garbage collected prometheus_multiproc_dir: tempfile.TemporaryDirectory def set_prometheus_multiproc_dir(): # Set prometheus multiprocess directory # sglang uses prometheus multiprocess mode # we need to set this before importing prometheus_client # https://prometheus.github.io/client_python/multiprocess/ global prometheus_multiproc_dir if "PROMETHEUS_MULTIPROC_DIR" in os.environ: logger.debug("User set PROMETHEUS_MULTIPROC_DIR detected.") prometheus_multiproc_dir = tempfile.TemporaryDirectory( dir=os.environ["PROMETHEUS_MULTIPROC_DIR"] ) else: prometheus_multiproc_dir = tempfile.TemporaryDirectory() os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}") def add_prometheus_middleware(app): # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess registry = CollectorRegistry() multiprocess.MultiProcessCollector(registry) metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) # Workaround for 307 Redirect for /metrics metrics_route.path_regex = re.compile("^/metrics(?P.*)$") app.routes.append(metrics_route) def bind_port(port): """Bind to a specific port, assuming it's available.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allows address reuse sock.bind(("", port)) sock.listen(1) return sock def get_amdgpu_memory_capacity(): try: # Run rocm-smi and capture the output result = subprocess.run( [ "rocminfo | grep 'gfx' -A 100 | grep 'Pool 1' -A 5 | grep 'Size:' | awk '{print $2}'" ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, text=True, ) if result.returncode != 0: raise RuntimeError(f"rocm-smi error: {result.stderr.strip()}") # Parse the output to extract memory values in MiB memory_values = [ float(mem.split("(")[0].strip()) / 1024 for mem in result.stdout.strip().split("\n") ] if not memory_values: raise ValueError("No GPU memory values found.") # Return the minimum memory value return min(memory_values) except FileNotFoundError: raise RuntimeError( "rocm-smi not found. Ensure AMD ROCm drivers are installed and accessible." ) def get_device_sm(): if torch.cuda.is_available(): major, minor = torch.cuda.get_device_capability() return major * 10 + minor return 0 def get_nvgpu_memory_capacity(): try: # Run nvidia-smi and capture the output result = subprocess.run( ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) if result.returncode != 0: raise RuntimeError(f"nvidia-smi error: {result.stderr.strip()}") # Parse the output to extract memory values memory_values = [ float(mem) for mem in result.stdout.strip().split("\n") if re.match(r"^\d+(\.\d+)?$", mem.strip()) ] if not memory_values: raise ValueError("No GPU memory values found.") # Return the minimum memory value return min(memory_values) except FileNotFoundError: raise RuntimeError( "nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible." ) def get_hpu_memory_capacity(): try: # Run hl-smi and capture the output result = subprocess.run( ["hl-smi --query | grep 'Total'"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, text=True, ) if result.returncode != 0: raise RuntimeError(f"hl-smi error: {result.stderr.strip()}") # Parse the output to extract memory values in MiB memory_values = [ float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n") ] if not memory_values: raise ValueError("No GPU memory values found.") # Return the minimum memory value return min(memory_values) except FileNotFoundError: raise RuntimeError( "hl-smi not found. Ensure Habana drivers are installed and accessible." ) # Copy from pytorch and OpenRLHF to allow creating multiple main groups. # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py # https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py def init_custom_process_group( backend=None, init_method=None, timeout=None, world_size=-1, rank=-1, store=None, group_name=None, pg_options=None, ): from torch.distributed.distributed_c10d import ( Backend, PrefixStore, _new_process_group_helper, _world, default_pg_timeout, rendezvous, ) assert (store is None) or ( init_method is None ), "Cannot specify both init_method and store." if store is not None: assert world_size > 0, "world_size must be positive if using store" assert rank >= 0, "rank must be non-negative if using store" elif init_method is None: init_method = "env://" if backend: backend = Backend(backend) else: backend = Backend("undefined") if timeout is None: timeout = default_pg_timeout # backward compatible API if store is None: rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) store, rank, world_size = next(rendezvous_iterator) store.set_timeout(timeout) # Use a PrefixStore to avoid accidental overrides of keys used by # different systems (e.g. RPC) in case the store is multi-tenant. store = PrefixStore(group_name, store) # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 # We need to determine the appropriate parameter name based on PyTorch version pg_options_param_name = ( "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" ) pg, _ = _new_process_group_helper( world_size, rank, [], backend, store, group_name=group_name, **{pg_options_param_name: pg_options}, timeout=timeout, ) _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} return pg def crash_on_warnings(): # Crash on warning if we are running CI tests return get_bool_env_var("SGLANG_IS_IN_CI") def print_warning_once(msg: str) -> None: # Set the stacklevel to 2 to print the caller's line info logger.warning(msg, stacklevel=2) def get_device_name(device_id: int = 0) -> str: if hasattr(torch, "cuda") and torch.cuda.is_available(): return torch.cuda.get_device_name(device_id) if hasattr(torch, "xpu") and torch.xpu.is_available(): return torch.xpu.get_device_name(device_id) if hasattr(torch, "hpu") and torch.hpu.is_available(): return torch.hpu.get_device_name(device_id) @lru_cache(maxsize=1) def is_habana_available() -> bool: return find_spec("habana_frameworks") is not None @lru_cache(maxsize=8) def get_device(device_id: Optional[int] = None) -> str: if hasattr(torch, "cuda") and torch.cuda.is_available(): if device_id is None: return "cuda" return "cuda:{}".format(device_id) if hasattr(torch, "xpu") and torch.xpu.is_available(): if device_id == None: return "xpu" return "xpu:{}".format(device_id) if is_habana_available(): try: import habana_frameworks.torch.hpu if torch.hpu.is_available(): if device_id == None: return "hpu" return "hpu:{}".format(device_id) except ImportError as e: raise ImportError( "Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'." ) raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.") @lru_cache(maxsize=1) def get_device_count() -> int: if hasattr(torch, "cuda") and torch.cuda.is_available(): try: return torch.cuda.device_count() except RuntimeError: return 0 if hasattr(torch, "xpu") and torch.xpu.is_available(): try: return torch.xpu.device_count() except RuntimeError: return 0 if is_habana_available(): try: import habana_frameworks.torch.hpu if torch.hpu.is_available(): return torch.hpu.device_count() except (ImportError, RuntimeError): return 0 return 0 # No accelerators available def get_device_core_count(device_id: int = 0) -> int: if hasattr(torch, "cuda") and torch.cuda.is_available(): return torch.cuda.get_device_properties(device_id).multi_processor_count return 0 def get_device_capability(device_id: int = 0) -> Tuple[int, int]: major, minor = None, None if hasattr(torch, "cuda") and torch.cuda.is_available(): major, minor = torch.cuda.get_device_capability(device_id) if hasattr(torch, "xpu") and torch.xpu.is_available(): major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split( "." ) major, minor = int(major), int(minor) if hasattr(torch, "hpu") and torch.hpu.is_available(): try: # TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now. # Update this once the support is available. # major, minor = torch.hpu.get_device_capability(device_id) major, minor = None, None except Exception as e: raise RuntimeError( f"An error occurred while getting device capability of hpu: {e}." ) from e return major, minor def get_compiler_backend() -> str: if hasattr(torch, "hpu") and torch.hpu.is_available(): return "hpu_backend" return "inductor" sglang_lib = Library("sglang", "FRAGMENT") # noqa # Some backends use pytorch version < 2.4.0 which doesn't # support `torch.library.custom_op`. def supports_custom_op() -> bool: return hasattr(torch.library, "custom_op") def direct_register_custom_op( op_name: str, op_func: Callable, mutates_args: List[str], fake_impl: Optional[Callable] = None, target_lib: Optional[Library] = None, ): """ `torch.library.custom_op` can have significant overhead because it needs to consider complicated dispatching logic. This function directly registers a custom op and dispatches it to the CUDA backend. See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 for more details. By default, the custom op is registered to the vLLM library. If you want to register it to a different library, you can pass the library object to the `target_lib` argument. IMPORTANT: the lifetime of the operator is tied to the lifetime of the library object. If you want to bind the operator to a different library, make sure the library object is alive when the operator is used. """ import torch.library if hasattr(torch.library, "infer_schema"): schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) else: # for pytorch 2.4 import torch._custom_op.impl schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or sglang_lib my_lib.define(op_name + schema_str) my_lib.impl(op_name, op_func, "CUDA") if fake_impl is not None: my_lib._register_fake(op_name, fake_impl) def set_gpu_proc_affinity( tp_size: int, nnodes: int, gpu_id: int, ): # current process pid = os.getpid() p = psutil.Process(pid) tp_size_per_node = tp_size // nnodes # total physical cores total_pcores = psutil.cpu_count(logical=False) # physical cores per TP (N.B. more Cores than GPUs on node) num_cores_bind = total_pcores // tp_size_per_node # able to handle multiple DP per node start_cpu_id = (gpu_id * num_cores_bind) % total_pcores end_cpu_id = start_cpu_id + num_cores_bind if psutil.cpu_count() != psutil.cpu_count(logical=False): # HT on lower_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)] upper_cpu_ids = [id + total_pcores for id in range(start_cpu_id, end_cpu_id)] bind_cpu_ids = list(itertools.chain(lower_cpu_ids, upper_cpu_ids)) else: # HT off bind_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)] # set cpu_affinity to current process p.cpu_affinity(bind_cpu_ids) logger.info(f"Process {pid} gpu_id {gpu_id} is running on CPUs: {p.cpu_affinity()}") @lru_cache(maxsize=2) def disable_request_logging() -> bool: return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING") @lru_cache(maxsize=8) def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int: # Note: cuda_visible_devices is not used, but we keep it as an argument for # LRU Cache purposes. # Code below is based on # https://github.com/pytorch/pytorch/blob/ # c1cd946818442aca8c7f812b16d187ce1586c3bc/ # torch/cuda/__init__.py#L831C1-L831C17 import torch.version if not torch.cuda._is_compiled(): return 0 if is_hip(): # ROCm uses amdsmi instead of nvml for stateless device count # This requires a sufficiently modern version of Torch 2.4.0 raw_count = ( torch.cuda._device_count_amdsmi() if (hasattr(torch.cuda, "_device_count_amdsmi")) else -1 ) else: raw_count = torch.cuda._device_count_nvml() r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count return r # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/utils.py def cuda_device_count_stateless() -> int: """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES at the time of call. This should be used instead of torch.cuda.device_count() unless CUDA_VISIBLE_DEVICES has already been set to the desired value.""" # This can be removed and simply replaced with torch.cuda.get_device_count # after https://github.com/pytorch/pytorch/pull/122815 is released. return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None)) def dataclass_to_string_truncated( data, max_length=2048, skip_names: Optional[Set[str]] = None ): if skip_names is None: skip_names = set() if isinstance(data, str): if len(data) > max_length: half_length = max_length // 2 return f"{repr(data[:half_length])} ... {repr(data[-half_length:])}" else: return f"{repr(data)}" elif isinstance(data, (list, tuple)): if len(data) > max_length: half_length = max_length // 2 return str(data[:half_length]) + " ... " + str(data[-half_length:]) else: return str(data) elif isinstance(data, dict): return ( "{" + ", ".join( f"'{k}': {dataclass_to_string_truncated(v, max_length)}" for k, v in data.items() if k not in skip_names ) + "}" ) elif dataclasses.is_dataclass(data): fields = dataclasses.fields(data) return ( f"{data.__class__.__name__}(" + ", ".join( f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" for f in fields if f.name not in skip_names ) + ")" ) else: return str(data) def permute_weight(x: torch.Tensor) -> torch.Tensor: b_ = x.shape[0] n_ = x.shape[1] k_ = x.shape[2] x_ = x if x.dtype == torch.bfloat16 or x.dtype == torch.float16: x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8) elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8: x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16) else: # return x_ x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 8), 2, 4) x_ = x_.permute(0, 1, 3, 4, 2, 5) x_ = x_.contiguous() x_ = x_.view(*x.shape) return x_ class MultiprocessingSerializer: @staticmethod def serialize(obj): buf = io.BytesIO() ForkingPickler(buf).dump(obj) buf.seek(0) return buf.read() @staticmethod def deserialize(data): return ForkingPickler.loads(data) def debug_timing(func): # todo: replace with a more organized instrumentation def wrapper(*args, **kwargs): if logger.isEnabledFor(logging.DEBUG): tic = torch.cuda.Event(enable_timing=True) toc = torch.cuda.Event(enable_timing=True) tic.record() result = func(*args, **kwargs) toc.record() toc.synchronize() # Wait for the function to complete without synchronizing all ops on the GPU elapsed = tic.elapsed_time(toc) indices = kwargs.get("indices", args[1] if len(args) > 1 else None) num_tokens = len(indices) if indices is not None else 0 throughput = num_tokens / elapsed * 1000 if elapsed > 0 else 0 logger.debug( f"Transfer time: {elapsed} ms, throughput: {throughput} tokens/s" ) return result else: return func(*args, **kwargs) return wrapper def nullable_str(val: str): if not val or val == "None": return None return val def pyspy_dump_schedulers(): """py-spy dump on all scheduler in a local node.""" try: pid = psutil.Process().pid # Command to run py-spy with the PID cmd = f"py-spy dump --pid {pid}" result = subprocess.run( cmd, shell=True, capture_output=True, text=True, check=True ) logger.error(f"Pyspy dump for PID {pid}:\n{result.stdout}") except subprocess.CalledProcessError as e: logger.error(f"Pyspy failed to dump PID {pid}. Error: {e.stderr}") def kill_itself_when_parent_died(): if sys.platform == "linux": # sigkill this process when parent worker manager dies PR_SET_PDEATHSIG = 1 libc = ctypes.CDLL("libc.so.6") libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL) else: logger.warning("kill_itself_when_parent_died is only supported in linux.") def set_uvicorn_logging_configs(): from uvicorn.config import LOGGING_CONFIG LOGGING_CONFIG["formatters"]["default"][ "fmt" ] = "[%(asctime)s] %(levelprefix)s %(message)s" LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S" LOGGING_CONFIG["formatters"]["access"][ "fmt" ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" def get_ip() -> str: # SGLANG_HOST_IP env can be ignore host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "") if host_ip: return host_ip # IP is not set, try to get it from the network interface # try ipv4 s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable return s.getsockname()[0] except Exception: pass # try ipv6 try: s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) # Google's public DNS server, see # https://developers.google.com/speed/public-dns/docs/using#addresses s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable return s.getsockname()[0] except Exception: pass warnings.warn( "Failed to get the IP address, using 0.0.0.0 by default." "The value can be set by the environment variable" " SGLANG_HOST_IP or HOST_IP.", stacklevel=2, ) return "0.0.0.0" def get_open_port() -> int: port = os.getenv("SGLANG_PORT") if port is not None: port = int(port) while True: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", port)) return port except OSError: port += 1 # Increment port number if already in use logger.info("Port %d is already in use, trying port %d", port - 1, port) # try ipv4 try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) return s.getsockname()[1] except OSError: # try ipv6 with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: s.bind(("", 0)) return s.getsockname()[1] def is_valid_ipv6_address(address: str) -> bool: try: ipaddress.IPv6Address(address) return True except ValueError: return False def configure_ipv6(dist_init_addr): addr = dist_init_addr end = addr.find("]") if end == -1: raise ValueError("invalid IPv6 address format: missing ']'") host = addr[: end + 1] # this only validates the address without brackets: we still need the below checks. # if it's invalid, immediately raise an error so we know it's not formatting issues. if not is_valid_ipv6_address(host[1:end]): raise ValueError(f"invalid IPv6 address: {host}") port_str = None if len(addr) > end + 1: if addr[end + 1] == ":": port_str = addr[end + 2 :] else: raise ValueError("received IPv6 address format: expected ':' after ']'") if not port_str: raise ValueError( "a port must be specified in IPv6 address (format: [ipv6]:port)" ) try: port = int(port_str) except ValueError: raise ValueError(f"invalid port in IPv6 address: '{port_str}'") return port, host def rank0_print(msg: str): from sglang.srt.distributed import get_tensor_model_parallel_rank if get_tensor_model_parallel_rank() == 0: print(msg, flush=True) def get_cuda_version(): if torch.version.cuda: return tuple(map(int, torch.version.cuda.split("."))) return (0, 0) def launch_dummy_health_check_server(host, port): import uvicorn from fastapi import FastAPI, Response app = FastAPI() @app.get("/health") async def health(): """Check the health of the http server.""" return Response(status_code=200) @app.get("/health_generate") async def health_generate(): """Check the health of the http server.""" return Response(status_code=200) uvicorn.run( app, host=host, port=port, timeout_keep_alive=5, loop="uvloop", ) def create_checksum(directory: str): raise NotImplementedError() def set_cuda_arch(): if is_flashinfer_available(): capability = torch.cuda.get_device_capability() arch = f"{capability[0]}.{capability[1]}" os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" def next_power_of_2(n: int): return 1 << (n - 1).bit_length() if n > 0 else 1 setattr(triton, "next_power_of_2", next_power_of_2) @contextmanager def empty_context(*args, **kwargs): try: # Setup code goes here yield finally: # Cleanup code goes here pass def add_prefix(name: str, prefix: str) -> str: """Add a weight path prefix to a module name. Args: name: base module name. prefix: weight prefix str to added to the front of `name` concatenated with `.`. Returns: The string `prefix.name` if prefix is non-empty, otherwise just `name`. """ return name if not prefix else f"{prefix}.{name}" def is_remote_url(url: Union[str, Path]) -> bool: """ Check if the URL is a remote URL of the format: ://:/ """ if isinstance(url, Path): return False pattern = r"(.+)://(.*)" m = re.match(pattern, url) return m is not None def parse_connector_type(url: str) -> str: """ Parse the connector type from the URL of the format: :// """ pattern = r"(.+)://(.*)" m = re.match(pattern, url) if m is None: return "" return m.group(1)