576 lines
18 KiB
Python
576 lines
18 KiB
Python
"""This file contains the SGL programs used for unit testing."""
|
|
|
|
import json
|
|
import re
|
|
import time
|
|
|
|
import numpy as np
|
|
|
|
import sglang as sgl
|
|
from sglang.utils import download_and_cache_file, read_jsonl
|
|
|
|
|
|
def test_few_shot_qa():
|
|
@sgl.function
|
|
def few_shot_qa(s, question):
|
|
s += "The following are questions with answers.\n\n"
|
|
s += "Q: What is the capital of France?\n"
|
|
s += "A: Paris\n"
|
|
s += "Q: What is the capital of Germany?\n"
|
|
s += "A: Berlin\n"
|
|
s += "Q: What is the capital of Italy?\n"
|
|
s += "A: Rome\n"
|
|
s += "Q: " + question + "\n"
|
|
s += "A:" + sgl.gen("answer", stop="\n", temperature=0)
|
|
|
|
ret = few_shot_qa.run(question="What is the capital of the United States?")
|
|
assert "washington" in ret["answer"].strip().lower(), f"answer: {ret['answer']}"
|
|
|
|
rets = few_shot_qa.run_batch(
|
|
[
|
|
{"question": "What is the capital of Japan?"},
|
|
{"question": "What is the capital of the United Kingdom?"},
|
|
{"question": "What is the capital city of China?"},
|
|
],
|
|
temperature=0.1,
|
|
)
|
|
answers = [x["answer"].strip().lower() for x in rets]
|
|
assert answers == ["tokyo", "london", "beijing"], f"answers: {answers}"
|
|
|
|
|
|
def test_mt_bench():
|
|
@sgl.function
|
|
def answer_mt_bench(s, question_1, question_2):
|
|
s += sgl.system("You are a helpful assistant.")
|
|
s += sgl.user(question_1)
|
|
s += sgl.assistant(sgl.gen("answer_1"))
|
|
with s.user():
|
|
s += question_2
|
|
with s.assistant():
|
|
s += sgl.gen("answer_2")
|
|
|
|
question_1 = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions."
|
|
question_2 = (
|
|
"Rewrite your previous response. Start every sentence with the letter A."
|
|
)
|
|
ret = answer_mt_bench.run(
|
|
question_1=question_1, question_2=question_2, temperature=0.7, max_new_tokens=64
|
|
)
|
|
assert len(ret.messages()) in [4, 5]
|
|
|
|
|
|
def test_select(check_answer):
|
|
@sgl.function
|
|
def true_or_false(s, statement):
|
|
s += "Determine whether the statement below is True, False, or Unknown.\n"
|
|
s += "Statement: The capital of France is Pairs.\n"
|
|
s += "Answer: True\n"
|
|
s += "Statement: " + statement + "\n"
|
|
s += "Answer:" + sgl.select("answer", ["True", "False", "Unknown"])
|
|
|
|
ret = true_or_false.run(
|
|
statement="The capital of Germany is Berlin.",
|
|
)
|
|
if check_answer:
|
|
assert ret["answer"] == "True", ret.text()
|
|
else:
|
|
assert ret["answer"] in ["True", "False", "Unknown"]
|
|
|
|
ret = true_or_false.run(
|
|
statement="The capital of Canada is Tokyo.",
|
|
)
|
|
if check_answer:
|
|
assert ret["answer"] == "False", ret.text()
|
|
else:
|
|
assert ret["answer"] in ["True", "False", "Unknown"]
|
|
|
|
ret = true_or_false.run(
|
|
statement="Purple is a better color than green.",
|
|
)
|
|
if check_answer:
|
|
assert ret["answer"] == "Unknown", ret.text()
|
|
else:
|
|
assert ret["answer"] in ["True", "False", "Unknown"]
|
|
|
|
|
|
def test_decode_int():
|
|
@sgl.function
|
|
def decode_int(s):
|
|
s += "The number of hours in a day is " + sgl.gen_int("hours") + "\n"
|
|
s += "The number of days in a year is " + sgl.gen_int("days") + "\n"
|
|
|
|
ret = decode_int.run(temperature=0.1)
|
|
assert int(ret["hours"]) == 24, ret.text()
|
|
assert int(ret["days"]) == 365, ret.text()
|
|
|
|
|
|
def test_decode_json_regex():
|
|
@sgl.function
|
|
def decode_json(s):
|
|
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
|
|
|
|
s += "Generate a JSON object to describe the basic city information of Paris.\n"
|
|
s += "Here are the JSON object:\n"
|
|
|
|
# NOTE: we recommend using dtype gen or whole regex string to control the output
|
|
|
|
with s.var_scope("json_output"):
|
|
s += "{\n"
|
|
s += ' "name": ' + sgl.gen(regex=REGEX_STR) + ",\n"
|
|
s += ' "population": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
|
|
s += ' "area": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
|
|
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT, stop=[" ", "\n"]) + "\n"
|
|
s += "}"
|
|
|
|
ret = decode_json.run(temperature=0.0)
|
|
try:
|
|
js_obj = json.loads(ret["json_output"])
|
|
except json.decoder.JSONDecodeError:
|
|
print("JSONDecodeError", ret["json_output"])
|
|
raise
|
|
assert isinstance(js_obj["name"], str)
|
|
assert isinstance(js_obj["population"], int)
|
|
|
|
|
|
def test_decode_json():
|
|
@sgl.function
|
|
def decode_json(s):
|
|
s += "Generate a JSON object to describe the basic city information of Paris.\n"
|
|
|
|
with s.var_scope("json_output"):
|
|
s += "{\n"
|
|
s += ' "name": ' + sgl.gen_string() + ",\n"
|
|
s += ' "population": ' + sgl.gen_int() + ",\n"
|
|
s += ' "area": ' + sgl.gen(dtype=int) + ",\n"
|
|
s += ' "country": ' + sgl.gen_string() + ",\n"
|
|
s += ' "timezone": ' + sgl.gen(dtype=str) + "\n"
|
|
s += "}"
|
|
|
|
ret = decode_json.run(max_new_tokens=64)
|
|
try:
|
|
js_obj = json.loads(ret["json_output"])
|
|
except json.decoder.JSONDecodeError:
|
|
print("JSONDecodeError", ret["json_output"])
|
|
raise
|
|
assert isinstance(js_obj["name"], str)
|
|
assert isinstance(js_obj["population"], int)
|
|
|
|
|
|
def test_expert_answer(check_answer=True):
|
|
@sgl.function
|
|
def expert_answer(s, question):
|
|
s += "Question: " + question + "\n"
|
|
s += (
|
|
"A good person to answer this question is"
|
|
+ sgl.gen("expert", stop=[".", "\n"])
|
|
+ ".\n"
|
|
)
|
|
s += (
|
|
"For example,"
|
|
+ s["expert"]
|
|
+ " would answer that "
|
|
+ sgl.gen("answer", stop=".")
|
|
+ "."
|
|
)
|
|
|
|
ret = expert_answer.run(question="What is the capital of France?", temperature=0.1)
|
|
|
|
if check_answer:
|
|
assert "paris" in ret.text().lower(), f"Answer: {ret.text()}"
|
|
|
|
|
|
def test_tool_use():
|
|
def calculate(expression):
|
|
return f"{eval(expression)}"
|
|
|
|
@sgl.function
|
|
def tool_use(s, lhs, rhs):
|
|
s += "Please perform computations using a calculator. You can use calculate(expression) to get the results.\n"
|
|
s += "For example,\ncalculate(1+2)=3\ncalculate(3*4)=12\n"
|
|
s += "Question: What is the product of " + str(lhs) + " and " + str(rhs) + "?\n"
|
|
s += (
|
|
"Answer: The answer is calculate("
|
|
+ sgl.gen("expression", stop=")")
|
|
+ ") = "
|
|
)
|
|
with s.var_scope("answer"):
|
|
s += calculate(s["expression"])
|
|
|
|
lhs, rhs = 257, 983
|
|
ret = tool_use(lhs=lhs, rhs=rhs, temperature=0)
|
|
assert int(ret["answer"]) == lhs * rhs
|
|
|
|
|
|
def test_react():
|
|
@sgl.function
|
|
def react(s, question):
|
|
s += """
|
|
Question: Which country does the founder of Microsoft live in?
|
|
Thought 1: I need to search for the founder of Microsoft.
|
|
Action 1: Search [Founder of Microsoft].
|
|
Observation 1: The founder of Microsoft is Bill Gates.
|
|
Thought 2: I need to search for the country where Bill Gates lives in.
|
|
Action 2: Search [Where does Bill Gates live].
|
|
Observation 2: Bill Gates lives in the United States.
|
|
Thought 3: The answer is the United States.
|
|
Action 3: Finish [United States].\n
|
|
"""
|
|
|
|
s += "Question: " + question + "\n"
|
|
|
|
for i in range(1, 5):
|
|
s += f"Thought {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n"
|
|
s += f"Action {i}: " + sgl.select(f"action_{i}", ["Search", "Finish"])
|
|
|
|
if s[f"action_{i}"] == "Search":
|
|
s += " [" + sgl.gen(stop="]") + "].\n"
|
|
s += f"Observation {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n"
|
|
else:
|
|
s += " [" + sgl.gen("answer", stop="]") + "].\n"
|
|
break
|
|
|
|
ret = react.run(
|
|
question="What country does the creator of Linux live in?",
|
|
temperature=0.1,
|
|
)
|
|
answer = ret["answer"].lower()
|
|
assert "finland" in answer or "states" in answer
|
|
|
|
|
|
def test_parallel_decoding():
|
|
max_tokens = 64
|
|
fork_size = 5
|
|
|
|
@sgl.function
|
|
def parallel_decoding(s, topic):
|
|
s += "Act as a helpful assistant.\n"
|
|
s += "USER: Give some tips for " + topic + ".\n"
|
|
s += (
|
|
"ASSISTANT: Okay. Here are "
|
|
+ str(fork_size)
|
|
+ " concise tips, each under 8 words:\n"
|
|
)
|
|
|
|
# Generate skeleton
|
|
for i in range(1, 1 + fork_size):
|
|
s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n"
|
|
|
|
# Generate detailed tips
|
|
forks = s.fork(fork_size)
|
|
for i in range(fork_size):
|
|
forks[
|
|
i
|
|
] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:"
|
|
forks[i] += sgl.gen("detailed_tip", max_tokens, stop=["\n\n"])
|
|
forks.join()
|
|
|
|
# Concatenate tips and summarize
|
|
s += "Here are these tips with detailed explanation:\n"
|
|
for i in range(fork_size):
|
|
s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n"
|
|
|
|
s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
|
|
|
|
ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3)
|
|
assert isinstance(ret["summary"], str)
|
|
|
|
|
|
def test_parallel_encoding(check_answer=True):
|
|
max_tokens = 64
|
|
|
|
@sgl.function
|
|
def parallel_encoding(s, question, context_0, context_1, context_2):
|
|
s += "USER: I will ask a question based on some statements.\n"
|
|
s += "ASSISTANT: Sure. I will give the answer.\n"
|
|
s += "USER: Please memorize these statements.\n"
|
|
|
|
contexts = [context_0, context_1, context_2]
|
|
|
|
forks = s.fork(len(contexts))
|
|
forks += lambda i: f"Statement {i}: " + contexts[i] + "\n"
|
|
forks.join(mode="concate_and_append")
|
|
|
|
s += "Now, please answer the following question. " "Do not list options."
|
|
s += "\nQuestion: " + question + "\n"
|
|
s += "ASSISTANT:" + sgl.gen("answer", max_tokens=max_tokens)
|
|
|
|
ret = parallel_encoding.run(
|
|
question="Who is the father of Julian?",
|
|
context_0="Ethan is the father of Liam.",
|
|
context_1="Noah is the father of Julian.",
|
|
context_2="Oliver is the father of Carlos.",
|
|
temperature=0,
|
|
)
|
|
answer = ret["answer"]
|
|
|
|
if check_answer:
|
|
assert "Noah" in answer
|
|
|
|
|
|
def test_image_qa():
|
|
@sgl.function
|
|
def image_qa(s, question):
|
|
s += sgl.user(sgl.image("example_image.png") + question)
|
|
s += sgl.assistant(sgl.gen("answer"))
|
|
|
|
state = image_qa.run(
|
|
question="Please describe this image in simple words.",
|
|
temperature=0,
|
|
max_new_tokens=64,
|
|
)
|
|
|
|
assert (
|
|
"taxi" in state.messages()[-1]["content"]
|
|
or "car" in state.messages()[-1]["content"]
|
|
), f"{state.messages()[-1]['content']}"
|
|
|
|
|
|
def test_stream():
|
|
@sgl.function
|
|
def qa(s, question):
|
|
s += sgl.system("You are a helpful assistant.")
|
|
s += sgl.user(question)
|
|
s += sgl.assistant(sgl.gen("answer"))
|
|
|
|
ret = qa(
|
|
question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
|
|
stream=True,
|
|
)
|
|
out = ""
|
|
for chunk in ret.text_iter():
|
|
out += chunk
|
|
|
|
ret = qa(
|
|
question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
|
|
stream=True,
|
|
)
|
|
out = ""
|
|
for chunk in ret.text_iter("answer"):
|
|
out += chunk
|
|
|
|
|
|
def test_regex():
|
|
regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
|
|
|
@sgl.function
|
|
def regex_gen(s):
|
|
s += "Q: What is the IP address of the Google DNS servers?\n"
|
|
s += "A: " + sgl.gen(
|
|
"answer",
|
|
temperature=0,
|
|
regex=regex,
|
|
)
|
|
|
|
state = regex_gen.run()
|
|
answer = state["answer"]
|
|
assert re.match(regex, answer)
|
|
|
|
|
|
def test_dtype_gen():
|
|
@sgl.function
|
|
def dtype_gen(s):
|
|
s += "Q: What is the full name of DNS?\n"
|
|
s += "A: The full nams is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n"
|
|
s += "Q: Which year was DNS invented?\n"
|
|
s += "A: " + sgl.gen("int_res", dtype=int) + "\n"
|
|
s += "Q: What is the value of pi?\n"
|
|
s += "A: " + sgl.gen("float_res", dtype=float) + "\n"
|
|
s += "Q: Is the sky blue?\n"
|
|
s += "A: " + sgl.gen("bool_res", dtype=bool) + "\n"
|
|
|
|
state = dtype_gen.run()
|
|
|
|
try:
|
|
state["int_res"] = int(state["int_res"])
|
|
state["float_res"] = float(state["float_res"])
|
|
state["bool_res"] = bool(state["bool_res"])
|
|
# assert state["str_res"].startswith('"') and state["str_res"].endswith('"')
|
|
except ValueError:
|
|
print(state)
|
|
raise
|
|
|
|
|
|
def test_completion_speculative():
|
|
@sgl.function(num_api_spec_tokens=64)
|
|
def gen_character_spec(s):
|
|
s += "Construct a character within the following format:\n"
|
|
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
|
s += "\nPlease generate new Name, Birthday and Job.\n"
|
|
s += (
|
|
"Name:"
|
|
+ sgl.gen("name", stop="\n")
|
|
+ "\nBirthday:"
|
|
+ sgl.gen("birthday", stop="\n")
|
|
)
|
|
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
|
|
|
@sgl.function
|
|
def gen_character_no_spec(s):
|
|
s += "Construct a character within the following format:\n"
|
|
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
|
s += "\nPlease generate new Name, Birthday and Job.\n"
|
|
s += (
|
|
"Name:"
|
|
+ sgl.gen("name", stop="\n")
|
|
+ "\nBirthday:"
|
|
+ sgl.gen("birthday", stop="\n")
|
|
)
|
|
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
|
|
|
|
token_usage = sgl.global_config.default_backend.token_usage
|
|
|
|
token_usage.reset()
|
|
gen_character_spec().sync()
|
|
usage_with_spec = token_usage.prompt_tokens
|
|
|
|
token_usage.reset()
|
|
gen_character_no_spec().sync()
|
|
usage_with_no_spec = token_usage.prompt_tokens
|
|
|
|
assert (
|
|
usage_with_spec < usage_with_no_spec
|
|
), f"{usage_with_spec} vs {usage_with_no_spec}"
|
|
|
|
|
|
def test_chat_completion_speculative():
|
|
@sgl.function(num_api_spec_tokens=256)
|
|
def gen_character_spec(s):
|
|
s += sgl.system("You are a helpful assistant.")
|
|
s += sgl.user("Construct a character within the following format:")
|
|
s += sgl.assistant(
|
|
"Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
|
)
|
|
s += sgl.user("Please generate new Name, Birthday and Job.\n")
|
|
s += sgl.assistant(
|
|
"Name:"
|
|
+ sgl.gen("name", stop="\n")
|
|
+ "\nBirthday:"
|
|
+ sgl.gen("birthday", stop="\n")
|
|
+ "\nJob:"
|
|
+ sgl.gen("job", stop="\n")
|
|
)
|
|
|
|
gen_character_spec().sync()
|
|
|
|
|
|
def test_hellaswag_select():
|
|
"""Benchmark the accuracy of sgl.select on the HellaSwag dataset."""
|
|
|
|
def get_one_example(lines, i, include_answer):
|
|
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
|
if include_answer:
|
|
ret += lines[i]["endings"][lines[i]["label"]]
|
|
return ret
|
|
|
|
def get_few_shot_examples(lines, k):
|
|
ret = ""
|
|
for i in range(k):
|
|
ret += get_one_example(lines, i, True) + "\n\n"
|
|
return ret
|
|
|
|
# Read data
|
|
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
|
filename = download_and_cache_file(url)
|
|
lines = list(read_jsonl(filename))
|
|
|
|
# Construct prompts
|
|
num_questions = 200
|
|
num_shots = 20
|
|
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
|
|
|
questions = []
|
|
choices = []
|
|
labels = []
|
|
for i in range(len(lines[:num_questions])):
|
|
questions.append(get_one_example(lines, i, False))
|
|
choices.append(lines[i]["endings"])
|
|
labels.append(lines[i]["label"])
|
|
arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]
|
|
|
|
#####################################
|
|
######### SGL Program Begin #########
|
|
#####################################
|
|
|
|
import sglang as sgl
|
|
|
|
@sgl.function
|
|
def few_shot_hellaswag(s, question, choices):
|
|
s += few_shot_examples + question
|
|
s += sgl.select("answer", choices=choices)
|
|
|
|
#####################################
|
|
########## SGL Program End ##########
|
|
#####################################
|
|
|
|
# Run requests
|
|
tic = time.time()
|
|
rets = few_shot_hellaswag.run_batch(
|
|
arguments,
|
|
temperature=0,
|
|
num_threads=64,
|
|
progress_bar=True,
|
|
generator_style=False,
|
|
)
|
|
preds = []
|
|
for i, ret in enumerate(rets):
|
|
preds.append(choices[i].index(ret["answer"]))
|
|
latency = time.time() - tic
|
|
|
|
# Compute accuracy
|
|
accuracy = np.mean(np.array(preds) == np.array(labels))
|
|
|
|
# Test generator style of run_batch
|
|
tic = time.time()
|
|
rets = few_shot_hellaswag.run_batch(
|
|
arguments,
|
|
temperature=0,
|
|
num_threads=64,
|
|
progress_bar=True,
|
|
generator_style=True,
|
|
)
|
|
preds_gen = []
|
|
for i, ret in enumerate(rets):
|
|
preds_gen.append(choices[i].index(ret["answer"]))
|
|
latency_gen = time.time() - tic
|
|
|
|
# Compute accuracy
|
|
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
|
print(f"{accuracy=}, {accuracy_gen=}")
|
|
assert np.abs(accuracy_gen - accuracy) < 0.1
|
|
assert np.abs(latency_gen - latency) < 1
|
|
|
|
return accuracy, latency
|
|
|
|
|
|
def test_gen_min_new_tokens():
|
|
"""
|
|
Validate sgl.gen(min_tokens) functionality.
|
|
|
|
The test asks a question where, without a min_tokens constraint, the generated answer is expected to be short.
|
|
By enforcing the min_tokens parameter, we ensure the generated answer has at least the specified number of tokens.
|
|
We verify that the number of tokens in the answer is >= the min_tokens threshold.
|
|
"""
|
|
import sglang as sgl
|
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
|
|
model_path = sgl.global_config.default_backend.endpoint.get_model_name()
|
|
MIN_TOKENS, MAX_TOKENS = 64, 128
|
|
|
|
@sgl.function
|
|
def convo_1(s):
|
|
s += sgl.user("What is the capital of the United States?")
|
|
s += sgl.assistant(
|
|
sgl.gen("answer", min_tokens=MIN_TOKENS, max_tokens=MAX_TOKENS)
|
|
)
|
|
|
|
def assert_min_tokens(tokenizer, text):
|
|
token_ids = tokenizer.encode(text)
|
|
assert (
|
|
len(token_ids) >= MIN_TOKENS
|
|
), f"Generated {len(token_ids)} tokens, min required: {MIN_TOKENS}. Text: {text}"
|
|
|
|
tokenizer = get_tokenizer(model_path)
|
|
|
|
state = convo_1.run()
|
|
assert_min_tokens(tokenizer, state["answer"])
|