501 lines
16 KiB
Python
501 lines
16 KiB
Python
import atexit
|
|
import json
|
|
import multiprocessing
|
|
import warnings
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import aiohttp
|
|
import requests
|
|
|
|
from sglang.global_config import global_config
|
|
from sglang.lang.backend.base_backend import BaseBackend
|
|
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
|
|
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
|
from sglang.lang.interpreter import StreamExecutor
|
|
from sglang.lang.ir import (
|
|
REGEX_BOOL,
|
|
REGEX_FLOAT,
|
|
REGEX_INT,
|
|
REGEX_STR,
|
|
SglSamplingParams,
|
|
)
|
|
from sglang.utils import http_request
|
|
|
|
|
|
class RuntimeEndpoint(BaseBackend):
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
api_key: Optional[str] = None,
|
|
verify: Optional[str] = None,
|
|
chat_template_name: Optional[str] = None,
|
|
):
|
|
super().__init__()
|
|
self.support_concate_and_append = True
|
|
|
|
self.base_url = base_url
|
|
self.api_key = api_key
|
|
self.verify = verify
|
|
|
|
res = http_request(
|
|
self.base_url + "/get_model_info",
|
|
api_key=self.api_key,
|
|
verify=self.verify,
|
|
)
|
|
self._assert_success(res)
|
|
self.model_info = res.json()
|
|
|
|
if chat_template_name:
|
|
self.chat_template = get_chat_template(chat_template_name)
|
|
else:
|
|
self.chat_template = get_chat_template_by_model_path(
|
|
self.model_info["model_path"]
|
|
)
|
|
|
|
def get_model_name(self):
|
|
return self.model_info["model_path"]
|
|
|
|
def flush_cache(self):
|
|
res = http_request(
|
|
self.base_url + "/flush_cache",
|
|
api_key=self.api_key,
|
|
verify=self.verify,
|
|
method="POST",
|
|
)
|
|
self._assert_success(res)
|
|
|
|
def get_server_info(self):
|
|
res = http_request(
|
|
self.base_url + "/get_server_info",
|
|
api_key=self.api_key,
|
|
verify=self.verify,
|
|
)
|
|
self._assert_success(res)
|
|
return res.json()
|
|
|
|
def get_chat_template(self):
|
|
return self.chat_template
|
|
|
|
def cache_prefix(self, prefix_str: str):
|
|
res = http_request(
|
|
self.base_url + "/generate",
|
|
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
|
api_key=self.api_key,
|
|
verify=self.verify,
|
|
)
|
|
self._assert_success(res)
|
|
|
|
def commit_lazy_operations(self, s: StreamExecutor):
|
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
|
self._add_images(s, data)
|
|
res = http_request(
|
|
self.base_url + "/generate",
|
|
json=data,
|
|
api_key=self.api_key,
|
|
verify=self.verify,
|
|
)
|
|
self._assert_success(res)
|
|
|
|
def fill_image(self, s: StreamExecutor):
|
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
|
self._add_images(s, data)
|
|
res = http_request(
|
|
self.base_url + "/generate",
|
|
json=data,
|
|
api_key=self.api_key,
|
|
verify=self.verify,
|
|
)
|
|
self._assert_success(res)
|
|
|
|
def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
|
|
if sampling_params.dtype is None:
|
|
return
|
|
|
|
if sampling_params.stop == ():
|
|
sampling_params.stop = []
|
|
|
|
dtype_regex = None
|
|
if sampling_params.dtype in ["int", int]:
|
|
|
|
dtype_regex = REGEX_INT
|
|
sampling_params.stop.extend([" ", "\n"])
|
|
elif sampling_params.dtype in ["float", float]:
|
|
|
|
dtype_regex = REGEX_FLOAT
|
|
sampling_params.stop.extend([" ", "\n"])
|
|
elif sampling_params.dtype in ["str", str]:
|
|
|
|
dtype_regex = REGEX_STR
|
|
elif sampling_params.dtype in ["bool", bool]:
|
|
|
|
dtype_regex = REGEX_BOOL
|
|
else:
|
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
|
|
|
if dtype_regex is not None and sampling_params.regex is not None:
|
|
warnings.warn(
|
|
f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
|
|
)
|
|
|
|
sampling_params.regex = dtype_regex
|
|
|
|
def generate(
|
|
self,
|
|
s: StreamExecutor,
|
|
sampling_params: SglSamplingParams,
|
|
):
|
|
self._handle_dtype_to_regex(sampling_params)
|
|
data = {
|
|
"text": s.text_,
|
|
"sampling_params": {
|
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
|
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
|
**sampling_params.to_srt_kwargs(),
|
|
},
|
|
}
|
|
|
|
for item in [
|
|
"return_logprob",
|
|
"logprob_start_len",
|
|
"top_logprobs_num",
|
|
"return_text_in_logprobs",
|
|
]:
|
|
value = getattr(sampling_params, item, None)
|
|
if value is not None:
|
|
data[item] = value
|
|
|
|
self._add_images(s, data)
|
|
|
|
res = http_request(
|
|
self.base_url + "/generate",
|
|
json=data,
|
|
api_key=self.api_key,
|
|
verify=self.verify,
|
|
)
|
|
self._assert_success(res)
|
|
|
|
obj = res.json()
|
|
comp = obj["text"]
|
|
return comp, obj["meta_info"]
|
|
|
|
def generate_stream(
|
|
self,
|
|
s: StreamExecutor,
|
|
sampling_params: SglSamplingParams,
|
|
):
|
|
self._handle_dtype_to_regex(sampling_params)
|
|
|
|
data = {
|
|
"text": s.text_,
|
|
"sampling_params": {
|
|
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
|
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
|
**sampling_params.to_srt_kwargs(),
|
|
},
|
|
}
|
|
|
|
for item in [
|
|
"return_logprob",
|
|
"logprob_start_len",
|
|
"top_logprobs_num",
|
|
"return_text_in_logprobs",
|
|
]:
|
|
value = getattr(sampling_params, item, None)
|
|
if value is not None:
|
|
data[item] = value
|
|
|
|
data["stream"] = True
|
|
self._add_images(s, data)
|
|
|
|
res = http_request(
|
|
self.base_url + "/generate",
|
|
json=data,
|
|
stream=True,
|
|
api_key=self.api_key,
|
|
verify=self.verify,
|
|
)
|
|
self._assert_success(res)
|
|
pos = 0
|
|
|
|
for chunk in res.iter_lines(decode_unicode=False):
|
|
chunk = chunk.decode("utf-8")
|
|
if chunk and chunk.startswith("data:"):
|
|
if chunk == "data: [DONE]":
|
|
break
|
|
data = json.loads(chunk[5:].strip("\n"))
|
|
chunk_text = data["text"][pos:]
|
|
meta_info = data["meta_info"]
|
|
pos += len(chunk_text)
|
|
yield chunk_text, meta_info
|
|
|
|
def select(
|
|
self,
|
|
s: StreamExecutor,
|
|
choices: List[str],
|
|
temperature: float,
|
|
choices_method: ChoicesSamplingMethod,
|
|
) -> ChoicesDecision:
|
|
assert temperature <= 1e-5
|
|
|
|
# Cache common prefix
|
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
|
obj = self._generate_http_request(s, data)
|
|
prompt_len = obj["meta_info"]["prompt_tokens"]
|
|
logprob_start_len = max(prompt_len - 2, 0) # For token healing
|
|
|
|
# Compute logprob
|
|
data = {
|
|
"text": [s.text_ + c for c in choices],
|
|
"sampling_params": {
|
|
"max_new_tokens": 0,
|
|
"temperature": 0,
|
|
},
|
|
"return_logprob": True,
|
|
"return_text_in_logprobs": True,
|
|
"logprob_start_len": logprob_start_len,
|
|
}
|
|
obj = self._generate_http_request(s, data)
|
|
|
|
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
|
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
|
normalized_prompt_logprobs = [
|
|
compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
|
|
for r in obj
|
|
]
|
|
|
|
# Remove extra token if no token healing occurred
|
|
for i in range(len(input_token_logprobs)):
|
|
healed_token_str = input_token_logprobs[i][0][-1]
|
|
if s.text_.endswith(healed_token_str):
|
|
healed_token_logprob = input_token_logprobs[i][0][0]
|
|
normalized_prompt_logprobs[i] = (
|
|
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
|
|
- healed_token_logprob
|
|
) / (len(input_token_logprobs[i]) - 1)
|
|
input_token_logprobs[i] = input_token_logprobs[i][1:]
|
|
|
|
# Compute unconditional logprobs if required
|
|
if choices_method.requires_unconditional_logprobs:
|
|
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
|
|
data = {
|
|
"input_ids": input_ids,
|
|
"sampling_params": {"max_new_tokens": 0},
|
|
"return_logprob": True,
|
|
}
|
|
obj = self._generate_http_request(s, data)
|
|
unconditional_token_logprobs = [
|
|
r["meta_info"]["input_token_logprobs"] for r in obj
|
|
]
|
|
else:
|
|
unconditional_token_logprobs = None
|
|
|
|
return choices_method(
|
|
choices=choices,
|
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
|
input_token_logprobs=input_token_logprobs,
|
|
output_token_logprobs=output_token_logprobs,
|
|
unconditional_token_logprobs=unconditional_token_logprobs,
|
|
)
|
|
|
|
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
|
res = http_request(
|
|
self.base_url + "/concate_and_append_request",
|
|
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
|
api_key=self.api_key,
|
|
verify=self.verify,
|
|
)
|
|
self._assert_success(res)
|
|
|
|
def _generate_http_request(self, s: StreamExecutor, data):
|
|
self._add_images(s, data)
|
|
res = http_request(
|
|
self.base_url + "/generate",
|
|
json=data,
|
|
api_key=self.api_key,
|
|
verify=self.verify,
|
|
)
|
|
self._assert_success(res)
|
|
return res.json()
|
|
|
|
def _add_images(self, s: StreamExecutor, data):
|
|
if s.images_:
|
|
assert len(s.images_) == 1, "Only support one image."
|
|
data["image_data"] = s.images_[0][1]
|
|
|
|
def _assert_success(self, res):
|
|
if res.status_code != 200:
|
|
raise RuntimeError(res.json())
|
|
|
|
|
|
def compute_normalized_prompt_logprobs(input_logprobs):
|
|
values = [x[0] for x in input_logprobs if x[0]]
|
|
return sum(values) / len(values)
|
|
|
|
|
|
class Runtime:
|
|
"""
|
|
A wrapper for the HTTP server.
|
|
This is used for launching the server in a python program without
|
|
using the command line interface.
|
|
|
|
It is mainly used for the frontend language.
|
|
You should use the Engine class if you want to do normal offline processing without the frontend language.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
log_level: str = "error",
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""See the arguments in server_args.py::ServerArgs"""
|
|
# We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
|
|
# client code without installing SRT server and its dependency if they want.
|
|
from sglang.srt.entrypoints.http_server import launch_server
|
|
from sglang.srt.server_args import ServerArgs
|
|
from sglang.srt.utils import is_port_available
|
|
|
|
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
|
|
|
# Pre-allocate ports
|
|
for port in range(self.server_args.port, 40000):
|
|
if is_port_available(port):
|
|
break
|
|
self.server_args.port = port
|
|
|
|
self.url = self.server_args.url()
|
|
self.generate_url = self.url + "/generate"
|
|
|
|
# NOTE: We store pid instead of proc to fix some issues during __delete__
|
|
self.pid = None
|
|
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
|
|
|
|
proc = multiprocessing.Process(
|
|
target=launch_server,
|
|
args=(self.server_args, pipe_writer),
|
|
)
|
|
proc.start()
|
|
pipe_writer.close()
|
|
self.pid = proc.pid
|
|
|
|
# Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
|
atexit.register(self.shutdown)
|
|
|
|
# TODO: remove this pipe_writer mechanism and use `/health_generate` instead.
|
|
try:
|
|
init_state = pipe_reader.recv()
|
|
except EOFError:
|
|
init_state = ""
|
|
|
|
if init_state != "ready":
|
|
self.shutdown()
|
|
raise RuntimeError(
|
|
"Initialization failed. Please see the error messages above."
|
|
)
|
|
|
|
self.endpoint = RuntimeEndpoint(self.url)
|
|
|
|
def shutdown(self):
|
|
from sglang.srt.utils import kill_process_tree
|
|
|
|
if self.pid is not None:
|
|
kill_process_tree(self.pid)
|
|
self.pid = None
|
|
|
|
def cache_prefix(self, prefix: str):
|
|
self.endpoint.cache_prefix(prefix)
|
|
|
|
def get_tokenizer(self):
|
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
|
|
return get_tokenizer(
|
|
self.server_args.tokenizer_path,
|
|
tokenizer_mode=self.server_args.tokenizer_mode,
|
|
trust_remote_code=self.server_args.trust_remote_code,
|
|
revision=self.server_args.revision,
|
|
)
|
|
|
|
async def async_generate(
|
|
self,
|
|
prompt: str,
|
|
sampling_params: Optional[Dict] = None,
|
|
):
|
|
if self.server_args.skip_tokenizer_init:
|
|
json_data = {
|
|
"input_ids": prompt,
|
|
"sampling_params": sampling_params,
|
|
"stream": True,
|
|
}
|
|
else:
|
|
json_data = {
|
|
"text": prompt,
|
|
"sampling_params": sampling_params,
|
|
"stream": True,
|
|
}
|
|
pos = 0
|
|
|
|
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
|
async with session.post(self.generate_url, json=json_data) as response:
|
|
async for chunk, _ in response.content.iter_chunks():
|
|
chunk = chunk.decode("utf-8")
|
|
if chunk and chunk.startswith("data:"):
|
|
if chunk == "data: [DONE]\n\n":
|
|
break
|
|
data = json.loads(chunk[5:].strip("\n"))
|
|
if "text" in data:
|
|
cur = data["text"][pos:]
|
|
if cur:
|
|
yield cur
|
|
pos += len(cur)
|
|
else:
|
|
yield data
|
|
|
|
add_request = async_generate
|
|
|
|
def generate(
|
|
self,
|
|
prompt: Union[str, List[str]],
|
|
sampling_params: Optional[Dict] = None,
|
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
|
lora_path: Optional[List[Optional[str]]] = None,
|
|
):
|
|
json_data = {
|
|
"text": prompt,
|
|
"sampling_params": sampling_params,
|
|
"return_logprob": return_logprob,
|
|
"logprob_start_len": logprob_start_len,
|
|
"top_logprobs_num": top_logprobs_num,
|
|
"lora_path": lora_path,
|
|
}
|
|
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
|
|
response = requests.post(
|
|
self.url + "/generate",
|
|
json=json_data,
|
|
)
|
|
return json.dumps(response.json())
|
|
|
|
def encode(
|
|
self,
|
|
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
|
):
|
|
json_data = {"text": prompt}
|
|
response = requests.post(self.url + "/encode", json=json_data)
|
|
return json.dumps(response.json())
|
|
|
|
async def get_server_info(self):
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(f"{self.url}/get_server_info") as response:
|
|
if response.status == 200:
|
|
return await response.json()
|
|
else:
|
|
error_data = await response.json()
|
|
raise RuntimeError(
|
|
f"Failed to get server info. {error_data['error']['message']}"
|
|
)
|
|
|
|
def __del__(self):
|
|
self.shutdown()
|