139 lines
3.3 KiB
Python
139 lines
3.3 KiB
Python
import argparse
|
|
from enum import Enum
|
|
|
|
from pydantic import BaseModel, constr
|
|
|
|
import sglang as sgl
|
|
from sglang.srt.constrained.outlines_backend import build_regex_from_object
|
|
from sglang.test.test_utils import (
|
|
add_common_sglang_args_and_parse,
|
|
select_sglang_backend,
|
|
)
|
|
|
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
|
|
|
ip_jump_forward = (
|
|
r"The google's DNS sever address is "
|
|
+ IP_REGEX
|
|
+ r" and "
|
|
+ IP_REGEX
|
|
+ r". "
|
|
+ r"The google's website domain name is "
|
|
+ r"www\.(\w)+\.(\w)+"
|
|
+ r"."
|
|
)
|
|
|
|
|
|
# fmt: off
|
|
@sgl.function
|
|
def regex_gen(s):
|
|
s += "Q: What is the IP address of the Google DNS servers?\n"
|
|
s += "A: " + sgl.gen(
|
|
"answer",
|
|
max_tokens=128,
|
|
temperature=0,
|
|
regex=ip_jump_forward,
|
|
)
|
|
# fmt: on
|
|
|
|
json_jump_forward = (
|
|
r"""The information about Hogwarts is in the following JSON format\.\n"""
|
|
+ r"""\n\{\n"""
|
|
+ r""" "name": "[\w\d\s]*",\n"""
|
|
+ r""" "country": "[\w\d\s]*",\n"""
|
|
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]+,\n"""
|
|
+ r""" "population": [-+]?[0-9]+,\n"""
|
|
+ r""" "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n"""
|
|
+ r"""\}\n"""
|
|
)
|
|
|
|
# fmt: off
|
|
@sgl.function
|
|
def json_gen(s):
|
|
s += sgl.gen(
|
|
"json",
|
|
max_tokens=128,
|
|
temperature=0,
|
|
regex=json_jump_forward,
|
|
)
|
|
# fmt: on
|
|
|
|
|
|
class Weapon(str, Enum):
|
|
sword = "sword"
|
|
axe = "axe"
|
|
mace = "mace"
|
|
spear = "spear"
|
|
bow = "bow"
|
|
crossbow = "crossbow"
|
|
|
|
|
|
class Armor(str, Enum):
|
|
leather = "leather"
|
|
chainmail = "chainmail"
|
|
plate = "plate"
|
|
|
|
|
|
class Character(BaseModel):
|
|
name: constr(max_length=10)
|
|
age: int
|
|
armor: Armor
|
|
weapon: Weapon
|
|
strength: int
|
|
|
|
|
|
@sgl.function
|
|
def character_gen(s):
|
|
s += "Give me a character description who is a wizard.\n"
|
|
s += sgl.gen(
|
|
"character",
|
|
max_tokens=128,
|
|
temperature=0,
|
|
regex=build_regex_from_object(Character),
|
|
)
|
|
|
|
|
|
def main(args):
|
|
# Select backend
|
|
backend = select_sglang_backend(args)
|
|
sgl.set_default_backend(backend)
|
|
|
|
state = regex_gen.run(temperature=0)
|
|
|
|
print("=" * 20, "IP TEST", "=" * 20)
|
|
print(state.text())
|
|
|
|
state = json_gen.run(temperature=0)
|
|
|
|
print("=" * 20, "JSON TEST", "=" * 20)
|
|
print(state.text())
|
|
|
|
state = character_gen.run(temperature=0)
|
|
|
|
print("=" * 20, "CHARACTER TEST", "=" * 20)
|
|
print(state.text())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
args = add_common_sglang_args_and_parse(parser)
|
|
main(args)
|
|
|
|
# ==================== IP TEST ====================
|
|
# Q: What is the IP address of the Google DNS servers?
|
|
# A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com.
|
|
# ==================== JSON TEST ====================
|
|
# The information about Hogwarts is in the following JSON format.
|
|
|
|
# {
|
|
# "name": "Hogwarts School of Witchcraft and Wizardry",
|
|
# "country": "Scotland",
|
|
# "latitude": 55.566667,
|
|
# "population": 1000,
|
|
# "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"],
|
|
# }
|
|
|
|
# ==================== CHARACTER TEST ====================
|
|
# Give me a character description who is a wizard.
|
|
# { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 }
|