1022 lines
32 KiB
Python
1022 lines
32 KiB
Python
"""The interpreter that executes SGL programs"""
|
|
|
|
import asyncio
|
|
import contextvars
|
|
import copy
|
|
import multiprocessing
|
|
import queue
|
|
import threading
|
|
import uuid
|
|
import warnings
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from contextlib import contextmanager
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
import tqdm
|
|
|
|
from sglang.global_config import global_config
|
|
from sglang.lang.ir import (
|
|
SglCommitLazy,
|
|
SglConcateAndAppend,
|
|
SglConstantText,
|
|
SglExpr,
|
|
SglExprList,
|
|
SglGen,
|
|
SglImage,
|
|
SglRoleBegin,
|
|
SglRoleEnd,
|
|
SglSelect,
|
|
SglVariable,
|
|
SglVarScopeBegin,
|
|
SglVarScopeEnd,
|
|
SglVideo,
|
|
)
|
|
from sglang.utils import (
|
|
encode_image_base64,
|
|
encode_video_base64,
|
|
get_exception_traceback,
|
|
)
|
|
|
|
|
|
def run_internal(state, program, func_args, func_kwargs, sync):
|
|
try:
|
|
state.ret_value = program.func(state, *func_args, **func_kwargs)
|
|
except Exception as e:
|
|
raise e
|
|
finally:
|
|
state.stream_executor.end()
|
|
|
|
if sync:
|
|
state.stream_executor.sync()
|
|
|
|
if global_config.verbosity >= 2:
|
|
print(state.text())
|
|
|
|
|
|
def run_program(
|
|
program,
|
|
backend,
|
|
func_args,
|
|
func_kwargs,
|
|
default_sampling_para,
|
|
stream,
|
|
sync=False,
|
|
use_thread=True,
|
|
):
|
|
if hasattr(backend, "endpoint"):
|
|
backend = backend.endpoint
|
|
assert backend is not None, "Please specify a backend"
|
|
func_kwargs.update(program.bind_arguments)
|
|
stream_executor = StreamExecutor(
|
|
backend,
|
|
func_kwargs,
|
|
default_sampling_para,
|
|
chat_template=None,
|
|
stream=stream,
|
|
num_api_spec_tokens=program.num_api_spec_tokens,
|
|
use_thread=use_thread,
|
|
)
|
|
state = ProgramState(stream_executor)
|
|
|
|
if stream:
|
|
t = threading.Thread(
|
|
target=run_internal, args=(state, program, func_args, func_kwargs, sync)
|
|
)
|
|
t.start()
|
|
return state
|
|
else:
|
|
run_internal(state, program, func_args, func_kwargs, sync)
|
|
return state
|
|
|
|
|
|
def run_program_batch(
|
|
program,
|
|
backend,
|
|
batch_arguments,
|
|
default_sampling_para,
|
|
num_threads,
|
|
progress_bar,
|
|
generator_style=False,
|
|
):
|
|
if hasattr(backend, "endpoint"):
|
|
backend = backend.endpoint
|
|
|
|
# Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
|
|
if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
|
|
cache_program(program, backend)
|
|
|
|
# Run all programs
|
|
if num_threads == "auto":
|
|
num_threads = max(96, multiprocessing.cpu_count() * 16)
|
|
num_threads = min(num_threads, len(batch_arguments))
|
|
|
|
if generator_style:
|
|
return _run_program_batch_generator(
|
|
program,
|
|
backend,
|
|
batch_arguments,
|
|
default_sampling_para,
|
|
num_threads,
|
|
progress_bar,
|
|
)
|
|
|
|
# Original code path when generator_style=False
|
|
if num_threads == 1:
|
|
rets = []
|
|
if progress_bar:
|
|
for arguments in tqdm.tqdm(batch_arguments):
|
|
rets.append(
|
|
run_program(
|
|
program,
|
|
backend,
|
|
(),
|
|
arguments,
|
|
default_sampling_para,
|
|
False,
|
|
True,
|
|
)
|
|
)
|
|
else:
|
|
for arguments in batch_arguments:
|
|
rets.append(
|
|
run_program(
|
|
program,
|
|
backend,
|
|
(),
|
|
arguments,
|
|
default_sampling_para,
|
|
False,
|
|
True,
|
|
)
|
|
)
|
|
else:
|
|
if progress_bar:
|
|
pbar = tqdm.tqdm(total=len(batch_arguments))
|
|
|
|
with ThreadPoolExecutor(num_threads) as executor:
|
|
futures = []
|
|
for arguments in batch_arguments:
|
|
futures.append(
|
|
executor.submit(
|
|
run_program,
|
|
program,
|
|
backend,
|
|
(),
|
|
arguments,
|
|
default_sampling_para,
|
|
False,
|
|
True,
|
|
)
|
|
)
|
|
if progress_bar:
|
|
futures[-1].add_done_callback(lambda _: pbar.update())
|
|
|
|
rets = [f.result() for f in futures]
|
|
rets[-1].sync()
|
|
|
|
if progress_bar:
|
|
pbar.close()
|
|
|
|
return rets
|
|
|
|
|
|
def _run_program_batch_generator(
|
|
program,
|
|
backend,
|
|
batch_arguments,
|
|
default_sampling_para,
|
|
num_threads,
|
|
progress_bar,
|
|
):
|
|
"""Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor."""
|
|
if num_threads == 1:
|
|
iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments
|
|
for arguments in iterator:
|
|
yield run_program(
|
|
program,
|
|
backend,
|
|
(),
|
|
arguments,
|
|
default_sampling_para,
|
|
False,
|
|
True,
|
|
)
|
|
else:
|
|
pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None
|
|
|
|
# Process in chunks to avoid overwhelming ThreadPoolExecutor
|
|
# Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks
|
|
# so we will never reach "yield" until all tasks are done
|
|
chunk_size = 200
|
|
|
|
with ThreadPoolExecutor(num_threads) as executor:
|
|
for chunk_start in range(0, len(batch_arguments), chunk_size):
|
|
chunk_end = min(chunk_start + chunk_size, len(batch_arguments))
|
|
chunk_futures = []
|
|
|
|
# Submit chunk of tasks
|
|
for i in range(chunk_start, chunk_end):
|
|
future = executor.submit(
|
|
run_program,
|
|
program,
|
|
backend,
|
|
(),
|
|
batch_arguments[i],
|
|
default_sampling_para,
|
|
False,
|
|
True,
|
|
)
|
|
if pbar:
|
|
future.add_done_callback(lambda _: pbar.update())
|
|
chunk_futures.append(future)
|
|
|
|
# Yield results from this chunk as they complete
|
|
for future in chunk_futures:
|
|
yield future.result()
|
|
|
|
if pbar:
|
|
pbar.close()
|
|
|
|
|
|
def cache_program(program, backend):
|
|
from sglang.lang.tracer import extract_prefix_by_tracing
|
|
|
|
prefix = extract_prefix_by_tracing(program, backend)
|
|
if prefix and len(prefix) > 64:
|
|
backend.cache_prefix(prefix)
|
|
|
|
|
|
class StreamExecutor:
|
|
"""A stream executor that executes SGL expressions in a background thread."""
|
|
|
|
def __init__(
|
|
self,
|
|
backend,
|
|
arguments,
|
|
default_sampling_para,
|
|
chat_template,
|
|
stream,
|
|
num_api_spec_tokens=None,
|
|
use_thread=True,
|
|
):
|
|
from sglang.lang.backend.base_backend import BaseBackend
|
|
|
|
self.sid = uuid.uuid4().hex
|
|
self.backend: BaseBackend = backend
|
|
self.arguments: Dict[str, Any] = arguments
|
|
self.default_sampling_para = default_sampling_para
|
|
self.stream = stream
|
|
|
|
self.variables = {} # Dict[name: str -> value: str]
|
|
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
|
self.meta_info = {} # Dict[name: str -> info: str]
|
|
self.is_finished = False
|
|
self.error_ = None
|
|
|
|
# For completion
|
|
self.text_ = "" # The full text
|
|
|
|
# For chat
|
|
self.messages_ = [] # The messages in the OpenAI API format
|
|
self.chat_template = chat_template or self.backend.get_chat_template()
|
|
self.cur_role = None
|
|
self.cur_role_begin_pos = None
|
|
|
|
# For vision
|
|
self.images_ = []
|
|
self.cur_images = []
|
|
|
|
# For fork/join
|
|
self.fork_start_text_pos = None
|
|
|
|
# For speculative execution
|
|
self.num_api_spec_tokens = num_api_spec_tokens
|
|
self.speculated_text = ""
|
|
|
|
# Worker thread
|
|
self.use_thread = use_thread
|
|
if self.use_thread:
|
|
self.queue = queue.Queue()
|
|
|
|
def _run_worker_in_context():
|
|
self._thread_worker_func()
|
|
|
|
self.worker = threading.Thread(
|
|
target=contextvars.copy_context().run, args=(_run_worker_in_context,)
|
|
)
|
|
self.worker.start()
|
|
|
|
# For streaming
|
|
if stream:
|
|
self.stream_text_event = threading.Event()
|
|
self.stream_var_event = {}
|
|
else:
|
|
self.stream_text_event = None
|
|
self.stream_var_event = None
|
|
|
|
def submit(self, expr: SglExpr):
|
|
self._init_var_event(expr)
|
|
|
|
if self.use_thread:
|
|
self.queue.put(expr)
|
|
else:
|
|
self._execute(expr)
|
|
|
|
def sync(self):
|
|
if self.use_thread:
|
|
self.queue.join()
|
|
|
|
def get_var(self, name):
|
|
if name in self.variable_event:
|
|
self.variable_event[name].wait()
|
|
return self.variables[name]
|
|
|
|
def set_var(self, name, value):
|
|
self.variables[name] = value
|
|
|
|
def get_meta_info(self, name, timeout=None):
|
|
if name in self.variable_event:
|
|
got = self.variable_event[name].wait(timeout)
|
|
if not got:
|
|
raise TimeoutError(f"Timeout while waiting for event '{name}'")
|
|
ret = self.meta_info.get(name, None)
|
|
return ret
|
|
|
|
def fork(
|
|
self,
|
|
size: int = 1,
|
|
position_ids_offset: Optional[List[int]] = None,
|
|
):
|
|
if size > 1 and str(self.text_):
|
|
self.submit(SglCommitLazy())
|
|
|
|
self.sync()
|
|
size = int(size)
|
|
|
|
exes = [
|
|
StreamExecutor(
|
|
self.backend,
|
|
self.arguments,
|
|
self.default_sampling_para,
|
|
self.chat_template,
|
|
self.stream,
|
|
)
|
|
for _ in range(size)
|
|
]
|
|
for i in range(size):
|
|
exes[i].variables = dict(self.variables)
|
|
exes[i].text_ = str(self.text_)
|
|
exes[i].messages_ = list(self.messages_)
|
|
exes[i].cur_role = self.cur_role
|
|
exes[i].cur_role_begin_pos = self.cur_role_begin_pos
|
|
exes[i].fork_start_text_pos = len(self.text_)
|
|
exes[i].images_ = list(self.images_)
|
|
|
|
# TODO(ying): handle API speculative execution
|
|
|
|
return exes
|
|
|
|
def text(self):
|
|
self.sync()
|
|
return self.text_
|
|
|
|
def messages(self):
|
|
self.sync()
|
|
return self.messages_
|
|
|
|
def error(self):
|
|
self.sync()
|
|
return self.error_
|
|
|
|
def end(self):
|
|
if self.use_thread:
|
|
if self.worker.is_alive():
|
|
self.queue.put(None)
|
|
self.backend.end_program(self)
|
|
|
|
def _thread_worker_func(self):
|
|
error = None
|
|
|
|
while True:
|
|
expr = self.queue.get()
|
|
if expr is None:
|
|
self.queue.task_done()
|
|
break
|
|
|
|
try:
|
|
self._execute(expr)
|
|
except Exception as e:
|
|
warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
|
|
error = e
|
|
break
|
|
self.queue.task_done()
|
|
if self.stream_text_event:
|
|
self.stream_text_event.set()
|
|
|
|
# Clean the queue and events
|
|
if error is not None:
|
|
try:
|
|
while True:
|
|
self.queue.task_done()
|
|
self.queue.get_nowait()
|
|
except queue.Empty:
|
|
pass
|
|
for name in self.variable_event:
|
|
self.variable_event[name].set()
|
|
if self.stream_var_event:
|
|
for name in self.stream_var_event:
|
|
self.stream_var_event[name].set()
|
|
self.error_ = error
|
|
|
|
if self.stream_text_event:
|
|
self.stream_text_event.set()
|
|
|
|
self.is_finished = True
|
|
|
|
def _execute(self, other):
|
|
if isinstance(other, str):
|
|
other = SglConstantText(other)
|
|
|
|
assert isinstance(other, SglExpr), f"{other}"
|
|
|
|
if isinstance(other, SglConstantText):
|
|
self._execute_fill(other.value)
|
|
elif isinstance(other, SglGen):
|
|
self._execute_gen(other)
|
|
elif isinstance(other, SglSelect):
|
|
self._execute_select(other)
|
|
elif isinstance(other, SglExprList):
|
|
for x in other.expr_list:
|
|
self._execute(x)
|
|
elif isinstance(other, SglRoleBegin):
|
|
self._execute_role_begin(other)
|
|
elif isinstance(other, SglRoleEnd):
|
|
self._execute_role_end(other)
|
|
elif isinstance(other, SglImage):
|
|
self._execute_image(other)
|
|
elif isinstance(other, SglVideo):
|
|
self._execute_video(other)
|
|
elif isinstance(other, SglVariable):
|
|
self._execute_variable(other)
|
|
elif isinstance(other, SglVarScopeBegin):
|
|
self._execute_var_scope_begin(other)
|
|
elif isinstance(other, SglVarScopeEnd):
|
|
self._execute_var_scope_end(other)
|
|
elif isinstance(other, SglCommitLazy):
|
|
self._execute_commit_lazy_operations(other)
|
|
elif isinstance(other, SglConcateAndAppend):
|
|
if (
|
|
global_config.enable_parallel_encoding
|
|
and self.backend.support_concate_and_append
|
|
):
|
|
self._execute_concatenate_and_append_kv_cache(other)
|
|
else:
|
|
self._execute_concatenate_and_append_text(other)
|
|
else:
|
|
raise ValueError(f"Unknown type: {type(other)}")
|
|
|
|
def _execute_fill(self, value: str, prefix=False):
|
|
value = str(value)
|
|
|
|
if (
|
|
self.cur_role == "assistant"
|
|
and self.num_api_spec_tokens is not None
|
|
and self.backend.is_chat_model
|
|
and not prefix
|
|
):
|
|
self.backend.spec_fill(value)
|
|
return
|
|
|
|
if self.speculated_text.startswith(value):
|
|
self.speculated_text = self.speculated_text[len(value) :]
|
|
else:
|
|
self.speculated_text = ""
|
|
|
|
self.text_ += value
|
|
|
|
def _execute_image(self, expr: SglImage):
|
|
path = expr.path
|
|
|
|
base64_data = encode_image_base64(path)
|
|
|
|
self.images_.append((path, base64_data))
|
|
self.cur_images.append((path, base64_data))
|
|
self.text_ += self.chat_template.image_token
|
|
|
|
def _execute_video(self, expr: SglVideo):
|
|
path = expr.path
|
|
num_frames = expr.num_frames
|
|
|
|
base64_data = encode_video_base64(path, num_frames)
|
|
|
|
self.images_.append((path, base64_data))
|
|
self.cur_images.append((path, base64_data))
|
|
self.text_ += self.chat_template.image_token
|
|
|
|
def _spec_gen(self, sampling_params):
|
|
stop = sampling_params.stop
|
|
max_new_tokens = sampling_params.max_new_tokens
|
|
meta_info = {}
|
|
|
|
def regen():
|
|
nonlocal meta_info
|
|
|
|
sampling_params.max_new_tokens = max(
|
|
sampling_params.max_new_tokens, self.num_api_spec_tokens
|
|
)
|
|
sampling_params.stop = None
|
|
self.speculated_text, meta_info = self.backend.generate(
|
|
self, sampling_params=sampling_params
|
|
)
|
|
|
|
def find_stop():
|
|
if isinstance(stop, str):
|
|
return self.speculated_text.find(stop)
|
|
elif isinstance(stop, (tuple, list)):
|
|
pos = -1
|
|
for stop_str in stop:
|
|
stop_pos = self.speculated_text.find(stop_str)
|
|
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
|
pos = stop_pos
|
|
return pos
|
|
else:
|
|
raise Exception("Wrong type of stop in sampling parameters.")
|
|
|
|
if stop is None:
|
|
if len(self.speculated_text) < max_new_tokens:
|
|
regen()
|
|
comp = self.speculated_text[:max_new_tokens]
|
|
self.speculated_text = self.speculated_text[max_new_tokens:]
|
|
elif isinstance(stop, (str, list, tuple)):
|
|
if self.speculated_text == "":
|
|
regen()
|
|
stop_pos = find_stop()
|
|
if stop_pos == -1:
|
|
stop_pos = min(
|
|
sampling_params.max_new_tokens,
|
|
len(self.speculated_text),
|
|
)
|
|
comp = self.speculated_text[:stop_pos]
|
|
self.speculated_text = self.speculated_text[stop_pos:]
|
|
else:
|
|
raise ValueError("Wrong type of stop in sampling parameters.")
|
|
|
|
return comp, meta_info
|
|
|
|
def _execute_gen(self, expr: SglGen):
|
|
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
|
name = expr.name
|
|
if not self.stream:
|
|
if self.num_api_spec_tokens is None:
|
|
comp, meta_info = self.backend.generate(
|
|
self,
|
|
sampling_params=sampling_params,
|
|
)
|
|
|
|
else:
|
|
if self.backend.is_chat_model:
|
|
# Speculative execution on models with only chat interface.
|
|
# Store the calls into a temporary list.
|
|
# They will be lazily executed later.
|
|
comp, meta_info = self.backend.generate(
|
|
self,
|
|
sampling_params=sampling_params,
|
|
spec_var_name=name,
|
|
)
|
|
return
|
|
|
|
else: # Speculative execution on models with completion interface
|
|
comp, meta_info = self._spec_gen(sampling_params)
|
|
if isinstance(comp, list):
|
|
self.text_ += comp[0]
|
|
else:
|
|
assert isinstance(comp, str)
|
|
self.text_ += comp
|
|
|
|
self.variables[name] = comp
|
|
self.meta_info[name] = meta_info
|
|
self.variable_event[name].set()
|
|
else:
|
|
assert (
|
|
self.num_api_spec_tokens is None
|
|
), "stream is not supported with api speculative execution"
|
|
generator = self.backend.generate_stream(
|
|
self, sampling_params=sampling_params
|
|
)
|
|
|
|
self.variables[name] = ""
|
|
self.stream_var_event[name].set()
|
|
|
|
for comp, meta_info in generator:
|
|
self.text_ += comp
|
|
self.variables[name] += comp
|
|
self.meta_info[name] = meta_info
|
|
self.stream_var_event[name].set()
|
|
self.stream_text_event.set()
|
|
|
|
self.variable_event[name].set()
|
|
self.stream_var_event[name].set()
|
|
|
|
def _execute_select(self, expr: SglSelect):
|
|
choices_decision = self.backend.select(
|
|
self, expr.choices, expr.temperature, expr.choices_method
|
|
)
|
|
if expr.name is not None:
|
|
name = expr.name
|
|
self.variables[name] = choices_decision.decision
|
|
self.meta_info[name] = choices_decision.meta_info
|
|
self.variable_event[name].set()
|
|
if self.stream_var_event:
|
|
self.stream_var_event[name].set()
|
|
self.text_ += choices_decision.decision
|
|
|
|
def _execute_variable(self, expr: SglVariable):
|
|
src_executor = expr.source_stream_executor
|
|
value = src_executor.get_var(expr.name)
|
|
self._execute_fill(value)
|
|
|
|
def _execute_role_begin(self, expr: SglRoleBegin):
|
|
assert self.cur_role is None, "Nested roles are not allowed."
|
|
|
|
if len(self.messages_) == 0 and expr.role != "system":
|
|
# Insert the default system message
|
|
default_system = self.chat_template.default_system_prompt
|
|
if default_system:
|
|
self._execute_role_begin(SglRoleBegin("system"))
|
|
self._execute_fill(default_system)
|
|
self._execute_role_end(SglRoleEnd("system"))
|
|
|
|
self.cur_role = expr.role
|
|
|
|
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
|
|
|
self._execute_fill(prefix, prefix=True)
|
|
self.cur_role_begin_pos = len(self.text_)
|
|
|
|
def _execute_role_end(self, expr: SglRoleEnd):
|
|
if (
|
|
self.cur_role == "assistant"
|
|
and self.num_api_spec_tokens is not None
|
|
and self.backend.is_chat_model
|
|
):
|
|
# Execute the stored lazy generation calls
|
|
self.backend.role_end_generate(self)
|
|
self.cur_role = None
|
|
|
|
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
|
|
|
|
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
|
self._execute_fill(suffix)
|
|
|
|
if self.cur_images:
|
|
# OpenAI vision API format
|
|
last_msg = {
|
|
"role": expr.role,
|
|
"content": [{"type": "text", "text": new_text}],
|
|
}
|
|
for image_path, image_base64_data in self.cur_images:
|
|
last_msg["content"].append(
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{image_base64_data}"
|
|
},
|
|
}
|
|
)
|
|
self.messages_.append(last_msg)
|
|
self.cur_images = []
|
|
else:
|
|
# OpenAI chat API format
|
|
self.messages_.append({"role": expr.role, "content": new_text})
|
|
|
|
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
|
|
self.variables[expr.name] = int(len(self.text_))
|
|
|
|
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
|
self.variables[expr.name] = self.text_[self.variables[expr.name] :]
|
|
self.variable_event[expr.name].set()
|
|
|
|
def _execute_commit_lazy_operations(self, expr: SglCommitLazy):
|
|
self.backend.commit_lazy_operations(self)
|
|
|
|
def _execute_concatenate_and_append_text(self, expr: SglConcateAndAppend):
|
|
new_text = ""
|
|
for s in expr.states:
|
|
exe = s.stream_executor
|
|
exe.sync()
|
|
new_text += exe.text_[exe.fork_start_text_pos :]
|
|
|
|
self._execute_fill(new_text)
|
|
|
|
def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend):
|
|
self_len = len(self.text_)
|
|
|
|
for i, s in enumerate(expr.states):
|
|
exe = s.stream_executor
|
|
exe.submit(SglCommitLazy())
|
|
|
|
for i, s in enumerate(expr.states):
|
|
exe = s.stream_executor
|
|
exe.sync()
|
|
assert exe.fork_start_text_pos == self_len
|
|
self.text_ += exe.text_[exe.fork_start_text_pos :]
|
|
|
|
src_rids = [state.stream_executor.sid for state in expr.states]
|
|
self.backend.concatenate_and_append(src_rids, self.sid)
|
|
|
|
def _init_var_event(self, expr):
|
|
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
|
|
self.variable_event[expr.name] = threading.Event()
|
|
if self.stream:
|
|
self.stream_var_event[expr.name] = threading.Event()
|
|
elif isinstance(expr, SglExprList):
|
|
for e in expr.expr_list:
|
|
self._init_var_event(e)
|
|
|
|
def _resolve_sampling_params(self, sampling_params):
|
|
"""
|
|
Construct sampling param based on default + override values
|
|
|
|
The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args)
|
|
, and `sampling_params` contains the override values from sgl.gen().
|
|
|
|
Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`.
|
|
It also extends the stop tokens based on the chat template.
|
|
"""
|
|
|
|
# deepcopy is required because the dict has lists inside
|
|
clone = copy.deepcopy(self.default_sampling_para)
|
|
|
|
for item in [
|
|
"max_new_tokens",
|
|
"min_new_tokens",
|
|
"n",
|
|
"stop",
|
|
"stop_token_ids",
|
|
"temperature",
|
|
"top_p",
|
|
"top_k",
|
|
"min_p",
|
|
"frequency_penalty",
|
|
"presence_penalty",
|
|
"ignore_eos",
|
|
"return_logprob",
|
|
"logprob_start_len",
|
|
"top_logprobs_num",
|
|
"return_text_in_logprobs",
|
|
"dtype",
|
|
"regex",
|
|
"json_schema",
|
|
]:
|
|
value = getattr(sampling_params, item, None)
|
|
if value is not None:
|
|
setattr(clone, item, value)
|
|
|
|
if self.chat_template.stop_str:
|
|
if clone.stop == ():
|
|
clone.stop = []
|
|
elif isinstance(clone.stop, str):
|
|
clone.stop = [clone.stop]
|
|
clone.stop += self.chat_template.stop_str
|
|
|
|
return clone
|
|
|
|
def __del__(self):
|
|
self.end()
|
|
|
|
|
|
class ProgramState:
|
|
"""The state of an SGL program."""
|
|
|
|
def __init__(self, stream_executor: StreamExecutor):
|
|
self.stream_executor = stream_executor
|
|
|
|
def _role_common(self, name: str, expr: Optional[SglExpr] = None):
|
|
if expr is not None:
|
|
role_expr = SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
|
|
self.stream_executor.submit(role_expr)
|
|
return role_expr
|
|
else:
|
|
|
|
@contextmanager
|
|
def role_scope():
|
|
self.stream_executor.submit(SglRoleBegin(name))
|
|
yield
|
|
self.stream_executor.submit(SglRoleEnd(name))
|
|
|
|
return role_scope()
|
|
|
|
def system(self, expr: Optional[SglExpr] = None):
|
|
return self._role_common("system", expr)
|
|
|
|
def user(self, expr: Optional[SglExpr] = None):
|
|
return self._role_common("user", expr)
|
|
|
|
def assistant(self, expr: Optional[SglExpr] = None):
|
|
return self._role_common("assistant", expr)
|
|
|
|
@contextmanager
|
|
def var_scope(self, name: str):
|
|
self.stream_executor.submit(SglVarScopeBegin(name))
|
|
yield
|
|
self.stream_executor.submit(SglVarScopeEnd(name))
|
|
|
|
def fork(
|
|
self,
|
|
size: int = 1,
|
|
position_ids_offset: Optional[List[int]] = None,
|
|
):
|
|
stream_executors = self.stream_executor.fork(size, position_ids_offset)
|
|
states = [ProgramState(x) for x in stream_executors]
|
|
state_group = ProgramStateGroup(states, self)
|
|
return state_group
|
|
|
|
@contextmanager
|
|
def copy(self, position_ids_offset: Optional[List[int]] = None):
|
|
state_group = self.fork(1, position_ids_offset)
|
|
try:
|
|
yield state_group[0]
|
|
finally:
|
|
state_group.join()
|
|
|
|
def text(self):
|
|
return self.stream_executor.text()
|
|
|
|
def messages(self):
|
|
return self.stream_executor.messages()
|
|
|
|
def sync(self):
|
|
return self.stream_executor.sync()
|
|
|
|
def error(self):
|
|
return self.stream_executor.error()
|
|
|
|
def text_iter(self, var_name: Optional[str] = None):
|
|
if self.stream_executor.stream:
|
|
prev = 0
|
|
if var_name is None:
|
|
event = self.stream_executor.stream_text_event
|
|
while True:
|
|
event.wait()
|
|
event.clear()
|
|
out = str(self.stream_executor.text_[prev:])
|
|
prev += len(out)
|
|
if out:
|
|
yield out
|
|
if self.stream_executor.is_finished:
|
|
break
|
|
else:
|
|
event = None
|
|
while not event:
|
|
if var_name in self.stream_executor.stream_var_event:
|
|
event = self.stream_executor.stream_var_event[var_name]
|
|
if self.stream_executor.is_finished:
|
|
yield ""
|
|
return
|
|
|
|
while True:
|
|
event.wait()
|
|
event.clear()
|
|
out = str(self.stream_executor.variables[var_name][prev:])
|
|
prev += len(out)
|
|
if out:
|
|
yield out
|
|
if self.stream_executor.variable_event[var_name].is_set():
|
|
break
|
|
else:
|
|
if var_name is None:
|
|
yield self.text()
|
|
else:
|
|
yield self.get_var(var_name)
|
|
|
|
async def text_async_iter(
|
|
self, var_name: Optional[str] = None, return_meta_data: bool = False
|
|
):
|
|
loop = asyncio.get_running_loop()
|
|
|
|
if self.stream_executor.stream:
|
|
prev = 0
|
|
if var_name is None:
|
|
event = self.stream_executor.stream_text_event
|
|
while True:
|
|
await loop.run_in_executor(None, event.wait)
|
|
event.clear()
|
|
out = str(self.stream_executor.text_[prev:])
|
|
prev += len(out)
|
|
if out:
|
|
yield out
|
|
if self.stream_executor.is_finished:
|
|
break
|
|
else:
|
|
event = None
|
|
while not event:
|
|
if var_name in self.stream_executor.stream_var_event:
|
|
event = self.stream_executor.stream_var_event[var_name]
|
|
if self.stream_executor.is_finished:
|
|
yield ""
|
|
return
|
|
|
|
while True:
|
|
await loop.run_in_executor(None, event.wait)
|
|
event.clear()
|
|
out = str(self.stream_executor.variables[var_name][prev:])
|
|
prev += len(out)
|
|
if out:
|
|
if return_meta_data:
|
|
yield out, self.stream_executor.meta_info[var_name]
|
|
else:
|
|
yield out
|
|
if self.stream_executor.variable_event[var_name].is_set():
|
|
break
|
|
else:
|
|
if var_name is None:
|
|
yield self.text()
|
|
else:
|
|
yield self.get_var(var_name)
|
|
|
|
def get_var(self, name):
|
|
return self.stream_executor.get_var(name)
|
|
|
|
def set_var(self, name, value):
|
|
return self.stream_executor.set_var(name, value)
|
|
|
|
def get_meta_info(self, name):
|
|
return self.stream_executor.get_meta_info(name)
|
|
|
|
def __iadd__(self, other):
|
|
if other is None:
|
|
raise ValueError("Tried to append None to state.")
|
|
self.stream_executor.submit(other)
|
|
return self
|
|
|
|
def __getitem__(self, name):
|
|
return self.get_var(name)
|
|
|
|
def __setitem__(self, name, value):
|
|
self.set_var(name, value)
|
|
|
|
def __contains__(self, name):
|
|
return name in self.stream_executor.variables
|
|
|
|
def __del__(self):
|
|
self.stream_executor.end()
|
|
|
|
def __repr__(self) -> str:
|
|
return f"ProgramState({self.text()})"
|
|
|
|
|
|
class ProgramStateGroup:
|
|
def __init__(
|
|
self, states: List[ProgramState], src_state: Optional[ProgramState] = None
|
|
):
|
|
self.states = states
|
|
self.src_state = src_state
|
|
|
|
def join(self, mode: str = "gather_variable"):
|
|
if mode == "gather_variable":
|
|
# Copy variables back
|
|
src_vars = self.src_state.stream_executor.variables
|
|
src_var_set = set(src_vars.keys())
|
|
for child_state in self.states:
|
|
child_state.stream_executor.sync()
|
|
child_vars = child_state.stream_executor.variables
|
|
new_vars = set(child_vars.keys()) - src_var_set
|
|
|
|
for k in new_vars:
|
|
if k in src_vars:
|
|
src_vars[k].append(child_vars[k])
|
|
else:
|
|
src_vars[k] = [child_vars[k]]
|
|
elif mode == "concate_and_append":
|
|
# Concatenate and append KV cache
|
|
self.src_state += SglConcateAndAppend(self.states)
|
|
# Need a sync here. Otherwise, `states` can be deleted.
|
|
self.src_state.stream_executor.sync()
|
|
else:
|
|
raise ValueError(f"Invalid join mode: {mode}")
|
|
|
|
for s in self.states:
|
|
s.stream_executor.end()
|
|
|
|
def __getitem__(self, i: int):
|
|
return self.states[i]
|
|
|
|
def __setitem__(self, i: int, value):
|
|
assert self.states[i] == value
|
|
|
|
def __iadd__(self, other):
|
|
if isinstance(other, Callable):
|
|
# lambda function
|
|
for i in range(len(self.states)):
|
|
self.states[i] += other(i)
|
|
elif isinstance(other, SglExpr):
|
|
for i in range(len(self.states)):
|
|
self.states[i] += other
|
|
elif isinstance(other, (list, tuple)):
|
|
for i in range(len(self.states)):
|
|
self.states[i] += other[i]
|
|
else:
|
|
raise ValueError(f"Invalid value: {other}")
|
|
|
|
return self
|