sglang0.4.5.post1/python/sglang/srt/connector/base_connector.py

113 lines
2.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import os
import shutil
import signal
import tempfile
from abc import ABC, abstractmethod
from typing import Generator, List, Optional, Tuple
import torch
class BaseConnector(ABC):
"""
For fs connector such as s3:
<connector_type>://<path>/<filename>
For kv connector such as redis:
<connector_type>://<host>:<port>/<model_name>/keys/<key>
<connector_type://<host>:<port>/<model_name>/files/<filename>
"""
def __init__(self, url: str, device: torch.device = "cpu"):
self.url = url
self.device = device
self.closed = False
self.local_dir = tempfile.mkdtemp()
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self._close_by_signal(existing_handler))
def get_local_dir(self):
return self.local_dir
@abstractmethod
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, torch.Tensor], None, None]:
raise NotImplementedError()
@abstractmethod
def pull_files(
self,
allow_pattern: Optional[List[str]] = None,
ignore_pattern: Optional[List[str]] = None,
) -> None:
raise NotImplementedError()
def close(self):
if self.closed:
return
self.closed = True
if os.path.exists(self.local_dir):
shutil.rmtree(self.local_dir)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __del__(self):
self.close()
def _close_by_signal(self, existing_handler=None):
def new_handler(signum, frame):
self.close()
if existing_handler:
existing_handler(signum, frame)
return new_handler
class BaseKVConnector(BaseConnector):
@abstractmethod
def get(self, key: str) -> Optional[torch.Tensor]:
raise NotImplementedError()
@abstractmethod
def getstr(self, key: str) -> Optional[str]:
raise NotImplementedError()
@abstractmethod
def set(self, key: str, obj: torch.Tensor) -> None:
raise NotImplementedError()
@abstractmethod
def setstr(self, key: str, obj: str) -> None:
raise NotImplementedError()
@abstractmethod
def list(self, prefix: str) -> List[str]:
raise NotImplementedError()
class BaseFileConnector(BaseConnector):
"""
List full file names from remote fs path and filter by allow pattern.
Args:
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full paths allowed by the pattern
"""
@abstractmethod
def glob(self, allow_pattern: str) -> List[str]:
raise NotImplementedError()