147 lines
4.2 KiB
Python
147 lines
4.2 KiB
Python
import argparse
|
|
import json
|
|
import time
|
|
from typing import List, Tuple
|
|
|
|
import jsonschema
|
|
from datasets import load_dataset
|
|
|
|
import sglang as sgl
|
|
from sglang.global_config import global_config
|
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
from sglang.test.test_utils import (
|
|
add_common_sglang_args_and_parse,
|
|
select_sglang_backend,
|
|
)
|
|
from sglang.utils import dump_state_text
|
|
|
|
|
|
@sgl.function
|
|
def schema_gen(s, message: Tuple[str, str], json_schema: str):
|
|
system, user = message
|
|
s += sgl.system(system)
|
|
s += sgl.user(user)
|
|
s += sgl.assistant(
|
|
sgl.gen("json_output", temperature=0, max_tokens=256, json_schema=json_schema)
|
|
)
|
|
|
|
|
|
def contains_formats(schema, formats: List[str]):
|
|
if isinstance(schema, dict):
|
|
if schema.get("format", None) in formats:
|
|
return True
|
|
for value in schema.values():
|
|
if contains_formats(value, formats):
|
|
return True
|
|
elif isinstance(schema, list):
|
|
for item in schema:
|
|
if contains_formats(item, formats):
|
|
return True
|
|
return False
|
|
|
|
|
|
def convert_dataset(path: str):
|
|
raw_dataset = load_dataset(path)
|
|
dataset = []
|
|
for data in raw_dataset["train"]:
|
|
messages = data["prompt"]
|
|
schema = data["schema"]
|
|
obj = json.loads(schema)
|
|
|
|
# skip some corrupted examples
|
|
if obj.get("type", None) is None:
|
|
continue
|
|
|
|
# skip schema with format "email"
|
|
# which is not supported by outlines for now
|
|
if contains_formats(obj, ["email"]):
|
|
continue
|
|
|
|
system = messages[0]
|
|
user = messages[1]
|
|
assert system["role"] == "system", "invalid role"
|
|
assert user["role"] == "user", "invalid role"
|
|
assert len(messages) == 2, "invalid message length"
|
|
message = json.dumps(system["content"]), json.dumps(user["content"])
|
|
dataset.append(
|
|
{
|
|
"message": message,
|
|
"json_schema": schema,
|
|
}
|
|
)
|
|
|
|
return dataset
|
|
|
|
|
|
def bench_schema(args):
|
|
arguments = convert_dataset(args.data_path)
|
|
|
|
if args.num_jsons < 0 or args.num_jsons > len(arguments):
|
|
args.num_jsons = len(arguments)
|
|
arguments = arguments[: args.num_jsons]
|
|
|
|
# Select backend
|
|
backend = select_sglang_backend(args)
|
|
sgl.set_default_backend(backend)
|
|
|
|
# Run requests
|
|
tic = time.perf_counter()
|
|
states = schema_gen.run_batch(
|
|
arguments,
|
|
temperature=0,
|
|
num_threads=args.parallel,
|
|
progress_bar=True,
|
|
)
|
|
latency = time.perf_counter() - tic
|
|
|
|
# Check if the outputs are valid
|
|
indexes = []
|
|
for i, state in enumerate(states):
|
|
try:
|
|
schema = json.loads(arguments[i]["json_schema"])
|
|
obj = json.loads(state["json_output"])
|
|
assert jsonschema.validate(obj, schema) is None
|
|
except Exception as e:
|
|
print(e)
|
|
indexes.append(i)
|
|
|
|
return states, latency
|
|
|
|
|
|
def main(args):
|
|
states, latency = bench_schema(args)
|
|
|
|
# Compute accuracy
|
|
tokenizer = get_tokenizer(
|
|
global_config.default_backend.get_server_info()["tokenizer_path"]
|
|
)
|
|
output_jsons = [state["json_output"] for state in states]
|
|
num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons)
|
|
print(f"Latency: {latency:.3f}")
|
|
print(f"Output throughput: {num_output_tokens / latency:.3f} token/s")
|
|
print(f"#output tokens: {num_output_tokens}")
|
|
|
|
# Write results
|
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
|
with open(f"{args.backend}.jsonl", "w") as fout:
|
|
for state in states:
|
|
fout.write(state["json_output"] + "\n")
|
|
|
|
with open(args.result_file, "a") as fout:
|
|
value = {
|
|
"task": "json_schema",
|
|
"backend": args.backend,
|
|
"latency": round(latency, 3),
|
|
"num_jsons": args.num_jsons,
|
|
"parallel": args.parallel,
|
|
}
|
|
fout.write(json.dumps(value) + "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--data-path", type=str, default="NousResearch/json-mode-eval")
|
|
parser.add_argument("--num-jsons", type=int, default=-1)
|
|
args = add_common_sglang_args_and_parse(parser)
|
|
main(args)
|