import argparse import json import re import time import numpy as np import sglang as sgl from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) from sglang.utils import dump_state_text @sgl.function def line_retrieval(s, prefix, suffix, body_0, body_1, body_2, body_3): s += prefix + "\n" contexts = [body_0, body_1, body_2, body_3] position_ids_offset = [i * 1000 for i in range(len(contexts))] forks = s.fork(len(contexts), position_ids_offset) forks += lambda i: contexts[i] + "\n" forks.join(mode="concate_and_append") s += "\n" + suffix s += sgl.gen("answer", max_tokens=16) def eval_model(args, line_obj, num_hoops, src_indices, dst_percents): arguments = [] labels = [] sum_src_indices = [] sum_dst_indices = [] for i in range(len(src_indices)): for j in range(len(dst_percents)): src_index = src_indices[i] dst_percent = dst_percents[j] query_indices = line_obj["group_by_num_hoops"][str(num_hoops)] query_indices = [ q for q in query_indices if all(l <= src_index for l in line_obj["links"][q]) and q < src_index ] dst_index = query_indices[ min(int(len(query_indices) * dst_percent), len(query_indices) - 1) ] label = line_obj["values"][dst_index] body = line_obj["lines"][: src_index + 1] suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index]) body_part_len = len(body) // 4 arguments.append( { "prefix": line_obj["prefix"], "body_0": "\n".join(body[:body_part_len]), "body_1": "\n".join(body[body_part_len : 2 * body_part_len]), "body_2": "\n".join(body[2 * body_part_len : 3 * body_part_len]), "body_3": "\n".join(body[3 * body_part_len :]), "suffix": suffix, } ) labels.append(label) sum_src_indices.append(src_index) sum_dst_indices.append(dst_index) # Select backend backend = select_sglang_backend(args) tic = time.time() states = line_retrieval.run_batch( arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True, ) latency = time.time() - tic corrects = [] for i in range(len(arguments)): output = states[i]["answer"] prompt_len = states[i].get_meta_info("answer").get("prompt_length", -1) label = labels[i] # Try all numbers findall = re.findall("\d+", output) if not findall: response_number = output else: for response_number in findall: if response_number == label: break correct = response_number == label corrects.append(correct) # Log results summary = ( f"Line index: {sum_src_indices[i]} -> {sum_dst_indices[i]}, " f"Prompt len: {prompt_len}, " f"Correct: {correct}, " f"Label: {label}, Predicted: {response_number}, " ) print(summary) accuracy = np.mean(corrects) print(f"Accuracy: {accuracy:.3f}, latency: {latency:.2f} s") # Write results dump_state_text(f"tmp_output_{args.backend}.txt", states) with open(args.result_file, "a") as fout: value = { "task": "line_retrieval", "backend": args.backend, "num_gpus": 1, "latency": round(latency, 3), "num_requests": len(arguments), "other": { "num_questions": len(arguments), "parallel": args.parallel, }, } fout.write(json.dumps(value) + "\n") def main(args): line_obj = json.load(open(args.data_path, "r")) num_hoops = args.num_hoops for src_index in args.src_index: src_indices = [src_index] num_queries = args.num_queries_per_src dst_percents = [i * (1 / (num_queries)) for i in range(num_queries)] eval_model(args, line_obj, num_hoops, src_indices, dst_percents) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data-path", type=str, default="lines_1000_0.0.json") parser.add_argument("--src-index", type=int, nargs="+", default=[100]) parser.add_argument("--num-queries-per-src", type=int, default=10) parser.add_argument("--num-hoops", type=int, default=1) args = add_common_sglang_args_and_parse(parser) main(args)