sglang0.4.5.post1/python/sglang/lang/compiler.py

238 lines
7.5 KiB
Python

import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
from sglang.lang.ir import (
SglArgument,
SglConstantText,
SglExpr,
SglSamplingParams,
SglVariable,
)
def compile_func(function, backend):
tracer = function.trace(backend=backend)
compiler = CompiledFunction(tracer, function)
return compiler
class CompiledFunction:
def __init__(self, tracer, function):
self.function = function
self.last_node = CompGraphNode(tracer.last_node)
self.expr_to_node = {}
self.build_graph(tracer)
self.topological_sort()
def build_graph(self, tracer):
self.nodes = [self.last_node]
self.expr_to_node[tracer.last_node] = self.nodes[-1]
rename_pid = {}
visited = set([tracer.last_node])
head = 0
while head < len(self.nodes):
cur_node = self.nodes[head]
# add prev node
prev_node = cur_node.expr.prev_node
if prev_node is not None:
if prev_node not in visited:
visited.add(prev_node)
self.nodes.append(CompGraphNode(prev_node))
self.expr_to_node[prev_node] = self.nodes[-1]
cur_node.prev_node = self.expr_to_node[prev_node]
self.expr_to_node[prev_node].add_next_node(cur_node)
# add source node
if isinstance(cur_node.expr, SglVariable):
if cur_node.expr.name in tracer.variables:
source = tracer.variables[cur_node.expr.name].source
else:
source = cur_node.expr.source
if source not in visited:
visited.add(source)
self.nodes.append(CompGraphNode(source))
self.expr_to_node[source] = self.nodes[-1]
cur_node.source_node = self.expr_to_node[source]
self.expr_to_node[source].add_next_node(cur_node)
head += 1
# rename pid
if cur_node.expr.pid not in rename_pid:
rename_pid[cur_node.expr.pid] = len(rename_pid)
cur_node.expr.pid = rename_pid[cur_node.expr.pid]
def topological_sort(self):
prevd = {}
cand = Queue()
for x in self.nodes:
prevd[x] = (x.prev_node is not None) + (x.source_node is not None)
if prevd[x] == 0:
cand.put(x)
new_list = []
while cand.qsize() > 0:
head = cand.get()
new_list.append(head)
for x in head.next_nodes:
prevd[x] -= 1
if prevd[x] == 0:
cand.put(x)
self.nodes = new_list
def print_graph(
self,
):
for node in self.nodes:
print(node)
def run_internal(
self,
backend,
kwargs,
default_sampling_para,
):
stream_executor_ids = set([x.expr.pid for x in self.nodes])
stream_executors = {}
for x in stream_executor_ids:
arguments = kwargs if x == self.last_node.expr.pid else {}
stream_executors[x] = StreamExecutor(
backend, arguments, default_sampling_para, None, False
)
for node in self.nodes:
se_id = node.expr.pid
expr = node.expr
if isinstance(expr, SglVariable):
# Make a copy for SglVariable
expr = SglVariable(expr.name, expr.source)
expr.source_stream_executor = stream_executors[
node.source_node.expr.pid
]
elif isinstance(expr, SglArgument):
# Substitute SglArgument
expr = kwargs[expr.name]
stream_executors[se_id].submit(expr)
for stream_executor in stream_executors.values():
stream_executor.end()
return ProgramState(stream_executors[self.last_node.expr.pid])
def run(
self,
*,
max_new_tokens: int = 128,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
**kwargs,
):
backend = backend or global_config.default_backend
kwargs.update(self.function.bind_arguments)
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
return self.run_internal(backend, kwargs, default_sampling_para)
def run_batch(
self,
batch_kwargs,
*,
max_new_tokens: int = 128,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
num_threads: Union[str, int] = "auto",
):
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
assert isinstance(batch_kwargs[0], dict)
backend = backend or global_config.default_backend
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
# Extract prefix by tracing and cache it
if len(batch_kwargs) > 1:
cache_program(self.function, backend)
# Run all programs
if num_threads == "auto":
num_threads = multiprocessing.cpu_count()
num_threads = min(num_threads, len(batch_kwargs))
if num_threads == 1:
rets = []
for arguments in batch_kwargs:
rets.append(
self.run_internal(backend, arguments, default_sampling_para)
)
else:
with ThreadPoolExecutor(num_threads) as executor:
futures = []
for arguments in batch_kwargs:
futures.append(
executor.submit(
self.run_internal, backend, arguments, default_sampling_para
)
)
rets = [f.result() for f in futures]
rets[-1].sync()
return rets
class CompGraphNode:
def __init__(
self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None
):
self.expr = expr
self.next_nodes = next_nodes or []
self.prev_node = prev_node
self.source_node = source_node
def add_next_node(self, other):
self.next_nodes.append(other)
def __repr__(self):
re = f"stream {self.expr.pid:2d}: "
re += f"%{self.expr.node_id} = "
if self.prev_node is not None:
re += f"%{self.prev_node.expr.node_id} + "
re += repr(self.expr)
return re