sglang0.4.5.post1/python/sglang/lang/interpreter.py

1022 lines
32 KiB
Python

"""The interpreter that executes SGL programs"""
import asyncio
import contextvars
import copy
import multiprocessing
import queue
import threading
import uuid
import warnings
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional
import tqdm
from sglang.global_config import global_config
from sglang.lang.ir import (
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
SglExpr,
SglExprList,
SglGen,
SglImage,
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglVariable,
SglVarScopeBegin,
SglVarScopeEnd,
SglVideo,
)
from sglang.utils import (
encode_image_base64,
encode_video_base64,
get_exception_traceback,
)
def run_internal(state, program, func_args, func_kwargs, sync):
try:
state.ret_value = program.func(state, *func_args, **func_kwargs)
except Exception as e:
raise e
finally:
state.stream_executor.end()
if sync:
state.stream_executor.sync()
if global_config.verbosity >= 2:
print(state.text())
def run_program(
program,
backend,
func_args,
func_kwargs,
default_sampling_para,
stream,
sync=False,
use_thread=True,
):
if hasattr(backend, "endpoint"):
backend = backend.endpoint
assert backend is not None, "Please specify a backend"
func_kwargs.update(program.bind_arguments)
stream_executor = StreamExecutor(
backend,
func_kwargs,
default_sampling_para,
chat_template=None,
stream=stream,
num_api_spec_tokens=program.num_api_spec_tokens,
use_thread=use_thread,
)
state = ProgramState(stream_executor)
if stream:
t = threading.Thread(
target=run_internal, args=(state, program, func_args, func_kwargs, sync)
)
t.start()
return state
else:
run_internal(state, program, func_args, func_kwargs, sync)
return state
def run_program_batch(
program,
backend,
batch_arguments,
default_sampling_para,
num_threads,
progress_bar,
generator_style=False,
):
if hasattr(backend, "endpoint"):
backend = backend.endpoint
# Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
cache_program(program, backend)
# Run all programs
if num_threads == "auto":
num_threads = max(96, multiprocessing.cpu_count() * 16)
num_threads = min(num_threads, len(batch_arguments))
if generator_style:
return _run_program_batch_generator(
program,
backend,
batch_arguments,
default_sampling_para,
num_threads,
progress_bar,
)
# Original code path when generator_style=False
if num_threads == 1:
rets = []
if progress_bar:
for arguments in tqdm.tqdm(batch_arguments):
rets.append(
run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
)
else:
for arguments in batch_arguments:
rets.append(
run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
)
else:
if progress_bar:
pbar = tqdm.tqdm(total=len(batch_arguments))
with ThreadPoolExecutor(num_threads) as executor:
futures = []
for arguments in batch_arguments:
futures.append(
executor.submit(
run_program,
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
)
if progress_bar:
futures[-1].add_done_callback(lambda _: pbar.update())
rets = [f.result() for f in futures]
rets[-1].sync()
if progress_bar:
pbar.close()
return rets
def _run_program_batch_generator(
program,
backend,
batch_arguments,
default_sampling_para,
num_threads,
progress_bar,
):
"""Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor."""
if num_threads == 1:
iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments
for arguments in iterator:
yield run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
else:
pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None
# Process in chunks to avoid overwhelming ThreadPoolExecutor
# Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks
# so we will never reach "yield" until all tasks are done
chunk_size = 200
with ThreadPoolExecutor(num_threads) as executor:
for chunk_start in range(0, len(batch_arguments), chunk_size):
chunk_end = min(chunk_start + chunk_size, len(batch_arguments))
chunk_futures = []
# Submit chunk of tasks
for i in range(chunk_start, chunk_end):
future = executor.submit(
run_program,
program,
backend,
(),
batch_arguments[i],
default_sampling_para,
False,
True,
)
if pbar:
future.add_done_callback(lambda _: pbar.update())
chunk_futures.append(future)
# Yield results from this chunk as they complete
for future in chunk_futures:
yield future.result()
if pbar:
pbar.close()
def cache_program(program, backend):
from sglang.lang.tracer import extract_prefix_by_tracing
prefix = extract_prefix_by_tracing(program, backend)
if prefix and len(prefix) > 64:
backend.cache_prefix(prefix)
class StreamExecutor:
"""A stream executor that executes SGL expressions in a background thread."""
def __init__(
self,
backend,
arguments,
default_sampling_para,
chat_template,
stream,
num_api_spec_tokens=None,
use_thread=True,
):
from sglang.lang.backend.base_backend import BaseBackend
self.sid = uuid.uuid4().hex
self.backend: BaseBackend = backend
self.arguments: Dict[str, Any] = arguments
self.default_sampling_para = default_sampling_para
self.stream = stream
self.variables = {} # Dict[name: str -> value: str]
self.variable_event = {} # Dict[name: str -> event: threading.Event]
self.meta_info = {} # Dict[name: str -> info: str]
self.is_finished = False
self.error_ = None
# For completion
self.text_ = "" # The full text
# For chat
self.messages_ = [] # The messages in the OpenAI API format
self.chat_template = chat_template or self.backend.get_chat_template()
self.cur_role = None
self.cur_role_begin_pos = None
# For vision
self.images_ = []
self.cur_images = []
# For fork/join
self.fork_start_text_pos = None
# For speculative execution
self.num_api_spec_tokens = num_api_spec_tokens
self.speculated_text = ""
# Worker thread
self.use_thread = use_thread
if self.use_thread:
self.queue = queue.Queue()
def _run_worker_in_context():
self._thread_worker_func()
self.worker = threading.Thread(
target=contextvars.copy_context().run, args=(_run_worker_in_context,)
)
self.worker.start()
# For streaming
if stream:
self.stream_text_event = threading.Event()
self.stream_var_event = {}
else:
self.stream_text_event = None
self.stream_var_event = None
def submit(self, expr: SglExpr):
self._init_var_event(expr)
if self.use_thread:
self.queue.put(expr)
else:
self._execute(expr)
def sync(self):
if self.use_thread:
self.queue.join()
def get_var(self, name):
if name in self.variable_event:
self.variable_event[name].wait()
return self.variables[name]
def set_var(self, name, value):
self.variables[name] = value
def get_meta_info(self, name, timeout=None):
if name in self.variable_event:
got = self.variable_event[name].wait(timeout)
if not got:
raise TimeoutError(f"Timeout while waiting for event '{name}'")
ret = self.meta_info.get(name, None)
return ret
def fork(
self,
size: int = 1,
position_ids_offset: Optional[List[int]] = None,
):
if size > 1 and str(self.text_):
self.submit(SglCommitLazy())
self.sync()
size = int(size)
exes = [
StreamExecutor(
self.backend,
self.arguments,
self.default_sampling_para,
self.chat_template,
self.stream,
)
for _ in range(size)
]
for i in range(size):
exes[i].variables = dict(self.variables)
exes[i].text_ = str(self.text_)
exes[i].messages_ = list(self.messages_)
exes[i].cur_role = self.cur_role
exes[i].cur_role_begin_pos = self.cur_role_begin_pos
exes[i].fork_start_text_pos = len(self.text_)
exes[i].images_ = list(self.images_)
# TODO(ying): handle API speculative execution
return exes
def text(self):
self.sync()
return self.text_
def messages(self):
self.sync()
return self.messages_
def error(self):
self.sync()
return self.error_
def end(self):
if self.use_thread:
if self.worker.is_alive():
self.queue.put(None)
self.backend.end_program(self)
def _thread_worker_func(self):
error = None
while True:
expr = self.queue.get()
if expr is None:
self.queue.task_done()
break
try:
self._execute(expr)
except Exception as e:
warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
error = e
break
self.queue.task_done()
if self.stream_text_event:
self.stream_text_event.set()
# Clean the queue and events
if error is not None:
try:
while True:
self.queue.task_done()
self.queue.get_nowait()
except queue.Empty:
pass
for name in self.variable_event:
self.variable_event[name].set()
if self.stream_var_event:
for name in self.stream_var_event:
self.stream_var_event[name].set()
self.error_ = error
if self.stream_text_event:
self.stream_text_event.set()
self.is_finished = True
def _execute(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr), f"{other}"
if isinstance(other, SglConstantText):
self._execute_fill(other.value)
elif isinstance(other, SglGen):
self._execute_gen(other)
elif isinstance(other, SglSelect):
self._execute_select(other)
elif isinstance(other, SglExprList):
for x in other.expr_list:
self._execute(x)
elif isinstance(other, SglRoleBegin):
self._execute_role_begin(other)
elif isinstance(other, SglRoleEnd):
self._execute_role_end(other)
elif isinstance(other, SglImage):
self._execute_image(other)
elif isinstance(other, SglVideo):
self._execute_video(other)
elif isinstance(other, SglVariable):
self._execute_variable(other)
elif isinstance(other, SglVarScopeBegin):
self._execute_var_scope_begin(other)
elif isinstance(other, SglVarScopeEnd):
self._execute_var_scope_end(other)
elif isinstance(other, SglCommitLazy):
self._execute_commit_lazy_operations(other)
elif isinstance(other, SglConcateAndAppend):
if (
global_config.enable_parallel_encoding
and self.backend.support_concate_and_append
):
self._execute_concatenate_and_append_kv_cache(other)
else:
self._execute_concatenate_and_append_text(other)
else:
raise ValueError(f"Unknown type: {type(other)}")
def _execute_fill(self, value: str, prefix=False):
value = str(value)
if (
self.cur_role == "assistant"
and self.num_api_spec_tokens is not None
and self.backend.is_chat_model
and not prefix
):
self.backend.spec_fill(value)
return
if self.speculated_text.startswith(value):
self.speculated_text = self.speculated_text[len(value) :]
else:
self.speculated_text = ""
self.text_ += value
def _execute_image(self, expr: SglImage):
path = expr.path
base64_data = encode_image_base64(path)
self.images_.append((path, base64_data))
self.cur_images.append((path, base64_data))
self.text_ += self.chat_template.image_token
def _execute_video(self, expr: SglVideo):
path = expr.path
num_frames = expr.num_frames
base64_data = encode_video_base64(path, num_frames)
self.images_.append((path, base64_data))
self.cur_images.append((path, base64_data))
self.text_ += self.chat_template.image_token
def _spec_gen(self, sampling_params):
stop = sampling_params.stop
max_new_tokens = sampling_params.max_new_tokens
meta_info = {}
def regen():
nonlocal meta_info
sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.num_api_spec_tokens
)
sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
def find_stop():
if isinstance(stop, str):
return self.speculated_text.find(stop)
elif isinstance(stop, (tuple, list)):
pos = -1
for stop_str in stop:
stop_pos = self.speculated_text.find(stop_str)
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
pos = stop_pos
return pos
else:
raise Exception("Wrong type of stop in sampling parameters.")
if stop is None:
if len(self.speculated_text) < max_new_tokens:
regen()
comp = self.speculated_text[:max_new_tokens]
self.speculated_text = self.speculated_text[max_new_tokens:]
elif isinstance(stop, (str, list, tuple)):
if self.speculated_text == "":
regen()
stop_pos = find_stop()
if stop_pos == -1:
stop_pos = min(
sampling_params.max_new_tokens,
len(self.speculated_text),
)
comp = self.speculated_text[:stop_pos]
self.speculated_text = self.speculated_text[stop_pos:]
else:
raise ValueError("Wrong type of stop in sampling parameters.")
return comp, meta_info
def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name
if not self.stream:
if self.num_api_spec_tokens is None:
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
)
else:
if self.backend.is_chat_model:
# Speculative execution on models with only chat interface.
# Store the calls into a temporary list.
# They will be lazily executed later.
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
spec_var_name=name,
)
return
else: # Speculative execution on models with completion interface
comp, meta_info = self._spec_gen(sampling_params)
if isinstance(comp, list):
self.text_ += comp[0]
else:
assert isinstance(comp, str)
self.text_ += comp
self.variables[name] = comp
self.meta_info[name] = meta_info
self.variable_event[name].set()
else:
assert (
self.num_api_spec_tokens is None
), "stream is not supported with api speculative execution"
generator = self.backend.generate_stream(
self, sampling_params=sampling_params
)
self.variables[name] = ""
self.stream_var_event[name].set()
for comp, meta_info in generator:
self.text_ += comp
self.variables[name] += comp
self.meta_info[name] = meta_info
self.stream_var_event[name].set()
self.stream_text_event.set()
self.variable_event[name].set()
self.stream_var_event[name].set()
def _execute_select(self, expr: SglSelect):
choices_decision = self.backend.select(
self, expr.choices, expr.temperature, expr.choices_method
)
if expr.name is not None:
name = expr.name
self.variables[name] = choices_decision.decision
self.meta_info[name] = choices_decision.meta_info
self.variable_event[name].set()
if self.stream_var_event:
self.stream_var_event[name].set()
self.text_ += choices_decision.decision
def _execute_variable(self, expr: SglVariable):
src_executor = expr.source_stream_executor
value = src_executor.get_var(expr.name)
self._execute_fill(value)
def _execute_role_begin(self, expr: SglRoleBegin):
assert self.cur_role is None, "Nested roles are not allowed."
if len(self.messages_) == 0 and expr.role != "system":
# Insert the default system message
default_system = self.chat_template.default_system_prompt
if default_system:
self._execute_role_begin(SglRoleBegin("system"))
self._execute_fill(default_system)
self._execute_role_end(SglRoleEnd("system"))
self.cur_role = expr.role
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
self._execute_fill(prefix, prefix=True)
self.cur_role_begin_pos = len(self.text_)
def _execute_role_end(self, expr: SglRoleEnd):
if (
self.cur_role == "assistant"
and self.num_api_spec_tokens is not None
and self.backend.is_chat_model
):
# Execute the stored lazy generation calls
self.backend.role_end_generate(self)
self.cur_role = None
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
self._execute_fill(suffix)
if self.cur_images:
# OpenAI vision API format
last_msg = {
"role": expr.role,
"content": [{"type": "text", "text": new_text}],
}
for image_path, image_base64_data in self.cur_images:
last_msg["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64_data}"
},
}
)
self.messages_.append(last_msg)
self.cur_images = []
else:
# OpenAI chat API format
self.messages_.append({"role": expr.role, "content": new_text})
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
self.variables[expr.name] = int(len(self.text_))
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
self.variables[expr.name] = self.text_[self.variables[expr.name] :]
self.variable_event[expr.name].set()
def _execute_commit_lazy_operations(self, expr: SglCommitLazy):
self.backend.commit_lazy_operations(self)
def _execute_concatenate_and_append_text(self, expr: SglConcateAndAppend):
new_text = ""
for s in expr.states:
exe = s.stream_executor
exe.sync()
new_text += exe.text_[exe.fork_start_text_pos :]
self._execute_fill(new_text)
def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend):
self_len = len(self.text_)
for i, s in enumerate(expr.states):
exe = s.stream_executor
exe.submit(SglCommitLazy())
for i, s in enumerate(expr.states):
exe = s.stream_executor
exe.sync()
assert exe.fork_start_text_pos == self_len
self.text_ += exe.text_[exe.fork_start_text_pos :]
src_rids = [state.stream_executor.sid for state in expr.states]
self.backend.concatenate_and_append(src_rids, self.sid)
def _init_var_event(self, expr):
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
self.variable_event[expr.name] = threading.Event()
if self.stream:
self.stream_var_event[expr.name] = threading.Event()
elif isinstance(expr, SglExprList):
for e in expr.expr_list:
self._init_var_event(e)
def _resolve_sampling_params(self, sampling_params):
"""
Construct sampling param based on default + override values
The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args)
, and `sampling_params` contains the override values from sgl.gen().
Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`.
It also extends the stop tokens based on the chat template.
"""
# deepcopy is required because the dict has lists inside
clone = copy.deepcopy(self.default_sampling_para)
for item in [
"max_new_tokens",
"min_new_tokens",
"n",
"stop",
"stop_token_ids",
"temperature",
"top_p",
"top_k",
"min_p",
"frequency_penalty",
"presence_penalty",
"ignore_eos",
"return_logprob",
"logprob_start_len",
"top_logprobs_num",
"return_text_in_logprobs",
"dtype",
"regex",
"json_schema",
]:
value = getattr(sampling_params, item, None)
if value is not None:
setattr(clone, item, value)
if self.chat_template.stop_str:
if clone.stop == ():
clone.stop = []
elif isinstance(clone.stop, str):
clone.stop = [clone.stop]
clone.stop += self.chat_template.stop_str
return clone
def __del__(self):
self.end()
class ProgramState:
"""The state of an SGL program."""
def __init__(self, stream_executor: StreamExecutor):
self.stream_executor = stream_executor
def _role_common(self, name: str, expr: Optional[SglExpr] = None):
if expr is not None:
role_expr = SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
self.stream_executor.submit(role_expr)
return role_expr
else:
@contextmanager
def role_scope():
self.stream_executor.submit(SglRoleBegin(name))
yield
self.stream_executor.submit(SglRoleEnd(name))
return role_scope()
def system(self, expr: Optional[SglExpr] = None):
return self._role_common("system", expr)
def user(self, expr: Optional[SglExpr] = None):
return self._role_common("user", expr)
def assistant(self, expr: Optional[SglExpr] = None):
return self._role_common("assistant", expr)
@contextmanager
def var_scope(self, name: str):
self.stream_executor.submit(SglVarScopeBegin(name))
yield
self.stream_executor.submit(SglVarScopeEnd(name))
def fork(
self,
size: int = 1,
position_ids_offset: Optional[List[int]] = None,
):
stream_executors = self.stream_executor.fork(size, position_ids_offset)
states = [ProgramState(x) for x in stream_executors]
state_group = ProgramStateGroup(states, self)
return state_group
@contextmanager
def copy(self, position_ids_offset: Optional[List[int]] = None):
state_group = self.fork(1, position_ids_offset)
try:
yield state_group[0]
finally:
state_group.join()
def text(self):
return self.stream_executor.text()
def messages(self):
return self.stream_executor.messages()
def sync(self):
return self.stream_executor.sync()
def error(self):
return self.stream_executor.error()
def text_iter(self, var_name: Optional[str] = None):
if self.stream_executor.stream:
prev = 0
if var_name is None:
event = self.stream_executor.stream_text_event
while True:
event.wait()
event.clear()
out = str(self.stream_executor.text_[prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.is_finished:
break
else:
event = None
while not event:
if var_name in self.stream_executor.stream_var_event:
event = self.stream_executor.stream_var_event[var_name]
if self.stream_executor.is_finished:
yield ""
return
while True:
event.wait()
event.clear()
out = str(self.stream_executor.variables[var_name][prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.variable_event[var_name].is_set():
break
else:
if var_name is None:
yield self.text()
else:
yield self.get_var(var_name)
async def text_async_iter(
self, var_name: Optional[str] = None, return_meta_data: bool = False
):
loop = asyncio.get_running_loop()
if self.stream_executor.stream:
prev = 0
if var_name is None:
event = self.stream_executor.stream_text_event
while True:
await loop.run_in_executor(None, event.wait)
event.clear()
out = str(self.stream_executor.text_[prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.is_finished:
break
else:
event = None
while not event:
if var_name in self.stream_executor.stream_var_event:
event = self.stream_executor.stream_var_event[var_name]
if self.stream_executor.is_finished:
yield ""
return
while True:
await loop.run_in_executor(None, event.wait)
event.clear()
out = str(self.stream_executor.variables[var_name][prev:])
prev += len(out)
if out:
if return_meta_data:
yield out, self.stream_executor.meta_info[var_name]
else:
yield out
if self.stream_executor.variable_event[var_name].is_set():
break
else:
if var_name is None:
yield self.text()
else:
yield self.get_var(var_name)
def get_var(self, name):
return self.stream_executor.get_var(name)
def set_var(self, name, value):
return self.stream_executor.set_var(name, value)
def get_meta_info(self, name):
return self.stream_executor.get_meta_info(name)
def __iadd__(self, other):
if other is None:
raise ValueError("Tried to append None to state.")
self.stream_executor.submit(other)
return self
def __getitem__(self, name):
return self.get_var(name)
def __setitem__(self, name, value):
self.set_var(name, value)
def __contains__(self, name):
return name in self.stream_executor.variables
def __del__(self):
self.stream_executor.end()
def __repr__(self) -> str:
return f"ProgramState({self.text()})"
class ProgramStateGroup:
def __init__(
self, states: List[ProgramState], src_state: Optional[ProgramState] = None
):
self.states = states
self.src_state = src_state
def join(self, mode: str = "gather_variable"):
if mode == "gather_variable":
# Copy variables back
src_vars = self.src_state.stream_executor.variables
src_var_set = set(src_vars.keys())
for child_state in self.states:
child_state.stream_executor.sync()
child_vars = child_state.stream_executor.variables
new_vars = set(child_vars.keys()) - src_var_set
for k in new_vars:
if k in src_vars:
src_vars[k].append(child_vars[k])
else:
src_vars[k] = [child_vars[k]]
elif mode == "concate_and_append":
# Concatenate and append KV cache
self.src_state += SglConcateAndAppend(self.states)
# Need a sync here. Otherwise, `states` can be deleted.
self.src_state.stream_executor.sync()
else:
raise ValueError(f"Invalid join mode: {mode}")
for s in self.states:
s.stream_executor.end()
def __getitem__(self, i: int):
return self.states[i]
def __setitem__(self, i: int, value):
assert self.states[i] == value
def __iadd__(self, other):
if isinstance(other, Callable):
# lambda function
for i in range(len(self.states)):
self.states[i] += other(i)
elif isinstance(other, SglExpr):
for i in range(len(self.states)):
self.states[i] += other
elif isinstance(other, (list, tuple)):
for i in range(len(self.states)):
self.states[i] += other[i]
else:
raise ValueError(f"Invalid value: {other}")
return self