86 lines
2.5 KiB
Python
86 lines
2.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import logging
|
|
from typing import Generator, List, Optional, Tuple
|
|
from urllib.parse import urlparse
|
|
|
|
import torch
|
|
|
|
from sglang.srt.connector import BaseKVConnector
|
|
from sglang.srt.connector.serde import create_serde
|
|
from sglang.srt.connector.utils import pull_files_from_db
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RedisConnector(BaseKVConnector):
|
|
|
|
def __init__(self, url: str, device: torch.device = "cpu"):
|
|
import redis
|
|
|
|
super().__init__(url, device)
|
|
parsed_url = urlparse(url)
|
|
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
|
|
self.model_name = parsed_url.path.lstrip("/")
|
|
# TODO: more serde options
|
|
self.s, self.d = create_serde("safe")
|
|
|
|
def get(self, key: str) -> Optional[torch.Tensor]:
|
|
val = self.connection.get(key)
|
|
|
|
if val is None:
|
|
logger.error("Key %s not found", key)
|
|
return None
|
|
|
|
return self.d.from_bytes(val)
|
|
|
|
def getstr(self, key: str) -> Optional[str]:
|
|
val = self.connection.get(key)
|
|
if val is None:
|
|
logger.error("Key %s not found", key)
|
|
return None
|
|
|
|
return val.decode("utf-8")
|
|
|
|
def set(self, key: str, tensor: torch.Tensor) -> None:
|
|
assert tensor is not None
|
|
self.connection.set(key, self.s.to_bytes(tensor))
|
|
|
|
def setstr(self, key: str, obj: str) -> None:
|
|
self.connection.set(key, obj)
|
|
|
|
def list(self, prefix: str) -> List[str]:
|
|
cursor = 0
|
|
all_keys: List[bytes] = []
|
|
|
|
while True:
|
|
ret: Tuple[int, List[bytes]] = self.connection.scan(
|
|
cursor=cursor, match=f"{prefix}*"
|
|
) # type: ignore
|
|
cursor, keys = ret
|
|
all_keys.extend(keys)
|
|
if cursor == 0:
|
|
break
|
|
|
|
return [key.decode("utf-8") for key in all_keys]
|
|
|
|
def weight_iterator(
|
|
self, rank: int = 0
|
|
) -> Generator[Tuple[str, bytes], None, None]:
|
|
keys = self.list(f"{self.model_name}/keys/rank_{rank}/")
|
|
for key in keys:
|
|
val = self.get(key)
|
|
key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/")
|
|
yield key, val
|
|
|
|
def pull_files(
|
|
self,
|
|
allow_pattern: Optional[List[str]] = None,
|
|
ignore_pattern: Optional[List[str]] = None,
|
|
) -> None:
|
|
pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)
|
|
|
|
def close(self):
|
|
self.connection.close()
|
|
super().close()
|