130 lines
4.0 KiB
Python
130 lines
4.0 KiB
Python
import unittest
|
|
|
|
import sglang as sgl
|
|
from sglang.lang.backend.base_backend import BaseBackend
|
|
from sglang.lang.chat_template import get_chat_template
|
|
from sglang.test.test_utils import CustomTestCase
|
|
|
|
|
|
class TestTracing(CustomTestCase):
|
|
def test_few_shot_qa(self):
|
|
@sgl.function
|
|
def few_shot_qa(s, question):
|
|
s += "The following are questions with answers.\n\n"
|
|
s += "Q: What is the capital of France?\n"
|
|
s += "A: Paris\n"
|
|
s += "Q: " + question + "\n"
|
|
s += "A:" + sgl.gen("answer", stop="\n")
|
|
|
|
tracer = few_shot_qa.trace()
|
|
# print(tracer.last_node.print_graph_dfs() + "\n")
|
|
|
|
def test_select(self):
|
|
@sgl.function
|
|
def capital(s):
|
|
s += "The capital of France is"
|
|
s += sgl.select("capital", ["Paris. ", "London. "])
|
|
s += "It is a city" + sgl.gen("description", stop=".")
|
|
|
|
tracer = capital.trace()
|
|
# print(tracer.last_node.print_graph_dfs() + "\n")
|
|
|
|
def test_raise_warning(self):
|
|
@sgl.function
|
|
def wrong(s, question):
|
|
s += f"I want to ask {question}"
|
|
|
|
try:
|
|
tracer = wrong.trace()
|
|
raised = False
|
|
except TypeError:
|
|
raised = True
|
|
|
|
assert raised
|
|
|
|
def test_multi_function(self):
|
|
@sgl.function
|
|
def expand(s, tip):
|
|
s += (
|
|
"Please expand the following tip into a detailed paragraph:"
|
|
+ tip
|
|
+ "\n"
|
|
)
|
|
s += sgl.gen("detailed_tip")
|
|
|
|
@sgl.function
|
|
def tip_suggestion(s, topic):
|
|
s += "Here are 2 tips for " + topic + ".\n"
|
|
|
|
s += "1." + sgl.gen("tip_1", stop=["\n", ":", "."]) + "\n"
|
|
s += "2." + sgl.gen("tip_2", stop=["\n", ":", "."]) + "\n"
|
|
|
|
branch1 = expand(tip=s["tip_1"])
|
|
branch2 = expand(tip=s["tip_2"])
|
|
|
|
s += "Tip 1: " + branch1["detailed_tip"] + "\n"
|
|
s += "Tip 2: " + branch2["detailed_tip"] + "\n"
|
|
s += "In summary" + sgl.gen("summary")
|
|
|
|
compiled = tip_suggestion.compile()
|
|
# compiled.print_graph()
|
|
|
|
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
|
|
state = compiled.run(topic="staying healthy")
|
|
# print(state.text() + "\n")
|
|
|
|
states = compiled.run_batch(
|
|
[
|
|
{"topic": "staying healthy"},
|
|
{"topic": "staying happy"},
|
|
{"topic": "earning money"},
|
|
],
|
|
temperature=0,
|
|
)
|
|
# for s in states:
|
|
# print(s.text() + "\n")
|
|
|
|
def test_role(self):
|
|
@sgl.function
|
|
def multi_turn_chat(s):
|
|
s += sgl.user("Who are you?")
|
|
s += sgl.assistant(sgl.gen("answer_1"))
|
|
s += sgl.user("Who created you?")
|
|
s += sgl.assistant(sgl.gen("answer_2"))
|
|
|
|
backend = BaseBackend()
|
|
backend.chat_template = get_chat_template("llama-2-chat")
|
|
|
|
compiled = multi_turn_chat.compile(backend=backend)
|
|
# compiled.print_graph()
|
|
|
|
def test_fork(self):
|
|
@sgl.function
|
|
def tip_suggestion(s):
|
|
s += (
|
|
"Here are three tips for staying healthy: "
|
|
"1. Balanced Diet; "
|
|
"2. Regular Exercise; "
|
|
"3. Adequate Sleep\n"
|
|
)
|
|
|
|
forks = s.fork(3)
|
|
for i in range(3):
|
|
forks[i] += f"Now, expand tip {i+1} into a paragraph:\n"
|
|
forks[i] += sgl.gen(f"detailed_tip")
|
|
|
|
s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
|
|
s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
|
|
s += "Tip 3:" + forks[2]["detailed_tip"] + "\n"
|
|
s += "In summary" + sgl.gen("summary")
|
|
|
|
tracer = tip_suggestion.trace()
|
|
# print(tracer.last_node.print_graph_dfs())
|
|
|
|
a = tip_suggestion.run(backend=sgl.OpenAI("gpt-3.5-turbo-instruct"))
|
|
# print(a.text())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|