{{ content }}
# Adapted from https://github.com/openai/simple-evals/ import os import resource import time from collections import defaultdict from dataclasses import dataclass, field from multiprocessing.pool import ThreadPool from typing import Any, Dict, List, Optional, Tuple import httpx import jinja2 import numpy as np import openai import requests from openai import OpenAI from tqdm import tqdm OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant." OPENAI_SYSTEM_MESSAGE_CHATGPT = ( "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture." + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01" ) Message = Dict[str, Any] # keys role, content MessageList = List[Message] class SamplerBase: """ Base class for defining a sampling model, which can be evaluated, or used as part of the grading process. """ def __call__(self, message_list: MessageList) -> str: raise NotImplementedError() @dataclass class EvalResult: """ Result of running an evaluation (usually consisting of many samples) """ score: Optional[float] # top-line metric metrics: Optional[Dict[str, float]] # other metrics htmls: List[str] # strings of valid HTML convos: List[MessageList] # sampled conversations @dataclass class SingleEvalResult: """ Result of evaluating a single sample """ score: Optional[float] metrics: Dict[str, float] = field(default_factory=dict) html: Optional[str] = None convo: Optional[MessageList] = None # sampled conversation class Eval: """ Base class for defining an evaluation. """ def __call__(self, sampler: SamplerBase) -> EvalResult: raise NotImplementedError() class LargerHttpxClient(httpx.Client): def __init__(self): timeout_config = httpx.Timeout(3600) limits = httpx.Limits( max_keepalive_connections=3600, max_connections=3600, ) super().__init__(timeout=timeout_config, limits=limits) class ChatCompletionSampler(SamplerBase): """ Sample from OpenAI's chat completion API """ def __init__( self, base_url: str = None, model: Optional[str] = None, system_message: Optional[str] = None, temperature: float = 0.0, max_tokens: int = 2048, ): self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient()) if model is None: model = self.client.models.list().data[0].id self.model = model self.system_message = system_message self.temperature = temperature self.max_tokens = max_tokens self.image_format = "url" def _handle_image( self, image: str, encoding: str = "base64", format: str = "png", fovea: int = 768, ): new_image = { "type": "image_url", "image_url": { "url": f"data:image/{format};{encoding},{image}", }, } return new_image def _handle_text(self, text: str): return {"type": "text", "text": text} def _pack_message(self, role: str, content: Any): return {"role": str(role), "content": content} def __call__(self, message_list: MessageList) -> str: if self.system_message: message_list = [ self._pack_message("system", self.system_message) ] + message_list trial = 0 while True: try: response = self.client.chat.completions.create( model=self.model, messages=message_list, temperature=self.temperature, max_tokens=self.max_tokens, ) return response.choices[0].message.content # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU except openai.BadRequestError as e: print("Bad Request Error", e) return "" except Exception as e: exception_backoff = 2**trial # expontial back off print( f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", e, ) time.sleep(exception_backoff) trial += 1 # unknown error shall throw exception QUERY_TEMPLATE_MULTICHOICE = """ Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. {Question} A) {A} B) {B} C) {C} D) {D} """.strip() ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])" ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" EQUALITY_TEMPLATE = r""" Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications Examples: Expression 1: $2x+3$ Expression 2: $3+2x$ Yes Expression 1: 3/2 Expression 2: 1.5 Yes Expression 1: $x^2+2x+1$ Expression 2: $y^2+2y+1$ No Expression 1: $x^2+2x+1$ Expression 2: $(x+1)^2$ Yes Expression 1: 3245/5 Expression 2: 649 No (these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) Expression 1: 2/(-3) Expression 2: -2/3 Yes (trivial simplifications are allowed) Expression 1: 72 degrees Expression 2: 72 Yes (give benefit of the doubt to units) Expression 1: 64 Expression 2: 64 square feet Yes (give benefit of the doubt to units) --- YOUR TASK Respond with only "Yes" or "No" (without quotes). Do not include a rationale. Expression 1: %(expression1)s Expression 2: %(expression2)s """.strip() HTML_JINJA = """
Correct Answer: {{ correct_answer }}
Extracted Answer: {{ extracted_answer }}
Score: {{ score }}
""" def format_multichoice_question(row): return QUERY_TEMPLATE_MULTICHOICE.format(**row) def check_equality(sampler: SamplerBase, expr1: str, expr2: str): prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2} response = sampler([dict(content=prompt, role="user")]) return response.lower().strip() == "yes" def _compute_stat(values: list, stat: str): if stat == "mean": return np.mean(values) elif stat == "std": return np.std(values) elif stat == "min": return np.min(values) elif stat == "max": return np.max(values) else: raise ValueError(f"Unknown {stat =}") def aggregate_results( single_eval_results: List[SingleEvalResult], default_stats: Tuple[str] = ("mean", "std"), name2stats: Optional[Dict[str, Tuple[str]]] = None, ) -> EvalResult: """ Aggregate results from multiple evaluations into a single EvalResult. """ name2stats = name2stats or {} name2values = defaultdict(list) htmls = [] convos = [] for single_eval_result in single_eval_results: for name, value in single_eval_result.metrics.items(): name2values[name].append(value) if single_eval_result.score is not None: name2values["score"].append(single_eval_result.score) htmls.append(single_eval_result.html) convos.append(single_eval_result.convo) final_metrics = {} for name, values in name2values.items(): stats = name2stats.get(name, default_stats) for stat in stats: key = name if stat == "mean" else f"{name}:{stat}" final_metrics[key] = _compute_stat(values, stat) return EvalResult( score=final_metrics.pop("score", None), metrics=final_metrics, htmls=htmls, convos=convos, ) def map_with_progress(f: callable, xs: List[Any], num_threads: int): """ Apply f to each element of xs, using a ThreadPool, and show progress. """ if os.getenv("debug"): return list(map(f, tqdm(xs, total=len(xs)))) else: with ThreadPool(min(num_threads, len(xs))) as pool: return list(tqdm(pool.imap(f, xs), total=len(xs))) jinja_env = jinja2.Environment( loader=jinja2.BaseLoader(), undefined=jinja2.StrictUndefined, autoescape=jinja2.select_autoescape(["html", "xml"]), ) _message_template = """ """ def message_to_html(message: Message) -> str: """ Generate HTML snippet (inside a| Metric | Value |
|---|---|
| Score | {{ score | float | round(3) }} |
| {{ name }} | {{ value }} |