122 lines
4.3 KiB
Python
122 lines
4.3 KiB
Python
# Adapted from https://github.com/openai/simple-evals/
|
|
|
|
"""
|
|
Measuring Massive Multitask Language Understanding
|
|
Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt
|
|
https://arxiv.org/abs/2009.03300
|
|
"""
|
|
|
|
import random
|
|
import re
|
|
from typing import Optional
|
|
|
|
import pandas
|
|
|
|
from sglang.test import simple_eval_common as common
|
|
from sglang.test.simple_eval_common import (
|
|
ANSWER_PATTERN_MULTICHOICE,
|
|
HTML_JINJA,
|
|
Eval,
|
|
EvalResult,
|
|
SamplerBase,
|
|
SingleEvalResult,
|
|
format_multichoice_question,
|
|
)
|
|
|
|
subject2category = {
|
|
"abstract_algebra": "stem",
|
|
"anatomy": "other",
|
|
"astronomy": "stem",
|
|
"business_ethics": "other",
|
|
"clinical_knowledge": "other",
|
|
"college_biology": "stem",
|
|
"college_chemistry": "stem",
|
|
"college_computer_science": "stem",
|
|
"college_mathematics": "stem",
|
|
"college_medicine": "other",
|
|
"college_physics": "stem",
|
|
"computer_security": "stem",
|
|
"conceptual_physics": "stem",
|
|
"econometrics": "social_sciences",
|
|
"electrical_engineering": "stem",
|
|
"elementary_mathematics": "stem",
|
|
"formal_logic": "humanities",
|
|
"global_facts": "other",
|
|
"high_school_biology": "stem",
|
|
"high_school_chemistry": "stem",
|
|
"high_school_computer_science": "stem",
|
|
"high_school_european_history": "humanities",
|
|
"high_school_geography": "social_sciences",
|
|
"high_school_government_and_politics": "social_sciences",
|
|
"high_school_macroeconomics": "social_sciences",
|
|
"high_school_mathematics": "stem",
|
|
"high_school_microeconomics": "social_sciences",
|
|
"high_school_physics": "stem",
|
|
"high_school_psychology": "social_sciences",
|
|
"high_school_statistics": "stem",
|
|
"high_school_us_history": "humanities",
|
|
"high_school_world_history": "humanities",
|
|
"human_aging": "other",
|
|
"human_sexuality": "social_sciences",
|
|
"international_law": "humanities",
|
|
"jurisprudence": "humanities",
|
|
"logical_fallacies": "humanities",
|
|
"machine_learning": "stem",
|
|
"management": "other",
|
|
"marketing": "other",
|
|
"medical_genetics": "other",
|
|
"miscellaneous": "other",
|
|
"moral_disputes": "humanities",
|
|
"moral_scenarios": "humanities",
|
|
"nutrition": "other",
|
|
"philosophy": "humanities",
|
|
"prehistory": "humanities",
|
|
"professional_accounting": "other",
|
|
"professional_law": "humanities",
|
|
"professional_medicine": "other",
|
|
"professional_psychology": "social_sciences",
|
|
"public_relations": "social_sciences",
|
|
"security_studies": "social_sciences",
|
|
"sociology": "social_sciences",
|
|
"us_foreign_policy": "social_sciences",
|
|
"virology": "other",
|
|
"world_religions": "humanities",
|
|
}
|
|
|
|
|
|
class MMLUEval(Eval):
|
|
def __init__(self, filename: str, num_examples: Optional[int], num_threads: int):
|
|
df = pandas.read_csv(filename)
|
|
examples = [row.to_dict() for _, row in df.iterrows()]
|
|
if num_examples:
|
|
examples = random.Random(0).sample(examples, num_examples)
|
|
self.examples = examples
|
|
self.num_threads = num_threads
|
|
|
|
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
|
def fn(row: dict):
|
|
prompt_messages = [
|
|
sampler._pack_message(
|
|
content=format_multichoice_question(row), role="user"
|
|
)
|
|
]
|
|
response_text = sampler(prompt_messages)
|
|
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
|
|
extracted_answer = match.group(1) if match else None
|
|
score = 1.0 if extracted_answer == row["Answer"] else 0.0
|
|
html = common.jinja_env.from_string(HTML_JINJA).render(
|
|
prompt_messages=prompt_messages,
|
|
next_message=dict(content=response_text, role="assistant"),
|
|
score=score,
|
|
correct_answer=row["Answer"],
|
|
extracted_answer=extracted_answer,
|
|
)
|
|
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
|
category = subject2category.get(row["Subject"], "other")
|
|
return SingleEvalResult(
|
|
html=html, score=score, metrics={category: score}, convo=convo
|
|
)
|
|
|
|
results = common.map_with_progress(fn, self.examples, self.num_threads)
|
|
return common.aggregate_results(results)
|