sglang0.4.5.post1/scripts/deprecated/test_jump_forward.py

139 lines
3.3 KiB
Python

import argparse
from enum import Enum
from pydantic import BaseModel, constr
import sglang as sgl
from sglang.srt.constrained.outlines_backend import build_regex_from_object
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
ip_jump_forward = (
r"The google's DNS sever address is "
+ IP_REGEX
+ r" and "
+ IP_REGEX
+ r". "
+ r"The google's website domain name is "
+ r"www\.(\w)+\.(\w)+"
+ r"."
)
# fmt: off
@sgl.function
def regex_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + sgl.gen(
"answer",
max_tokens=128,
temperature=0,
regex=ip_jump_forward,
)
# fmt: on
json_jump_forward = (
r"""The information about Hogwarts is in the following JSON format\.\n"""
+ r"""\n\{\n"""
+ r""" "name": "[\w\d\s]*",\n"""
+ r""" "country": "[\w\d\s]*",\n"""
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]+,\n"""
+ r""" "population": [-+]?[0-9]+,\n"""
+ r""" "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n"""
+ r"""\}\n"""
)
# fmt: off
@sgl.function
def json_gen(s):
s += sgl.gen(
"json",
max_tokens=128,
temperature=0,
regex=json_jump_forward,
)
# fmt: on
class Weapon(str, Enum):
sword = "sword"
axe = "axe"
mace = "mace"
spear = "spear"
bow = "bow"
crossbow = "crossbow"
class Armor(str, Enum):
leather = "leather"
chainmail = "chainmail"
plate = "plate"
class Character(BaseModel):
name: constr(max_length=10)
age: int
armor: Armor
weapon: Weapon
strength: int
@sgl.function
def character_gen(s):
s += "Give me a character description who is a wizard.\n"
s += sgl.gen(
"character",
max_tokens=128,
temperature=0,
regex=build_regex_from_object(Character),
)
def main(args):
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
state = regex_gen.run(temperature=0)
print("=" * 20, "IP TEST", "=" * 20)
print(state.text())
state = json_gen.run(temperature=0)
print("=" * 20, "JSON TEST", "=" * 20)
print(state.text())
state = character_gen.run(temperature=0)
print("=" * 20, "CHARACTER TEST", "=" * 20)
print(state.text())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = add_common_sglang_args_and_parse(parser)
main(args)
# ==================== IP TEST ====================
# Q: What is the IP address of the Google DNS servers?
# A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com.
# ==================== JSON TEST ====================
# The information about Hogwarts is in the following JSON format.
# {
# "name": "Hogwarts School of Witchcraft and Wizardry",
# "country": "Scotland",
# "latitude": 55.566667,
# "population": 1000,
# "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"],
# }
# ==================== CHARACTER TEST ====================
# Give me a character description who is a wizard.
# { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 }