467 lines
12 KiB
Python
467 lines
12 KiB
Python
# Adapted from https://github.com/openai/simple-evals/
|
|
|
|
import os
|
|
import resource
|
|
import time
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field
|
|
from multiprocessing.pool import ThreadPool
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import httpx
|
|
import jinja2
|
|
import numpy as np
|
|
import openai
|
|
import requests
|
|
from openai import OpenAI
|
|
from tqdm import tqdm
|
|
|
|
OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
|
|
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
|
|
"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
|
|
+ "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
|
|
)
|
|
|
|
|
|
Message = Dict[str, Any] # keys role, content
|
|
MessageList = List[Message]
|
|
|
|
|
|
class SamplerBase:
|
|
"""
|
|
Base class for defining a sampling model, which can be evaluated,
|
|
or used as part of the grading process.
|
|
"""
|
|
|
|
def __call__(self, message_list: MessageList) -> str:
|
|
raise NotImplementedError()
|
|
|
|
|
|
@dataclass
|
|
class EvalResult:
|
|
"""
|
|
Result of running an evaluation (usually consisting of many samples)
|
|
"""
|
|
|
|
score: Optional[float] # top-line metric
|
|
metrics: Optional[Dict[str, float]] # other metrics
|
|
htmls: List[str] # strings of valid HTML
|
|
convos: List[MessageList] # sampled conversations
|
|
|
|
|
|
@dataclass
|
|
class SingleEvalResult:
|
|
"""
|
|
Result of evaluating a single sample
|
|
"""
|
|
|
|
score: Optional[float]
|
|
metrics: Dict[str, float] = field(default_factory=dict)
|
|
html: Optional[str] = None
|
|
convo: Optional[MessageList] = None # sampled conversation
|
|
|
|
|
|
class Eval:
|
|
"""
|
|
Base class for defining an evaluation.
|
|
"""
|
|
|
|
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class LargerHttpxClient(httpx.Client):
|
|
def __init__(self):
|
|
timeout_config = httpx.Timeout(3600)
|
|
limits = httpx.Limits(
|
|
max_keepalive_connections=3600,
|
|
max_connections=3600,
|
|
)
|
|
super().__init__(timeout=timeout_config, limits=limits)
|
|
|
|
|
|
class ChatCompletionSampler(SamplerBase):
|
|
"""
|
|
Sample from OpenAI's chat completion API
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str = None,
|
|
model: Optional[str] = None,
|
|
system_message: Optional[str] = None,
|
|
temperature: float = 0.0,
|
|
max_tokens: int = 2048,
|
|
):
|
|
self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
|
|
|
|
if model is None:
|
|
model = self.client.models.list().data[0].id
|
|
|
|
self.model = model
|
|
self.system_message = system_message
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self.image_format = "url"
|
|
|
|
def _handle_image(
|
|
self,
|
|
image: str,
|
|
encoding: str = "base64",
|
|
format: str = "png",
|
|
fovea: int = 768,
|
|
):
|
|
new_image = {
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/{format};{encoding},{image}",
|
|
},
|
|
}
|
|
return new_image
|
|
|
|
def _handle_text(self, text: str):
|
|
return {"type": "text", "text": text}
|
|
|
|
def _pack_message(self, role: str, content: Any):
|
|
return {"role": str(role), "content": content}
|
|
|
|
def __call__(self, message_list: MessageList) -> str:
|
|
if self.system_message:
|
|
message_list = [
|
|
self._pack_message("system", self.system_message)
|
|
] + message_list
|
|
trial = 0
|
|
while True:
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=message_list,
|
|
temperature=self.temperature,
|
|
max_tokens=self.max_tokens,
|
|
)
|
|
return response.choices[0].message.content
|
|
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
|
|
except openai.BadRequestError as e:
|
|
print("Bad Request Error", e)
|
|
return ""
|
|
except Exception as e:
|
|
exception_backoff = 2**trial # expontial back off
|
|
print(
|
|
f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
|
|
e,
|
|
)
|
|
time.sleep(exception_backoff)
|
|
trial += 1
|
|
# unknown error shall throw exception
|
|
|
|
|
|
QUERY_TEMPLATE_MULTICHOICE = """
|
|
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
|
|
|
{Question}
|
|
|
|
A) {A}
|
|
B) {B}
|
|
C) {C}
|
|
D) {D}
|
|
""".strip()
|
|
|
|
ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])"
|
|
ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)"
|
|
|
|
|
|
EQUALITY_TEMPLATE = r"""
|
|
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
|
|
|
|
Examples:
|
|
|
|
Expression 1: $2x+3$
|
|
Expression 2: $3+2x$
|
|
|
|
Yes
|
|
|
|
Expression 1: 3/2
|
|
Expression 2: 1.5
|
|
|
|
Yes
|
|
|
|
Expression 1: $x^2+2x+1$
|
|
Expression 2: $y^2+2y+1$
|
|
|
|
No
|
|
|
|
Expression 1: $x^2+2x+1$
|
|
Expression 2: $(x+1)^2$
|
|
|
|
Yes
|
|
|
|
Expression 1: 3245/5
|
|
Expression 2: 649
|
|
|
|
No
|
|
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
|
|
|
|
Expression 1: 2/(-3)
|
|
Expression 2: -2/3
|
|
|
|
Yes
|
|
(trivial simplifications are allowed)
|
|
|
|
Expression 1: 72 degrees
|
|
Expression 2: 72
|
|
|
|
Yes
|
|
(give benefit of the doubt to units)
|
|
|
|
Expression 1: 64
|
|
Expression 2: 64 square feet
|
|
|
|
Yes
|
|
(give benefit of the doubt to units)
|
|
|
|
---
|
|
|
|
YOUR TASK
|
|
|
|
|
|
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
|
|
|
|
Expression 1: %(expression1)s
|
|
Expression 2: %(expression2)s
|
|
""".strip()
|
|
|
|
|
|
HTML_JINJA = """
|
|
<h3>Prompt conversation</h3>
|
|
{% for message in prompt_messages %}
|
|
{{ message_to_html(message) | safe }}
|
|
{% endfor %}
|
|
<h3>Sampled message</h3>
|
|
{{ message_to_html(next_message) | safe }}
|
|
<h3>Results</h3>
|
|
<p>Correct Answer: {{ correct_answer }}</p>
|
|
<p>Extracted Answer: {{ extracted_answer }}</p>
|
|
<p>Score: {{ score }}</p>
|
|
"""
|
|
|
|
|
|
def format_multichoice_question(row):
|
|
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
|
|
|
|
|
|
def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
|
|
prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
|
|
response = sampler([dict(content=prompt, role="user")])
|
|
return response.lower().strip() == "yes"
|
|
|
|
|
|
def _compute_stat(values: list, stat: str):
|
|
if stat == "mean":
|
|
return np.mean(values)
|
|
elif stat == "std":
|
|
return np.std(values)
|
|
elif stat == "min":
|
|
return np.min(values)
|
|
elif stat == "max":
|
|
return np.max(values)
|
|
else:
|
|
raise ValueError(f"Unknown {stat =}")
|
|
|
|
|
|
def aggregate_results(
|
|
single_eval_results: List[SingleEvalResult],
|
|
default_stats: Tuple[str] = ("mean", "std"),
|
|
name2stats: Optional[Dict[str, Tuple[str]]] = None,
|
|
) -> EvalResult:
|
|
"""
|
|
Aggregate results from multiple evaluations into a single EvalResult.
|
|
"""
|
|
name2stats = name2stats or {}
|
|
name2values = defaultdict(list)
|
|
htmls = []
|
|
convos = []
|
|
for single_eval_result in single_eval_results:
|
|
for name, value in single_eval_result.metrics.items():
|
|
name2values[name].append(value)
|
|
if single_eval_result.score is not None:
|
|
name2values["score"].append(single_eval_result.score)
|
|
htmls.append(single_eval_result.html)
|
|
convos.append(single_eval_result.convo)
|
|
final_metrics = {}
|
|
for name, values in name2values.items():
|
|
stats = name2stats.get(name, default_stats)
|
|
for stat in stats:
|
|
key = name if stat == "mean" else f"{name}:{stat}"
|
|
final_metrics[key] = _compute_stat(values, stat)
|
|
return EvalResult(
|
|
score=final_metrics.pop("score", None),
|
|
metrics=final_metrics,
|
|
htmls=htmls,
|
|
convos=convos,
|
|
)
|
|
|
|
|
|
def map_with_progress(f: callable, xs: List[Any], num_threads: int):
|
|
"""
|
|
Apply f to each element of xs, using a ThreadPool, and show progress.
|
|
"""
|
|
if os.getenv("debug"):
|
|
return list(map(f, tqdm(xs, total=len(xs))))
|
|
else:
|
|
with ThreadPool(min(num_threads, len(xs))) as pool:
|
|
return list(tqdm(pool.imap(f, xs), total=len(xs)))
|
|
|
|
|
|
jinja_env = jinja2.Environment(
|
|
loader=jinja2.BaseLoader(),
|
|
undefined=jinja2.StrictUndefined,
|
|
autoescape=jinja2.select_autoescape(["html", "xml"]),
|
|
)
|
|
_message_template = """
|
|
<div class="message {{ role }}">
|
|
<div class="role">
|
|
{{ role }}
|
|
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
|
|
</div>
|
|
<div class="content">
|
|
<pre>{{ content }}</pre>
|
|
</div>
|
|
</div>
|
|
"""
|
|
|
|
|
|
def message_to_html(message: Message) -> str:
|
|
"""
|
|
Generate HTML snippet (inside a <div>) for a message.
|
|
"""
|
|
return jinja_env.from_string(_message_template).render(
|
|
role=message["role"],
|
|
content=message["content"],
|
|
variant=message.get("variant", None),
|
|
)
|
|
|
|
|
|
jinja_env.globals["message_to_html"] = message_to_html
|
|
|
|
|
|
_report_template = """<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<style>
|
|
.message {
|
|
padding: 8px 16px;
|
|
margin-bottom: 8px;
|
|
border-radius: 4px;
|
|
}
|
|
.message.user {
|
|
background-color: #B2DFDB;
|
|
color: #00695C;
|
|
}
|
|
.message.assistant {
|
|
background-color: #B39DDB;
|
|
color: #4527A0;
|
|
}
|
|
.message.system {
|
|
background-color: #EEEEEE;
|
|
color: #212121;
|
|
}
|
|
.role {
|
|
font-weight: bold;
|
|
margin-bottom: 4px;
|
|
}
|
|
.variant {
|
|
color: #795548;
|
|
}
|
|
table, th, td {
|
|
border: 1px solid black;
|
|
}
|
|
pre {
|
|
white-space: pre-wrap;
|
|
}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
{% if metrics %}
|
|
<h1>Metrics</h1>
|
|
<table>
|
|
<tr>
|
|
<th>Metric</th>
|
|
<th>Value</th>
|
|
</tr>
|
|
<tr>
|
|
<td><b>Score</b></td>
|
|
<td>{{ score | float | round(3) }}</td>
|
|
</tr>
|
|
{% for name, value in metrics.items() %}
|
|
<tr>
|
|
<td>{{ name }}</td>
|
|
<td>{{ value }}</td>
|
|
</tr>
|
|
{% endfor %}
|
|
</table>
|
|
{% endif %}
|
|
<h1>Examples</h1>
|
|
{% for html in htmls %}
|
|
{{ html | safe }}
|
|
<hr>
|
|
{% endfor %}
|
|
</body>
|
|
</html>
|
|
"""
|
|
|
|
|
|
def make_report(eval_result: EvalResult) -> str:
|
|
"""
|
|
Create a standalone HTML report from an EvalResult.
|
|
"""
|
|
return jinja_env.from_string(_report_template).render(
|
|
score=eval_result.score,
|
|
metrics=eval_result.metrics,
|
|
htmls=eval_result.htmls,
|
|
)
|
|
|
|
|
|
def make_report_from_example_htmls(htmls: List[str]):
|
|
"""
|
|
Create a standalone HTML report from a list of example htmls
|
|
"""
|
|
return jinja_env.from_string(_report_template).render(
|
|
score=None, metrics={}, htmls=htmls
|
|
)
|
|
|
|
|
|
def download_dataset(path, url):
|
|
print(f"Downloading dataset {path} from {url}")
|
|
try:
|
|
response = requests.get(url, stream=True)
|
|
response.raise_for_status()
|
|
|
|
total_size = int(response.headers.get("content-length", 0))
|
|
block_size = 8192
|
|
|
|
with open(path, "wb") as f, tqdm(
|
|
desc="Downloading",
|
|
total=total_size,
|
|
unit="iB",
|
|
unit_scale=True,
|
|
unit_divisor=1024,
|
|
) as progress_bar:
|
|
for data in response.iter_content(block_size):
|
|
size = f.write(data)
|
|
progress_bar.update(size)
|
|
|
|
print(f"Dataset downloaded and saved to {path}")
|
|
except requests.RequestException as e:
|
|
raise Exception(f"Failed to download dataset: {e}")
|
|
|
|
|
|
def set_ulimit(target_soft_limit=65535):
|
|
resource_type = resource.RLIMIT_NOFILE
|
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
|
|
|
if current_soft < target_soft_limit:
|
|
try:
|
|
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
|
except ValueError as e:
|
|
print(f"Fail to set RLIMIT_NOFILE: {e}")
|