515 lines
16 KiB
Python
515 lines
16 KiB
Python
"""Common utilities"""
|
|
|
|
import base64
|
|
import importlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import signal
|
|
import socket
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import traceback
|
|
import urllib.request
|
|
import weakref
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from io import BytesIO
|
|
from json import dumps
|
|
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
|
|
|
import numpy as np
|
|
import requests
|
|
from IPython.display import HTML, display
|
|
from pydantic import BaseModel
|
|
from tqdm import tqdm
|
|
|
|
from sglang.srt.utils import kill_process_tree
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str:
|
|
"""Convert a JSON schema to a string.
|
|
Parameters
|
|
----------
|
|
json_schema
|
|
The JSON schema.
|
|
Returns
|
|
-------
|
|
str
|
|
The JSON schema converted to a string.
|
|
Raises
|
|
------
|
|
ValueError
|
|
If the schema is not a dictionary, a string or a Pydantic class.
|
|
"""
|
|
if isinstance(json_schema, dict):
|
|
schema_str = json.dumps(json_schema)
|
|
elif isinstance(json_schema, str):
|
|
schema_str = json_schema
|
|
elif issubclass(json_schema, BaseModel):
|
|
schema_str = json.dumps(json_schema.model_json_schema())
|
|
else:
|
|
raise ValueError(
|
|
f"Cannot parse schema {json_schema}. The schema must be either "
|
|
+ "a Pydantic class, a dictionary or a string that contains the JSON "
|
|
+ "schema specification"
|
|
)
|
|
return schema_str
|
|
|
|
|
|
def get_exception_traceback():
|
|
etype, value, tb = sys.exc_info()
|
|
err_str = "".join(traceback.format_exception(etype, value, tb))
|
|
return err_str
|
|
|
|
|
|
def is_same_type(values: list):
|
|
"""Return whether the elements in values are of the same type."""
|
|
if len(values) <= 1:
|
|
return True
|
|
else:
|
|
t = type(values[0])
|
|
return all(isinstance(v, t) for v in values[1:])
|
|
|
|
|
|
def read_jsonl(filename: str):
|
|
"""Read a JSONL file."""
|
|
with open(filename) as fin:
|
|
for line in fin:
|
|
if line.startswith("#"):
|
|
continue
|
|
yield json.loads(line)
|
|
|
|
|
|
def dump_state_text(filename: str, states: list, mode: str = "w"):
|
|
"""Dump program state in a text file."""
|
|
from sglang.lang.interpreter import ProgramState
|
|
|
|
with open(filename, mode) as fout:
|
|
for i, s in enumerate(states):
|
|
if isinstance(s, str):
|
|
pass
|
|
elif isinstance(s, ProgramState):
|
|
s = s.text()
|
|
else:
|
|
s = str(s)
|
|
|
|
fout.write(
|
|
"=" * 40 + f" {i} " + "=" * 40 + "\n" + s + "\n" + "=" * 80 + "\n\n"
|
|
)
|
|
|
|
|
|
class HttpResponse:
|
|
def __init__(self, resp):
|
|
self.resp = resp
|
|
|
|
def json(self):
|
|
return json.loads(self.resp.read())
|
|
|
|
@property
|
|
def status_code(self):
|
|
return self.resp.status
|
|
|
|
|
|
def http_request(
|
|
url,
|
|
json=None,
|
|
stream=False,
|
|
api_key=None,
|
|
verify=None,
|
|
method: Optional[str] = None,
|
|
):
|
|
"""A faster version of requests.post with low-level urllib API."""
|
|
headers = {"Content-Type": "application/json; charset=utf-8"}
|
|
|
|
# add the Authorization header if an api key is provided
|
|
if api_key is not None:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
if stream:
|
|
return requests.post(url, json=json, stream=True, headers=headers)
|
|
else:
|
|
req = urllib.request.Request(url, headers=headers, method=method)
|
|
if json is None:
|
|
data = None
|
|
else:
|
|
data = bytes(dumps(json), encoding="utf-8")
|
|
|
|
try:
|
|
resp = urllib.request.urlopen(req, data=data, cafile=verify)
|
|
return HttpResponse(resp)
|
|
except urllib.error.HTTPError as e:
|
|
return HttpResponse(e)
|
|
|
|
|
|
def encode_image_base64(image_path: Union[str, bytes]):
|
|
"""Encode an image in base64."""
|
|
if isinstance(image_path, str):
|
|
with open(image_path, "rb") as image_file:
|
|
data = image_file.read()
|
|
return base64.b64encode(data).decode("utf-8")
|
|
elif isinstance(image_path, bytes):
|
|
return base64.b64encode(image_path).decode("utf-8")
|
|
else:
|
|
# image_path is PIL.WebPImagePlugin.WebPImageFile
|
|
image = image_path
|
|
buffered = BytesIO()
|
|
image.save(buffered, format="PNG")
|
|
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
|
|
|
|
def encode_frame(frame):
|
|
import cv2 # pip install opencv-python-headless
|
|
from PIL import Image
|
|
|
|
# Convert the frame to RGB (OpenCV uses BGR by default)
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
|
|
# Convert the frame to PIL Image to easily convert to bytes
|
|
im_pil = Image.fromarray(frame)
|
|
|
|
# Convert to bytes
|
|
buffered = BytesIO()
|
|
|
|
# frame_format = str(os.getenv('FRAME_FORMAT', "JPEG"))
|
|
|
|
im_pil.save(buffered, format="PNG")
|
|
|
|
frame_bytes = buffered.getvalue()
|
|
|
|
# Return the bytes of the frame
|
|
return frame_bytes
|
|
|
|
|
|
def encode_video_base64(video_path: str, num_frames: int = 16):
|
|
import cv2 # pip install opencv-python-headless
|
|
|
|
cap = cv2.VideoCapture(video_path)
|
|
if not cap.isOpened():
|
|
raise IOError(f"Could not open video file:{video_path}")
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
print(f"target_frames: {num_frames}")
|
|
|
|
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
|
|
|
frames = []
|
|
for _ in range(total_frames):
|
|
ret, frame = cap.read()
|
|
if ret:
|
|
frames.append(frame)
|
|
else:
|
|
# Handle the case where the frame could not be read
|
|
# print(f"Warning: Could not read frame at index {i}.")
|
|
pass
|
|
|
|
cap.release()
|
|
|
|
# Safely select frames based on frame_indices, avoiding IndexError
|
|
frames = [frames[i] for i in frame_indices if i < len(frames)]
|
|
|
|
# If there are not enough frames, duplicate the last frame until we reach the target
|
|
while len(frames) < num_frames:
|
|
frames.append(frames[-1])
|
|
|
|
# Use ThreadPoolExecutor to process and encode frames in parallel
|
|
with ThreadPoolExecutor() as executor:
|
|
encoded_frames = list(executor.map(encode_frame, frames))
|
|
|
|
# encoded_frames = list(map(encode_frame, frames))
|
|
|
|
# Concatenate all frames bytes
|
|
video_bytes = b"".join(encoded_frames)
|
|
|
|
# Encode the concatenated bytes to base64
|
|
video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8")
|
|
|
|
return video_base64
|
|
|
|
|
|
def _is_chinese_char(cp: int):
|
|
"""Checks whether CP is the codepoint of a CJK character."""
|
|
# This defines a "chinese character" as anything in the CJK Unicode block:
|
|
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
|
#
|
|
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
|
# despite its name. The modern Korean Hangul alphabet is a different block,
|
|
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
|
# space-separated words, so they are not treated specially and handled
|
|
# like the all of the other languages.
|
|
if (
|
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
|
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
|
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
|
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
|
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
|
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
|
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
|
): #
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def find_printable_text(text: str):
|
|
"""Returns the longest printable substring of text that contains only entire words."""
|
|
# Borrowed from https://github.com/huggingface/transformers/blob/061580c82c2db1de9139528243e105953793f7a2/src/transformers/generation/streamers.py#L99
|
|
|
|
# After the symbol for a new line, we flush the cache.
|
|
if text.endswith("\n"):
|
|
return text
|
|
# If the last token is a CJK character, we print the characters.
|
|
elif len(text) > 0 and _is_chinese_char(ord(text[-1])):
|
|
return text
|
|
# Otherwise if the penultimate token is a CJK character, we print the characters except for the last one.
|
|
elif len(text) > 1 and _is_chinese_char(ord(text[-2])):
|
|
return text[:-1]
|
|
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
|
# which may change with the subsequent token -- there are probably smarter ways to do this!)
|
|
else:
|
|
return text[: text.rfind(" ") + 1]
|
|
|
|
|
|
def graceful_registry(sub_module_name: str):
|
|
def graceful_shutdown(signum, frame):
|
|
logger.info(
|
|
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
|
|
)
|
|
if signum == signal.SIGTERM:
|
|
logger.info(f"{sub_module_name} recive sigterm")
|
|
|
|
signal.signal(signal.SIGTERM, graceful_shutdown)
|
|
|
|
|
|
class LazyImport:
|
|
"""Lazy import to make `import sglang` run faster."""
|
|
|
|
def __init__(self, module_name: str, class_name: str):
|
|
self.module_name = module_name
|
|
self.class_name = class_name
|
|
self._module = None
|
|
|
|
def _load(self):
|
|
if self._module is None:
|
|
module = importlib.import_module(self.module_name)
|
|
self._module = getattr(module, self.class_name)
|
|
return self._module
|
|
|
|
def __getattr__(self, name: str):
|
|
module = self._load()
|
|
return getattr(module, name)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
module = self._load()
|
|
return module(*args, **kwargs)
|
|
|
|
|
|
def download_and_cache_file(url: str, filename: Optional[str] = None):
|
|
"""Read and cache a file from a url."""
|
|
if filename is None:
|
|
filename = os.path.join("/tmp", url.split("/")[-1])
|
|
|
|
# Check if the cache file already exists
|
|
if os.path.exists(filename):
|
|
return filename
|
|
|
|
print(f"Downloading from {url} to {filename}")
|
|
|
|
# Stream the response to show the progress bar
|
|
response = requests.get(url, stream=True)
|
|
response.raise_for_status() # Check for request errors
|
|
|
|
# Total size of the file in bytes
|
|
total_size = int(response.headers.get("content-length", 0))
|
|
chunk_size = 1024 # Download in chunks of 1KB
|
|
|
|
# Use tqdm to display the progress bar
|
|
with open(filename, "wb") as f, tqdm(
|
|
desc=filename,
|
|
total=total_size,
|
|
unit="B",
|
|
unit_scale=True,
|
|
unit_divisor=1024,
|
|
) as bar:
|
|
for chunk in response.iter_content(chunk_size=chunk_size):
|
|
f.write(chunk)
|
|
bar.update(len(chunk))
|
|
|
|
return filename
|
|
|
|
|
|
def is_in_ci():
|
|
from sglang.test.test_utils import is_in_ci
|
|
|
|
return is_in_ci()
|
|
|
|
|
|
def print_highlight(html_content: str):
|
|
if is_in_ci():
|
|
html_content = str(html_content).replace("\n", "<br>")
|
|
display(HTML(f"<strong style='color: #00008B;'>{html_content}</strong>"))
|
|
else:
|
|
print(html_content)
|
|
|
|
|
|
process_socket_map = weakref.WeakKeyDictionary()
|
|
|
|
|
|
def reserve_port(host, start=30000, end=40000):
|
|
"""
|
|
Reserve an available port by trying to bind a socket.
|
|
Returns a tuple (port, lock_socket) where `lock_socket` is kept open to hold the lock.
|
|
"""
|
|
candidates = list(range(start, end))
|
|
random.shuffle(candidates)
|
|
|
|
for port in candidates:
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
try:
|
|
# Attempt to bind to the port on localhost
|
|
sock.bind((host, port))
|
|
return port, sock
|
|
except socket.error:
|
|
sock.close() # Failed to bind, try next port
|
|
continue
|
|
raise RuntimeError("No free port available.")
|
|
|
|
|
|
def release_port(lock_socket):
|
|
"""
|
|
Release the reserved port by closing the lock socket.
|
|
"""
|
|
try:
|
|
lock_socket.close()
|
|
except Exception as e:
|
|
print(f"Error closing socket: {e}")
|
|
|
|
|
|
def execute_shell_command(command: str) -> subprocess.Popen:
|
|
"""
|
|
Execute a shell command and return its process handle.
|
|
"""
|
|
command = command.replace("\\\n", " ").replace("\\", " ")
|
|
parts = command.split()
|
|
return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT)
|
|
|
|
|
|
def launch_server_cmd(command: str, host: str = "0.0.0.0", port: int = None):
|
|
"""
|
|
Launch the server using the given command.
|
|
If no port is specified, a free port is reserved.
|
|
"""
|
|
if port is None:
|
|
port, lock_socket = reserve_port(host)
|
|
else:
|
|
lock_socket = None
|
|
|
|
full_command = f"{command} --port {port}"
|
|
process = execute_shell_command(full_command)
|
|
|
|
if lock_socket is not None:
|
|
process_socket_map[process] = lock_socket
|
|
|
|
return process, port
|
|
|
|
|
|
def terminate_process(process):
|
|
"""
|
|
Terminate the process and automatically release the reserved port.
|
|
"""
|
|
kill_process_tree(process.pid)
|
|
|
|
lock_socket = process_socket_map.pop(process, None)
|
|
if lock_socket is not None:
|
|
release_port(lock_socket)
|
|
|
|
|
|
def wait_for_server(base_url: str, timeout: int = None) -> None:
|
|
"""Wait for the server to be ready by polling the /v1/models endpoint.
|
|
|
|
Args:
|
|
base_url: The base URL of the server
|
|
timeout: Maximum time to wait in seconds. None means wait forever.
|
|
"""
|
|
start_time = time.time()
|
|
while True:
|
|
try:
|
|
response = requests.get(
|
|
f"{base_url}/v1/models",
|
|
headers={"Authorization": "Bearer None"},
|
|
)
|
|
if response.status_code == 200:
|
|
time.sleep(5)
|
|
print_highlight(
|
|
"""\n
|
|
NOTE: Typically, the server runs in a separate terminal.
|
|
In this notebook, we run the server and notebook code together, so their outputs are combined.
|
|
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
|
|
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
|
|
"""
|
|
)
|
|
break
|
|
|
|
if timeout and time.time() - start_time > timeout:
|
|
raise TimeoutError("Server did not become ready within timeout period")
|
|
except requests.exceptions.RequestException:
|
|
time.sleep(1)
|
|
|
|
|
|
class TypeBasedDispatcher:
|
|
def __init__(self, mapping: List[Tuple[Type, Callable]]):
|
|
self._mapping = mapping
|
|
|
|
def __call__(self, obj: Any):
|
|
for ty, fn in self._mapping:
|
|
if isinstance(obj, ty):
|
|
return fn(obj)
|
|
raise ValueError(f"Invalid object: {obj}")
|
|
|
|
|
|
def trim_overlap(existing_text, new_chunk):
|
|
"""
|
|
Finds the largest suffix of 'existing_text' that is a prefix of 'new_chunk'
|
|
and removes that overlap from the start of 'new_chunk'.
|
|
"""
|
|
max_overlap = 0
|
|
max_possible = min(len(existing_text), len(new_chunk))
|
|
for i in range(max_possible, 0, -1):
|
|
if existing_text.endswith(new_chunk[:i]):
|
|
max_overlap = i
|
|
break
|
|
return new_chunk[max_overlap:]
|
|
|
|
|
|
def stream_and_merge(llm, prompt, sampling_params):
|
|
"""
|
|
1) Streams the text,
|
|
2) Removes chunk overlaps,
|
|
3) Returns the merged text.
|
|
"""
|
|
final_text = ""
|
|
for chunk in llm.generate(prompt, sampling_params, stream=True):
|
|
chunk_text = chunk["text"]
|
|
cleaned_chunk = trim_overlap(final_text, chunk_text)
|
|
final_text += cleaned_chunk
|
|
return final_text
|
|
|
|
|
|
async def async_stream_and_merge(llm, prompt, sampling_params):
|
|
"""
|
|
Streams tokens asynchronously, removes chunk overlaps,
|
|
and yields the cleaned chunk in real time for printing.
|
|
"""
|
|
final_text = ""
|
|
generator = await llm.async_generate(prompt, sampling_params, stream=True)
|
|
async for chunk in generator:
|
|
chunk_text = chunk["text"]
|
|
cleaned_chunk = trim_overlap(final_text, chunk_text)
|
|
final_text += cleaned_chunk
|
|
yield cleaned_chunk # yield the non-overlapping portion
|