import argparse import json import time import answer_extraction import eval_utils import numpy as np from datasets import load_dataset import sglang as sgl from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) from sglang.utils import dump_state_text @sgl.function def reasoning_gen(s, question: str): s += sgl.user( question + "\nPlease reason step by step, and put your final answer within \boxed{}." ) s += sgl.assistant( sgl.gen( "answer", ) ) def convert_dataset(path: str, question_key: str, answer_key: str, num_tries: int): raw_dataset = load_dataset(path) questions = [] answers = [] for data in raw_dataset["train"]: question = data[question_key] answer = data[answer_key] for _ in range(num_tries): questions.append({"question": question}) answers.append({"answer": answer}) return questions, answers def main(args): # Select backend sgl.set_default_backend(select_sglang_backend(args)) # Get dataset questions, answers = convert_dataset( args.data_path, args.question_key, args.answer_key, args.num_tries ) # Run requests tic = time.perf_counter() states = reasoning_gen.run_batch( questions, num_threads=args.parallel, progress_bar=True, temperature=0.6, max_new_tokens=32768, top_p=0.95, ) latency = time.perf_counter() - tic # Extract results and record outcomes in a list. outcomes = [] for i, state in enumerate(states): try: pred_answer = answer_extraction.extract_math_answer( questions[i]["question"], state["answer"], "limo" ) gt_answer = str(answers[i]["answer"]) pred_answer = ( pred_answer[-1] if isinstance(pred_answer, list) else pred_answer ) is_correct = 1 if eval_utils.math_equal(pred_answer, gt_answer) else 0 except Exception as e: print(f"Error extracting answer: {e}") is_correct = 0 outcomes.append(is_correct) # Calculate overall accuracy using numpy overall_accuracy = np.mean(outcomes) print(f"Overall Accuracy: {overall_accuracy}") # Calculate mean standard error over questions if num_tries >= 2 if args.num_tries > 1: outcomes_np = np.array(outcomes).reshape(-1, args.num_tries) # Using sample standard deviation with ddof=1 std_per_question = np.std(outcomes_np, axis=1, ddof=1) # Compute the standard error for each question: std / sqrt(num_tries) se_per_question = std_per_question / np.sqrt(args.num_tries) mean_se = se_per_question.mean() print(f"Mean Standard Error of Accuracy across questions: {mean_se}") else: mean_se = None print("Not enough samples per question to compute standard error.") # Calculate output throughput num_output_tokens = sum( s.get_meta_info("answer")["completion_tokens"] for s in states ) output_throughput = num_output_tokens / latency print(f"Output throughput: {output_throughput} token/s") # Dump results dump_state_text(f"tmp_output_{args.backend}.txt", states) # Write results with open(args.result_file, "a") as fout: value = { "task": "limo", "backend": args.backend, "latency": round(latency, 3), "overall_accuracy": round(overall_accuracy, 3), "mean_se_accuracy": round(mean_se, 3) if mean_se is not None else None, "num_requests": len(questions), "other": { "num_questions": len(questions), "parallel": args.parallel, }, } fout.write(json.dumps(value) + "\n") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data-path", type=str, default="GAIR/LIMO") parser.add_argument("--question-key", type=str, default="question") parser.add_argument("--answer-key", type=str, default="answer") parser.add_argument("--num-tries", type=int, default=1) add_common_sglang_args_and_parse(parser) args = parser.parse_args() main(args)