sglang0.4.5.post1/python/sglang/srt/connector/redis.py

86 lines
2.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Generator, List, Optional, Tuple
from urllib.parse import urlparse
import torch
from sglang.srt.connector import BaseKVConnector
from sglang.srt.connector.serde import create_serde
from sglang.srt.connector.utils import pull_files_from_db
logger = logging.getLogger(__name__)
class RedisConnector(BaseKVConnector):
def __init__(self, url: str, device: torch.device = "cpu"):
import redis
super().__init__(url, device)
parsed_url = urlparse(url)
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
self.model_name = parsed_url.path.lstrip("/")
# TODO: more serde options
self.s, self.d = create_serde("safe")
def get(self, key: str) -> Optional[torch.Tensor]:
val = self.connection.get(key)
if val is None:
logger.error("Key %s not found", key)
return None
return self.d.from_bytes(val)
def getstr(self, key: str) -> Optional[str]:
val = self.connection.get(key)
if val is None:
logger.error("Key %s not found", key)
return None
return val.decode("utf-8")
def set(self, key: str, tensor: torch.Tensor) -> None:
assert tensor is not None
self.connection.set(key, self.s.to_bytes(tensor))
def setstr(self, key: str, obj: str) -> None:
self.connection.set(key, obj)
def list(self, prefix: str) -> List[str]:
cursor = 0
all_keys: List[bytes] = []
while True:
ret: Tuple[int, List[bytes]] = self.connection.scan(
cursor=cursor, match=f"{prefix}*"
) # type: ignore
cursor, keys = ret
all_keys.extend(keys)
if cursor == 0:
break
return [key.decode("utf-8") for key in all_keys]
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, bytes], None, None]:
keys = self.list(f"{self.model_name}/keys/rank_{rank}/")
for key in keys:
val = self.get(key)
key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/")
yield key, val
def pull_files(
self,
allow_pattern: Optional[List[str]] = None,
ignore_pattern: Optional[List[str]] = None,
) -> None:
pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)
def close(self):
self.connection.close()
super().close()