# SPDX-License-Identifier: Apache-2.0 import logging from typing import Generator, List, Optional, Tuple from urllib.parse import urlparse import torch from sglang.srt.connector import BaseKVConnector from sglang.srt.connector.serde import create_serde from sglang.srt.connector.utils import pull_files_from_db logger = logging.getLogger(__name__) class RedisConnector(BaseKVConnector): def __init__(self, url: str, device: torch.device = "cpu"): import redis super().__init__(url, device) parsed_url = urlparse(url) self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port) self.model_name = parsed_url.path.lstrip("/") # TODO: more serde options self.s, self.d = create_serde("safe") def get(self, key: str) -> Optional[torch.Tensor]: val = self.connection.get(key) if val is None: logger.error("Key %s not found", key) return None return self.d.from_bytes(val) def getstr(self, key: str) -> Optional[str]: val = self.connection.get(key) if val is None: logger.error("Key %s not found", key) return None return val.decode("utf-8") def set(self, key: str, tensor: torch.Tensor) -> None: assert tensor is not None self.connection.set(key, self.s.to_bytes(tensor)) def setstr(self, key: str, obj: str) -> None: self.connection.set(key, obj) def list(self, prefix: str) -> List[str]: cursor = 0 all_keys: List[bytes] = [] while True: ret: Tuple[int, List[bytes]] = self.connection.scan( cursor=cursor, match=f"{prefix}*" ) # type: ignore cursor, keys = ret all_keys.extend(keys) if cursor == 0: break return [key.decode("utf-8") for key in all_keys] def weight_iterator( self, rank: int = 0 ) -> Generator[Tuple[str, bytes], None, None]: keys = self.list(f"{self.model_name}/keys/rank_{rank}/") for key in keys: val = self.get(key) key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/") yield key, val def pull_files( self, allow_pattern: Optional[List[str]] = None, ignore_pattern: Optional[List[str]] = None, ) -> None: pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern) def close(self): self.connection.close() super().close()