sglang0.4.5.post1/python/sglang/srt/connector/s3.py

123 lines
3.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import fnmatch
import os
from pathlib import Path
from typing import Generator, Optional, Tuple
import torch
from sglang.srt.connector import BaseFileConnector
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
return [
path
for path in paths
if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
return [
path
for path in paths
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def list_files(
s3,
path: str,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None,
) -> tuple[str, str, list[str]]:
"""
List files from S3 path and filter by pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
Returns:
tuple[str, str, list[str]]: A tuple where:
- The first element is the bucket name
- The second element is string represent the bucket
and the prefix as a dir like string
- The third element is a list of files allowed or
disallowed by pattern
"""
parts = path.removeprefix("s3://").split("/")
prefix = "/".join(parts[1:])
bucket_name = parts[0]
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
paths = [obj["Key"] for obj in objects.get("Contents", [])]
paths = _filter_ignore(paths, ["*/"])
if allow_pattern is not None:
paths = _filter_allow(paths, allow_pattern)
if ignore_pattern is not None:
paths = _filter_ignore(paths, ignore_pattern)
return bucket_name, prefix, paths
class S3Connector(BaseFileConnector):
def __init__(self, url: str) -> None:
import boto3
super().__init__(url)
self.client = boto3.client("s3")
def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]:
bucket_name, _, paths = list_files(
self.client, path=self.url, allow_pattern=allow_pattern
)
return [f"s3://{bucket_name}/{path}" for path in paths]
def pull_files(
self,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None,
) -> None:
"""
Pull files from S3 storage into the temporary directory.
Args:
s3_model_path: The S3 path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
bucket_name, base_dir, files = list_files(
self.client, self.url, allow_pattern, ignore_pattern
)
if len(files) == 0:
return
for file in files:
destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir))
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
self.client.download_file(bucket_name, file, destination_file)
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, torch.Tensor], None, None]:
from sglang.srt.model_loader.weight_utils import (
runai_safetensors_weights_iterator,
)
# only support safetensor files now
hf_weights_files = self.glob(allow_pattern=["*.safetensors"])
return runai_safetensors_weights_iterator(hf_weights_files)
def close(self):
self.client.close()
super().close()