import argparse import json import time from concurrent.futures import ThreadPoolExecutor from functools import partial import guidance from tqdm import tqdm from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl # there are some FSM bugs with json regex converted from pydantic model # here use a string regex instead # regex_string = build_regex_from_object(HarryPoterRole) character_regex = ( r"""\{\n""" + r""" "name": "[\w\d\s]{1,16}",\n""" + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n""" + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n""" + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n""" + r""" "wand": \{\n""" + r""" "wood": "[\w\d\s]{1,16}",\n""" + r""" "core": "[\w\d\s]{1,16}",\n""" + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n""" + r""" \},\n""" + r""" "alive": "(Alive|Deceased)",\n""" + r""" "patronus": "[\w\d\s]{1,16}",\n""" + r""" "bogart": "[\w\d\s]{1,16}"\n""" + r"""\}""" ) city_regex = ( r"""\{\n""" + r""" "name": "[\w\d\s]{1,16}",\n""" + r""" "country": "[\w\d\s]{1,16}",\n""" + r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n""" + r""" "population": [-+]?[0-9]{1,9},\n""" + r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n""" + r"""\}""" ) # fmt: off def character_gen(name, generate): s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n" s += generate(s, max_tokens=256, regex=character_regex) return s # fmt: on # fmt: off def city_gen(document, generate): s = "Please extract the information of a city from the following wikipedia page.\n" s += "Page begin.\n" + document + "Page end.\n" s += "Here is the name, country, and symbol of the city in JSON format.\n" s += generate(s, max_tokens=256, regex=city_regex) return s # fmt: on @guidance def character_maker(lm, name): regex_str_no_quote = r"[\w\d\s]+" regex_float = r"[0-9]+\.[0-9]+" lm += f"""\ {name} is a character in Harry Potter. Please fill in the following information about this character. {{ "name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}", "house": "{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}", "blood status": "{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}", "occupation": "{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}", "wand": {{ "wood": "{guidance.gen("wood", max_tokens=16, regex=regex_str_no_quote)}", "core": "{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}", "length": {guidance.gen('length', max_tokens=10, regex=regex_float)} }}, "alive": "{guidance.select(options=['Alive', 'Deceased'], name='alive')}", "patronus": "{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}", "bogart": "{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}" }} """ return lm async def call_generate_lmql( prompt, temperature, max_tokens, regex, max_len=4096, model=None, **kwargs ): assert model is not None import lmql @lmql.query(model=model) async def program(question, max_tokens, regex): '''lmql """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and REGEX(ANSWER, regex) return ANSWER ''' return await program( question=prompt, temperature=temperature, max_tokens=max_tokens, max_len=max_len, regex=regex, **kwargs, ) @guidance def city_maker(lm, document): regex_str_no_quote = r"[\w\d\s]+" regex_float = r"[0-9]+\.[0-9]+" lm += f"""\ Please extract the information of a city from the following wikipedia page. Page begin. {document} Page end. Here is the name, country, and symbol of the city in JSON format. {{ "name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}", "country": "{guidance.gen("country", max_tokens=16, regex=regex_str_no_quote)}", "latitude": {guidance.gen("latitude", max_tokens=10, regex=regex_float)}, "population": {guidance.gen("population", max_tokens=10, regex=r"[0-9]+")}, "top 3 landmarks": [ "{guidance.gen("landmark1", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark2", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark3", max_tokens=16, regex=regex_str_no_quote)}" ] }} """ return lm def bench_character(args): arguments = [] with open(args.data_path, "r") as f: for line in f: arguments.append({"name": line.strip()}) arguments = arguments[: args.num_jsons] states = [None] * len(arguments) # Select backend if args.backend == "outlines": call_generate = partial(get_call_generate(args), temperature=0) def get_one_answer(i): states[i] = character_gen(**arguments[i], generate=call_generate) elif args.backend == "guidance": model = guidance.models.LlamaCpp( args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx, ) def get_one_answer(i): lm = model + character_maker(**arguments[i]) states[i] = lm elif args.backend == "lmql": import asyncio import lmql model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}") call_generate = partial( call_generate_lmql, model=model, max_tokens=256, regex=character_regex, ) async def get_one_answer_async(i): states[i] = await call_generate(prompt=arguments[i]["name"], temperature=0) else: raise ValueError(f"Invalid backend: {args.backend}") tic = time.time() if args.backend != "lmql": if args.parallel == 1: for i in tqdm(range(len(arguments))): get_one_answer(i) else: with ThreadPoolExecutor(args.parallel) as executor: rets = list( tqdm( executor.map(get_one_answer, list(range(len(arguments)))), total=len(arguments), ) ) for _ in rets: pass else: batches = [] for i in range(0, len(arguments), args.parallel): batches.append(list(range(i, min(i + args.parallel, len(arguments))))) loop = asyncio.get_event_loop() for bt in tqdm(batches): loop.run_until_complete( asyncio.gather(*[get_one_answer_async(i) for i in bt]) ) latency = time.time() - tic return states, latency def bench_city_doc(args): arguments = [] for line in read_jsonl(args.data_path): arguments.append({"document": line["document"]}) arguments = arguments[: args.num_jsons] states = [None] * len(arguments) # Select backend if args.backend == "outlines": call_generate = partial(get_call_generate(args), temperature=0) def get_one_answer(i): states[i] = city_gen(**arguments[i], generate=call_generate) elif args.backend == "guidance": model = guidance.models.LlamaCpp( args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx, ) def get_one_answer(i): lm = model + city_maker(**arguments[i]) states[i] = lm else: raise ValueError(f"Invalid backend: {args.backend}") tic = time.time() if args.parallel == 1: for i in tqdm(range(len(arguments))): get_one_answer(i) else: with ThreadPoolExecutor(args.parallel) as executor: rets = executor.map(get_one_answer, list(range(len(arguments)))) for _ in rets: pass latency = time.time() - tic return states, latency def main(args): if args.mode == "character": args.data_path = "dataset.txt" states, latency = bench_character(args) elif args.mode == "city": args.data_path = "questions.jsonl" states, latency = bench_city_doc(args) # Compute accuracy print(f"Latency: {latency:.3f}") # Write results dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states) with open(args.result_file, "a") as fout: value = { "task": "json_jump_forward", "backend": args.backend, "latency": round(latency, 3), "num_jsons": args.num_jsons, "mode": args.mode, "parallel": args.parallel, } fout.write(json.dumps(value) + "\n") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data-path", type=str) parser.add_argument("--num-jsons", type=int, default=50) parser.add_argument( "--mode", type=str, default="character", choices=["character", "city"] ) args = add_common_other_args_and_parse(parser) main(args)