238 lines
7.5 KiB
Python
238 lines
7.5 KiB
Python
import multiprocessing
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from queue import Queue
|
|
from typing import List, Union
|
|
|
|
from sglang.global_config import global_config
|
|
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
|
|
from sglang.lang.ir import (
|
|
SglArgument,
|
|
SglConstantText,
|
|
SglExpr,
|
|
SglSamplingParams,
|
|
SglVariable,
|
|
)
|
|
|
|
|
|
def compile_func(function, backend):
|
|
tracer = function.trace(backend=backend)
|
|
compiler = CompiledFunction(tracer, function)
|
|
return compiler
|
|
|
|
|
|
class CompiledFunction:
|
|
def __init__(self, tracer, function):
|
|
self.function = function
|
|
|
|
self.last_node = CompGraphNode(tracer.last_node)
|
|
self.expr_to_node = {}
|
|
self.build_graph(tracer)
|
|
self.topological_sort()
|
|
|
|
def build_graph(self, tracer):
|
|
self.nodes = [self.last_node]
|
|
self.expr_to_node[tracer.last_node] = self.nodes[-1]
|
|
|
|
rename_pid = {}
|
|
|
|
visited = set([tracer.last_node])
|
|
head = 0
|
|
while head < len(self.nodes):
|
|
cur_node = self.nodes[head]
|
|
|
|
# add prev node
|
|
prev_node = cur_node.expr.prev_node
|
|
if prev_node is not None:
|
|
if prev_node not in visited:
|
|
visited.add(prev_node)
|
|
self.nodes.append(CompGraphNode(prev_node))
|
|
self.expr_to_node[prev_node] = self.nodes[-1]
|
|
cur_node.prev_node = self.expr_to_node[prev_node]
|
|
self.expr_to_node[prev_node].add_next_node(cur_node)
|
|
|
|
# add source node
|
|
if isinstance(cur_node.expr, SglVariable):
|
|
if cur_node.expr.name in tracer.variables:
|
|
source = tracer.variables[cur_node.expr.name].source
|
|
else:
|
|
source = cur_node.expr.source
|
|
if source not in visited:
|
|
visited.add(source)
|
|
self.nodes.append(CompGraphNode(source))
|
|
self.expr_to_node[source] = self.nodes[-1]
|
|
cur_node.source_node = self.expr_to_node[source]
|
|
self.expr_to_node[source].add_next_node(cur_node)
|
|
head += 1
|
|
|
|
# rename pid
|
|
if cur_node.expr.pid not in rename_pid:
|
|
rename_pid[cur_node.expr.pid] = len(rename_pid)
|
|
cur_node.expr.pid = rename_pid[cur_node.expr.pid]
|
|
|
|
def topological_sort(self):
|
|
prevd = {}
|
|
cand = Queue()
|
|
for x in self.nodes:
|
|
prevd[x] = (x.prev_node is not None) + (x.source_node is not None)
|
|
if prevd[x] == 0:
|
|
cand.put(x)
|
|
new_list = []
|
|
while cand.qsize() > 0:
|
|
head = cand.get()
|
|
new_list.append(head)
|
|
for x in head.next_nodes:
|
|
prevd[x] -= 1
|
|
if prevd[x] == 0:
|
|
cand.put(x)
|
|
self.nodes = new_list
|
|
|
|
def print_graph(
|
|
self,
|
|
):
|
|
for node in self.nodes:
|
|
print(node)
|
|
|
|
def run_internal(
|
|
self,
|
|
backend,
|
|
kwargs,
|
|
default_sampling_para,
|
|
):
|
|
stream_executor_ids = set([x.expr.pid for x in self.nodes])
|
|
stream_executors = {}
|
|
for x in stream_executor_ids:
|
|
arguments = kwargs if x == self.last_node.expr.pid else {}
|
|
stream_executors[x] = StreamExecutor(
|
|
backend, arguments, default_sampling_para, None, False
|
|
)
|
|
for node in self.nodes:
|
|
se_id = node.expr.pid
|
|
expr = node.expr
|
|
if isinstance(expr, SglVariable):
|
|
# Make a copy for SglVariable
|
|
expr = SglVariable(expr.name, expr.source)
|
|
expr.source_stream_executor = stream_executors[
|
|
node.source_node.expr.pid
|
|
]
|
|
elif isinstance(expr, SglArgument):
|
|
# Substitute SglArgument
|
|
expr = kwargs[expr.name]
|
|
stream_executors[se_id].submit(expr)
|
|
for stream_executor in stream_executors.values():
|
|
stream_executor.end()
|
|
return ProgramState(stream_executors[self.last_node.expr.pid])
|
|
|
|
def run(
|
|
self,
|
|
*,
|
|
max_new_tokens: int = 128,
|
|
stop: Union[str, List[str]] = (),
|
|
temperature: float = 1.0,
|
|
top_p: float = 1.0,
|
|
top_k: int = -1,
|
|
min_p: float = 0.0,
|
|
frequency_penalty: float = 0.0,
|
|
presence_penalty: float = 0.0,
|
|
backend=None,
|
|
**kwargs,
|
|
):
|
|
backend = backend or global_config.default_backend
|
|
|
|
kwargs.update(self.function.bind_arguments)
|
|
|
|
default_sampling_para = SglSamplingParams(
|
|
max_new_tokens=max_new_tokens,
|
|
stop=stop,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
min_p=min_p,
|
|
frequency_penalty=frequency_penalty,
|
|
presence_penalty=presence_penalty,
|
|
)
|
|
|
|
return self.run_internal(backend, kwargs, default_sampling_para)
|
|
|
|
def run_batch(
|
|
self,
|
|
batch_kwargs,
|
|
*,
|
|
max_new_tokens: int = 128,
|
|
stop: Union[str, List[str]] = (),
|
|
temperature: float = 1.0,
|
|
top_p: float = 1.0,
|
|
top_k: int = -1,
|
|
min_p: float = 0.0,
|
|
frequency_penalty: float = 0.0,
|
|
presence_penalty: float = 0.0,
|
|
backend=None,
|
|
num_threads: Union[str, int] = "auto",
|
|
):
|
|
assert isinstance(batch_kwargs, (list, tuple))
|
|
if len(batch_kwargs) == 0:
|
|
return []
|
|
assert isinstance(batch_kwargs[0], dict)
|
|
|
|
backend = backend or global_config.default_backend
|
|
|
|
default_sampling_para = SglSamplingParams(
|
|
max_new_tokens=max_new_tokens,
|
|
stop=stop,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
min_p=min_p,
|
|
frequency_penalty=frequency_penalty,
|
|
presence_penalty=presence_penalty,
|
|
)
|
|
|
|
# Extract prefix by tracing and cache it
|
|
if len(batch_kwargs) > 1:
|
|
cache_program(self.function, backend)
|
|
|
|
# Run all programs
|
|
if num_threads == "auto":
|
|
num_threads = multiprocessing.cpu_count()
|
|
num_threads = min(num_threads, len(batch_kwargs))
|
|
|
|
if num_threads == 1:
|
|
rets = []
|
|
for arguments in batch_kwargs:
|
|
rets.append(
|
|
self.run_internal(backend, arguments, default_sampling_para)
|
|
)
|
|
else:
|
|
with ThreadPoolExecutor(num_threads) as executor:
|
|
futures = []
|
|
for arguments in batch_kwargs:
|
|
futures.append(
|
|
executor.submit(
|
|
self.run_internal, backend, arguments, default_sampling_para
|
|
)
|
|
)
|
|
rets = [f.result() for f in futures]
|
|
rets[-1].sync()
|
|
|
|
return rets
|
|
|
|
|
|
class CompGraphNode:
|
|
def __init__(
|
|
self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None
|
|
):
|
|
self.expr = expr
|
|
self.next_nodes = next_nodes or []
|
|
self.prev_node = prev_node
|
|
self.source_node = source_node
|
|
|
|
def add_next_node(self, other):
|
|
self.next_nodes.append(other)
|
|
|
|
def __repr__(self):
|
|
re = f"stream {self.expr.pid:2d}: "
|
|
re += f"%{self.expr.node_id} = "
|
|
if self.prev_node is not None:
|
|
re += f"%{self.prev_node.expr.node_id} + "
|
|
re += repr(self.expr)
|
|
return re
|