sglang0.4.5.post1/python/sglang/test/test_programs.py

576 lines
18 KiB
Python

"""This file contains the SGL programs used for unit testing."""
import json
import re
import time
import numpy as np
import sglang as sgl
from sglang.utils import download_and_cache_file, read_jsonl
def test_few_shot_qa():
@sgl.function
def few_shot_qa(s, question):
s += "The following are questions with answers.\n\n"
s += "Q: What is the capital of France?\n"
s += "A: Paris\n"
s += "Q: What is the capital of Germany?\n"
s += "A: Berlin\n"
s += "Q: What is the capital of Italy?\n"
s += "A: Rome\n"
s += "Q: " + question + "\n"
s += "A:" + sgl.gen("answer", stop="\n", temperature=0)
ret = few_shot_qa.run(question="What is the capital of the United States?")
assert "washington" in ret["answer"].strip().lower(), f"answer: {ret['answer']}"
rets = few_shot_qa.run_batch(
[
{"question": "What is the capital of Japan?"},
{"question": "What is the capital of the United Kingdom?"},
{"question": "What is the capital city of China?"},
],
temperature=0.1,
)
answers = [x["answer"].strip().lower() for x in rets]
assert answers == ["tokyo", "london", "beijing"], f"answers: {answers}"
def test_mt_bench():
@sgl.function
def answer_mt_bench(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question_1)
s += sgl.assistant(sgl.gen("answer_1"))
with s.user():
s += question_2
with s.assistant():
s += sgl.gen("answer_2")
question_1 = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions."
question_2 = (
"Rewrite your previous response. Start every sentence with the letter A."
)
ret = answer_mt_bench.run(
question_1=question_1, question_2=question_2, temperature=0.7, max_new_tokens=64
)
assert len(ret.messages()) in [4, 5]
def test_select(check_answer):
@sgl.function
def true_or_false(s, statement):
s += "Determine whether the statement below is True, False, or Unknown.\n"
s += "Statement: The capital of France is Pairs.\n"
s += "Answer: True\n"
s += "Statement: " + statement + "\n"
s += "Answer:" + sgl.select("answer", ["True", "False", "Unknown"])
ret = true_or_false.run(
statement="The capital of Germany is Berlin.",
)
if check_answer:
assert ret["answer"] == "True", ret.text()
else:
assert ret["answer"] in ["True", "False", "Unknown"]
ret = true_or_false.run(
statement="The capital of Canada is Tokyo.",
)
if check_answer:
assert ret["answer"] == "False", ret.text()
else:
assert ret["answer"] in ["True", "False", "Unknown"]
ret = true_or_false.run(
statement="Purple is a better color than green.",
)
if check_answer:
assert ret["answer"] == "Unknown", ret.text()
else:
assert ret["answer"] in ["True", "False", "Unknown"]
def test_decode_int():
@sgl.function
def decode_int(s):
s += "The number of hours in a day is " + sgl.gen_int("hours") + "\n"
s += "The number of days in a year is " + sgl.gen_int("days") + "\n"
ret = decode_int.run(temperature=0.1)
assert int(ret["hours"]) == 24, ret.text()
assert int(ret["days"]) == 365, ret.text()
def test_decode_json_regex():
@sgl.function
def decode_json(s):
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
s += "Generate a JSON object to describe the basic city information of Paris.\n"
s += "Here are the JSON object:\n"
# NOTE: we recommend using dtype gen or whole regex string to control the output
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen(regex=REGEX_STR) + ",\n"
s += ' "population": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT, stop=[" ", "\n"]) + "\n"
s += "}"
ret = decode_json.run(temperature=0.0)
try:
js_obj = json.loads(ret["json_output"])
except json.decoder.JSONDecodeError:
print("JSONDecodeError", ret["json_output"])
raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
def test_decode_json():
@sgl.function
def decode_json(s):
s += "Generate a JSON object to describe the basic city information of Paris.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen_string() + ",\n"
s += ' "population": ' + sgl.gen_int() + ",\n"
s += ' "area": ' + sgl.gen(dtype=int) + ",\n"
s += ' "country": ' + sgl.gen_string() + ",\n"
s += ' "timezone": ' + sgl.gen(dtype=str) + "\n"
s += "}"
ret = decode_json.run(max_new_tokens=64)
try:
js_obj = json.loads(ret["json_output"])
except json.decoder.JSONDecodeError:
print("JSONDecodeError", ret["json_output"])
raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
def test_expert_answer(check_answer=True):
@sgl.function
def expert_answer(s, question):
s += "Question: " + question + "\n"
s += (
"A good person to answer this question is"
+ sgl.gen("expert", stop=[".", "\n"])
+ ".\n"
)
s += (
"For example,"
+ s["expert"]
+ " would answer that "
+ sgl.gen("answer", stop=".")
+ "."
)
ret = expert_answer.run(question="What is the capital of France?", temperature=0.1)
if check_answer:
assert "paris" in ret.text().lower(), f"Answer: {ret.text()}"
def test_tool_use():
def calculate(expression):
return f"{eval(expression)}"
@sgl.function
def tool_use(s, lhs, rhs):
s += "Please perform computations using a calculator. You can use calculate(expression) to get the results.\n"
s += "For example,\ncalculate(1+2)=3\ncalculate(3*4)=12\n"
s += "Question: What is the product of " + str(lhs) + " and " + str(rhs) + "?\n"
s += (
"Answer: The answer is calculate("
+ sgl.gen("expression", stop=")")
+ ") = "
)
with s.var_scope("answer"):
s += calculate(s["expression"])
lhs, rhs = 257, 983
ret = tool_use(lhs=lhs, rhs=rhs, temperature=0)
assert int(ret["answer"]) == lhs * rhs
def test_react():
@sgl.function
def react(s, question):
s += """
Question: Which country does the founder of Microsoft live in?
Thought 1: I need to search for the founder of Microsoft.
Action 1: Search [Founder of Microsoft].
Observation 1: The founder of Microsoft is Bill Gates.
Thought 2: I need to search for the country where Bill Gates lives in.
Action 2: Search [Where does Bill Gates live].
Observation 2: Bill Gates lives in the United States.
Thought 3: The answer is the United States.
Action 3: Finish [United States].\n
"""
s += "Question: " + question + "\n"
for i in range(1, 5):
s += f"Thought {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n"
s += f"Action {i}: " + sgl.select(f"action_{i}", ["Search", "Finish"])
if s[f"action_{i}"] == "Search":
s += " [" + sgl.gen(stop="]") + "].\n"
s += f"Observation {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n"
else:
s += " [" + sgl.gen("answer", stop="]") + "].\n"
break
ret = react.run(
question="What country does the creator of Linux live in?",
temperature=0.1,
)
answer = ret["answer"].lower()
assert "finland" in answer or "states" in answer
def test_parallel_decoding():
max_tokens = 64
fork_size = 5
@sgl.function
def parallel_decoding(s, topic):
s += "Act as a helpful assistant.\n"
s += "USER: Give some tips for " + topic + ".\n"
s += (
"ASSISTANT: Okay. Here are "
+ str(fork_size)
+ " concise tips, each under 8 words:\n"
)
# Generate skeleton
for i in range(1, 1 + fork_size):
s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n"
# Generate detailed tips
forks = s.fork(fork_size)
for i in range(fork_size):
forks[
i
] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:"
forks[i] += sgl.gen("detailed_tip", max_tokens, stop=["\n\n"])
forks.join()
# Concatenate tips and summarize
s += "Here are these tips with detailed explanation:\n"
for i in range(fork_size):
s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n"
s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3)
assert isinstance(ret["summary"], str)
def test_parallel_encoding(check_answer=True):
max_tokens = 64
@sgl.function
def parallel_encoding(s, question, context_0, context_1, context_2):
s += "USER: I will ask a question based on some statements.\n"
s += "ASSISTANT: Sure. I will give the answer.\n"
s += "USER: Please memorize these statements.\n"
contexts = [context_0, context_1, context_2]
forks = s.fork(len(contexts))
forks += lambda i: f"Statement {i}: " + contexts[i] + "\n"
forks.join(mode="concate_and_append")
s += "Now, please answer the following question. " "Do not list options."
s += "\nQuestion: " + question + "\n"
s += "ASSISTANT:" + sgl.gen("answer", max_tokens=max_tokens)
ret = parallel_encoding.run(
question="Who is the father of Julian?",
context_0="Ethan is the father of Liam.",
context_1="Noah is the father of Julian.",
context_2="Oliver is the father of Carlos.",
temperature=0,
)
answer = ret["answer"]
if check_answer:
assert "Noah" in answer
def test_image_qa():
@sgl.function
def image_qa(s, question):
s += sgl.user(sgl.image("example_image.png") + question)
s += sgl.assistant(sgl.gen("answer"))
state = image_qa.run(
question="Please describe this image in simple words.",
temperature=0,
max_new_tokens=64,
)
assert (
"taxi" in state.messages()[-1]["content"]
or "car" in state.messages()[-1]["content"]
), f"{state.messages()[-1]['content']}"
def test_stream():
@sgl.function
def qa(s, question):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question)
s += sgl.assistant(sgl.gen("answer"))
ret = qa(
question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
stream=True,
)
out = ""
for chunk in ret.text_iter():
out += chunk
ret = qa(
question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
stream=True,
)
out = ""
for chunk in ret.text_iter("answer"):
out += chunk
def test_regex():
regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
@sgl.function
def regex_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + sgl.gen(
"answer",
temperature=0,
regex=regex,
)
state = regex_gen.run()
answer = state["answer"]
assert re.match(regex, answer)
def test_dtype_gen():
@sgl.function
def dtype_gen(s):
s += "Q: What is the full name of DNS?\n"
s += "A: The full nams is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n"
s += "Q: Which year was DNS invented?\n"
s += "A: " + sgl.gen("int_res", dtype=int) + "\n"
s += "Q: What is the value of pi?\n"
s += "A: " + sgl.gen("float_res", dtype=float) + "\n"
s += "Q: Is the sky blue?\n"
s += "A: " + sgl.gen("bool_res", dtype=bool) + "\n"
state = dtype_gen.run()
try:
state["int_res"] = int(state["int_res"])
state["float_res"] = float(state["float_res"])
state["bool_res"] = bool(state["bool_res"])
# assert state["str_res"].startswith('"') and state["str_res"].endswith('"')
except ValueError:
print(state)
raise
def test_completion_speculative():
@sgl.function(num_api_spec_tokens=64)
def gen_character_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += (
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
)
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
@sgl.function
def gen_character_no_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += (
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
)
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
token_usage = sgl.global_config.default_backend.token_usage
token_usage.reset()
gen_character_spec().sync()
usage_with_spec = token_usage.prompt_tokens
token_usage.reset()
gen_character_no_spec().sync()
usage_with_no_spec = token_usage.prompt_tokens
assert (
usage_with_spec < usage_with_no_spec
), f"{usage_with_spec} vs {usage_with_no_spec}"
def test_chat_completion_speculative():
@sgl.function(num_api_spec_tokens=256)
def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:")
s += sgl.assistant(
"Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
)
s += sgl.user("Please generate new Name, Birthday and Job.\n")
s += sgl.assistant(
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
+ "\nJob:"
+ sgl.gen("job", stop="\n")
)
gen_character_spec().sync()
def test_hellaswag_select():
"""Benchmark the accuracy of sgl.select on the HellaSwag dataset."""
def get_one_example(lines, i, include_answer):
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
if include_answer:
ret += lines[i]["endings"][lines[i]["label"]]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
# Read data
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
num_questions = 200
num_shots = 20
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
choices = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def few_shot_hellaswag(s, question, choices):
s += few_shot_examples + question
s += sgl.select("answer", choices=choices)
#####################################
########## SGL Program End ##########
#####################################
# Run requests
tic = time.time()
rets = few_shot_hellaswag.run_batch(
arguments,
temperature=0,
num_threads=64,
progress_bar=True,
generator_style=False,
)
preds = []
for i, ret in enumerate(rets):
preds.append(choices[i].index(ret["answer"]))
latency = time.time() - tic
# Compute accuracy
accuracy = np.mean(np.array(preds) == np.array(labels))
# Test generator style of run_batch
tic = time.time()
rets = few_shot_hellaswag.run_batch(
arguments,
temperature=0,
num_threads=64,
progress_bar=True,
generator_style=True,
)
preds_gen = []
for i, ret in enumerate(rets):
preds_gen.append(choices[i].index(ret["answer"]))
latency_gen = time.time() - tic
# Compute accuracy
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
print(f"{accuracy=}, {accuracy_gen=}")
assert np.abs(accuracy_gen - accuracy) < 0.1
assert np.abs(latency_gen - latency) < 1
return accuracy, latency
def test_gen_min_new_tokens():
"""
Validate sgl.gen(min_tokens) functionality.
The test asks a question where, without a min_tokens constraint, the generated answer is expected to be short.
By enforcing the min_tokens parameter, we ensure the generated answer has at least the specified number of tokens.
We verify that the number of tokens in the answer is >= the min_tokens threshold.
"""
import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenizer
model_path = sgl.global_config.default_backend.endpoint.get_model_name()
MIN_TOKENS, MAX_TOKENS = 64, 128
@sgl.function
def convo_1(s):
s += sgl.user("What is the capital of the United States?")
s += sgl.assistant(
sgl.gen("answer", min_tokens=MIN_TOKENS, max_tokens=MAX_TOKENS)
)
def assert_min_tokens(tokenizer, text):
token_ids = tokenizer.encode(text)
assert (
len(token_ids) >= MIN_TOKENS
), f"Generated {len(token_ids)} tokens, min required: {MIN_TOKENS}. Text: {text}"
tokenizer = get_tokenizer(model_path)
state = convo_1.run()
assert_min_tokens(tokenizer, state["answer"])