284 lines
8.1 KiB
Python
284 lines
8.1 KiB
Python
"""Tracing a program."""
|
|
|
|
import uuid
|
|
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
|
from sglang.global_config import global_config
|
|
from sglang.lang.backend.base_backend import BaseBackend
|
|
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
|
|
from sglang.lang.ir import (
|
|
SglArgument,
|
|
SglCommitLazy,
|
|
SglConcateAndAppend,
|
|
SglConstantText,
|
|
SglExpr,
|
|
SglExprList,
|
|
SglFork,
|
|
SglFunction,
|
|
SglGen,
|
|
SglGetForkItem,
|
|
SglRoleBegin,
|
|
SglRoleEnd,
|
|
SglSelect,
|
|
SglVariable,
|
|
SglVarScopeBegin,
|
|
SglVarScopeEnd,
|
|
)
|
|
|
|
|
|
class StopTracing(Exception):
|
|
pass
|
|
|
|
|
|
def extract_prefix_by_tracing(program, backend):
|
|
# Create dummy arguments
|
|
dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}
|
|
arguments = dummy_arguments
|
|
arguments.update(program.bind_arguments)
|
|
|
|
# Trace
|
|
tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)
|
|
try:
|
|
with TracingScope(tracer):
|
|
tracer.ret_value = program.func(tracer, **arguments)
|
|
except (StopTracing, TypeError, AttributeError):
|
|
# Some exceptions may not be catched
|
|
pass
|
|
|
|
# Run and cache prefix
|
|
prefix = ""
|
|
for expr in tracer.flatten_nodes():
|
|
if isinstance(expr, SglConstantText):
|
|
prefix += expr.value
|
|
else:
|
|
break
|
|
return prefix
|
|
|
|
|
|
def trace_program(program, arguments, backend):
|
|
# Create dummy backend
|
|
if backend is None:
|
|
backend = BaseBackend()
|
|
|
|
# Create dummy arguments
|
|
dummy_arguments = {
|
|
name: SglArgument(name, None)
|
|
for name in program.arg_names
|
|
if name not in arguments
|
|
}
|
|
arguments.update(dummy_arguments)
|
|
arguments.update(program.bind_arguments)
|
|
|
|
# Trace
|
|
tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)
|
|
with TracingScope(tracer):
|
|
tracer.ret_value = program.func(tracer, **arguments)
|
|
return tracer
|
|
|
|
|
|
class TracerProgramState(ProgramState):
|
|
def __init__(self, backend, arguments, only_trace_prefix):
|
|
self.pid = uuid.uuid4().hex
|
|
self.backend = backend
|
|
self.arguments: Dict[str, Any] = arguments
|
|
self.only_trace_prefix = only_trace_prefix
|
|
|
|
if hasattr(backend, "endpoint"):
|
|
self.backend = backend.endpoint
|
|
|
|
self.nodes = []
|
|
self.last_node = None
|
|
self.variables = {}
|
|
self.ret_value = None
|
|
|
|
# For completion
|
|
|
|
# For chat
|
|
self.messages_ = []
|
|
self.cur_role = None
|
|
self.chat_template = self.backend.get_chat_template()
|
|
|
|
# For multi states
|
|
self.child_states = []
|
|
|
|
cur_scope = TracingScope.get_current_scope()
|
|
if cur_scope is not None:
|
|
cur_scope.add_child_state(self)
|
|
|
|
##################################
|
|
########### Public API ###########
|
|
##################################
|
|
|
|
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
|
|
assert size >= 1
|
|
|
|
if self.only_trace_prefix:
|
|
raise StopTracing()
|
|
|
|
fork_node = SglFork(size)
|
|
fork_node.prev_node = self.last_node
|
|
|
|
states = [
|
|
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
|
|
for _ in range(size)
|
|
]
|
|
|
|
for i in range(size):
|
|
node = SglGetForkItem(i)
|
|
node.prev_node = fork_node
|
|
states[i].last_node = node
|
|
states[i].variables = dict(self.variables)
|
|
states[i].messages_ = list(self.messages_)
|
|
states[i].cur_role = self.cur_role
|
|
states[i].chat_template = self.chat_template
|
|
|
|
state_group = ProgramStateGroup(states, self)
|
|
|
|
return state_group
|
|
|
|
##################################
|
|
########## Internal API ##########
|
|
##################################
|
|
|
|
def _append_node(self, other: SglExpr):
|
|
self.nodes.append(other)
|
|
other.prev_node = self.last_node
|
|
self.last_node = other
|
|
|
|
def _execute(self, other: SglExpr):
|
|
if isinstance(other, str):
|
|
other = SglConstantText(other)
|
|
|
|
other.pid = self.pid
|
|
|
|
if isinstance(other, SglConstantText):
|
|
self._execute_fill(other)
|
|
elif isinstance(other, SglGen):
|
|
self._execute_gen(other)
|
|
elif isinstance(other, SglSelect):
|
|
self._execute_select(other)
|
|
elif isinstance(other, SglExprList):
|
|
for x in other.expr_list:
|
|
self._execute(x)
|
|
elif isinstance(other, SglRoleBegin):
|
|
self._execute_role_begin(other)
|
|
elif isinstance(other, SglRoleEnd):
|
|
self._execute_role_end(other)
|
|
elif isinstance(other, SglVarScopeBegin):
|
|
self._execute_var_scope_begin(other)
|
|
elif isinstance(other, SglVarScopeEnd):
|
|
self._execute_var_scope_end(other)
|
|
else:
|
|
if self.only_trace_prefix:
|
|
raise StopTracing()
|
|
else:
|
|
self._append_node(other)
|
|
|
|
return self
|
|
|
|
def __iadd__(self, other):
|
|
self._execute(other)
|
|
return self
|
|
|
|
def _execute_fill(self, expr: SglConstantText):
|
|
if isinstance(expr, str):
|
|
expr = SglConstantText(expr)
|
|
self._append_node(expr)
|
|
|
|
def _execute_gen(self, expr: SglGen):
|
|
name = expr.name if expr.name is not None else "gen_" + str(len(self.variables))
|
|
new_node = SglVariable(name, source=expr)
|
|
self.variables[name] = new_node
|
|
self._append_node(expr)
|
|
|
|
def _execute_select(self, expr: SglSelect):
|
|
name = (
|
|
expr.name if expr.name is not None else "select_" + str(len(self.variables))
|
|
)
|
|
new_node = SglVariable(name, source=expr)
|
|
self.variables[name] = new_node
|
|
self._append_node(expr)
|
|
|
|
def _execute_role_begin(self, expr: SglRoleBegin):
|
|
assert self.cur_role is None, "Nested roles are not allowed."
|
|
|
|
if len(self.messages_) == 0 and expr.role != "system":
|
|
# Insert default system message
|
|
default_system = self.chat_template.default_system_prompt
|
|
if default_system:
|
|
self._execute_role_begin(SglRoleBegin("system"))
|
|
self._execute_fill(default_system)
|
|
self._execute_role_end(SglRoleEnd("system"))
|
|
|
|
self.cur_role = expr.role
|
|
|
|
prefix, suffix = self.chat_template.get_prefix_and_suffix(
|
|
expr.role, self.messages_
|
|
)
|
|
|
|
self._execute_fill(prefix)
|
|
|
|
def _execute_role_end(self, expr: SglRoleEnd):
|
|
prefix, suffix = self.chat_template.get_prefix_and_suffix(
|
|
expr.role, self.messages_
|
|
)
|
|
|
|
self._execute_fill(suffix)
|
|
|
|
self.messages_.append({"role": expr.role, "content": ""})
|
|
|
|
self.cur_role = None
|
|
|
|
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
|
new_node = SglVariable(name, source=self.last_node)
|
|
self.variables[name] = new_node
|
|
|
|
def get_var(self, name):
|
|
ret = self.arguments.get(name, None)
|
|
if ret is not None:
|
|
return ret
|
|
|
|
v = self.variables[name]
|
|
return SglVariable(v.name, v.source)
|
|
|
|
def flatten_nodes(self):
|
|
def traverse(cur):
|
|
if isinstance(cur, SglExprList):
|
|
for child in cur.expr_list:
|
|
traverse(child)
|
|
else:
|
|
ret.append(cur)
|
|
|
|
ret = []
|
|
for x in self.nodes:
|
|
traverse(x)
|
|
return ret
|
|
|
|
def __del__(self):
|
|
pass
|
|
|
|
|
|
class TracingScope:
|
|
cur_scope = None
|
|
|
|
def __init__(self, tracer_state: TracerProgramState):
|
|
self.tracer_state = tracer_state
|
|
self.last_scope = TracingScope.cur_scope
|
|
|
|
def __enter__(self):
|
|
TracingScope.cur_scope = self
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
TracingScope.cur_scope = self.last_scope
|
|
|
|
@staticmethod
|
|
def get_current_scope():
|
|
return TracingScope.cur_scope
|
|
|
|
def add_child_state(self, state: TracerProgramState):
|
|
cur_scope = self
|
|
while cur_scope is not None:
|
|
cur_scope.tracer_state.child_states.append(state)
|
|
cur_scope = cur_scope.last_scope
|