sglang0.4.5.post1/python/sglang/test/simple_eval_common.py

467 lines
12 KiB
Python

# Adapted from https://github.com/openai/simple-evals/
import os
import resource
import time
from collections import defaultdict
from dataclasses import dataclass, field
from multiprocessing.pool import ThreadPool
from typing import Any, Dict, List, Optional, Tuple
import httpx
import jinja2
import numpy as np
import openai
import requests
from openai import OpenAI
from tqdm import tqdm
OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
+ "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
)
Message = Dict[str, Any] # keys role, content
MessageList = List[Message]
class SamplerBase:
"""
Base class for defining a sampling model, which can be evaluated,
or used as part of the grading process.
"""
def __call__(self, message_list: MessageList) -> str:
raise NotImplementedError()
@dataclass
class EvalResult:
"""
Result of running an evaluation (usually consisting of many samples)
"""
score: Optional[float] # top-line metric
metrics: Optional[Dict[str, float]] # other metrics
htmls: List[str] # strings of valid HTML
convos: List[MessageList] # sampled conversations
@dataclass
class SingleEvalResult:
"""
Result of evaluating a single sample
"""
score: Optional[float]
metrics: Dict[str, float] = field(default_factory=dict)
html: Optional[str] = None
convo: Optional[MessageList] = None # sampled conversation
class Eval:
"""
Base class for defining an evaluation.
"""
def __call__(self, sampler: SamplerBase) -> EvalResult:
raise NotImplementedError()
class LargerHttpxClient(httpx.Client):
def __init__(self):
timeout_config = httpx.Timeout(3600)
limits = httpx.Limits(
max_keepalive_connections=3600,
max_connections=3600,
)
super().__init__(timeout=timeout_config, limits=limits)
class ChatCompletionSampler(SamplerBase):
"""
Sample from OpenAI's chat completion API
"""
def __init__(
self,
base_url: str = None,
model: Optional[str] = None,
system_message: Optional[str] = None,
temperature: float = 0.0,
max_tokens: int = 2048,
):
self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
if model is None:
model = self.client.models.list().data[0].id
self.model = model
self.system_message = system_message
self.temperature = temperature
self.max_tokens = max_tokens
self.image_format = "url"
def _handle_image(
self,
image: str,
encoding: str = "base64",
format: str = "png",
fovea: int = 768,
):
new_image = {
"type": "image_url",
"image_url": {
"url": f"data:image/{format};{encoding},{image}",
},
}
return new_image
def _handle_text(self, text: str):
return {"type": "text", "text": text}
def _pack_message(self, role: str, content: Any):
return {"role": str(role), "content": content}
def __call__(self, message_list: MessageList) -> str:
if self.system_message:
message_list = [
self._pack_message("system", self.system_message)
] + message_list
trial = 0
while True:
try:
response = self.client.chat.completions.create(
model=self.model,
messages=message_list,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
return response.choices[0].message.content
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
except openai.BadRequestError as e:
print("Bad Request Error", e)
return ""
except Exception as e:
exception_backoff = 2**trial # expontial back off
print(
f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
e,
)
time.sleep(exception_backoff)
trial += 1
# unknown error shall throw exception
QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
{Question}
A) {A}
B) {B}
C) {C}
D) {D}
""".strip()
ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])"
ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)"
EQUALITY_TEMPLATE = r"""
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
Examples:
Expression 1: $2x+3$
Expression 2: $3+2x$
Yes
Expression 1: 3/2
Expression 2: 1.5
Yes
Expression 1: $x^2+2x+1$
Expression 2: $y^2+2y+1$
No
Expression 1: $x^2+2x+1$
Expression 2: $(x+1)^2$
Yes
Expression 1: 3245/5
Expression 2: 649
No
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
Expression 1: 2/(-3)
Expression 2: -2/3
Yes
(trivial simplifications are allowed)
Expression 1: 72 degrees
Expression 2: 72
Yes
(give benefit of the doubt to units)
Expression 1: 64
Expression 2: 64 square feet
Yes
(give benefit of the doubt to units)
---
YOUR TASK
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
Expression 1: %(expression1)s
Expression 2: %(expression2)s
""".strip()
HTML_JINJA = """
<h3>Prompt conversation</h3>
{% for message in prompt_messages %}
{{ message_to_html(message) | safe }}
{% endfor %}
<h3>Sampled message</h3>
{{ message_to_html(next_message) | safe }}
<h3>Results</h3>
<p>Correct Answer: {{ correct_answer }}</p>
<p>Extracted Answer: {{ extracted_answer }}</p>
<p>Score: {{ score }}</p>
"""
def format_multichoice_question(row):
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
response = sampler([dict(content=prompt, role="user")])
return response.lower().strip() == "yes"
def _compute_stat(values: list, stat: str):
if stat == "mean":
return np.mean(values)
elif stat == "std":
return np.std(values)
elif stat == "min":
return np.min(values)
elif stat == "max":
return np.max(values)
else:
raise ValueError(f"Unknown {stat =}")
def aggregate_results(
single_eval_results: List[SingleEvalResult],
default_stats: Tuple[str] = ("mean", "std"),
name2stats: Optional[Dict[str, Tuple[str]]] = None,
) -> EvalResult:
"""
Aggregate results from multiple evaluations into a single EvalResult.
"""
name2stats = name2stats or {}
name2values = defaultdict(list)
htmls = []
convos = []
for single_eval_result in single_eval_results:
for name, value in single_eval_result.metrics.items():
name2values[name].append(value)
if single_eval_result.score is not None:
name2values["score"].append(single_eval_result.score)
htmls.append(single_eval_result.html)
convos.append(single_eval_result.convo)
final_metrics = {}
for name, values in name2values.items():
stats = name2stats.get(name, default_stats)
for stat in stats:
key = name if stat == "mean" else f"{name}:{stat}"
final_metrics[key] = _compute_stat(values, stat)
return EvalResult(
score=final_metrics.pop("score", None),
metrics=final_metrics,
htmls=htmls,
convos=convos,
)
def map_with_progress(f: callable, xs: List[Any], num_threads: int):
"""
Apply f to each element of xs, using a ThreadPool, and show progress.
"""
if os.getenv("debug"):
return list(map(f, tqdm(xs, total=len(xs))))
else:
with ThreadPool(min(num_threads, len(xs))) as pool:
return list(tqdm(pool.imap(f, xs), total=len(xs)))
jinja_env = jinja2.Environment(
loader=jinja2.BaseLoader(),
undefined=jinja2.StrictUndefined,
autoescape=jinja2.select_autoescape(["html", "xml"]),
)
_message_template = """
<div class="message {{ role }}">
<div class="role">
{{ role }}
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
</div>
<div class="content">
<pre>{{ content }}</pre>
</div>
</div>
"""
def message_to_html(message: Message) -> str:
"""
Generate HTML snippet (inside a <div>) for a message.
"""
return jinja_env.from_string(_message_template).render(
role=message["role"],
content=message["content"],
variant=message.get("variant", None),
)
jinja_env.globals["message_to_html"] = message_to_html
_report_template = """<!DOCTYPE html>
<html>
<head>
<style>
.message {
padding: 8px 16px;
margin-bottom: 8px;
border-radius: 4px;
}
.message.user {
background-color: #B2DFDB;
color: #00695C;
}
.message.assistant {
background-color: #B39DDB;
color: #4527A0;
}
.message.system {
background-color: #EEEEEE;
color: #212121;
}
.role {
font-weight: bold;
margin-bottom: 4px;
}
.variant {
color: #795548;
}
table, th, td {
border: 1px solid black;
}
pre {
white-space: pre-wrap;
}
</style>
</head>
<body>
{% if metrics %}
<h1>Metrics</h1>
<table>
<tr>
<th>Metric</th>
<th>Value</th>
</tr>
<tr>
<td><b>Score</b></td>
<td>{{ score | float | round(3) }}</td>
</tr>
{% for name, value in metrics.items() %}
<tr>
<td>{{ name }}</td>
<td>{{ value }}</td>
</tr>
{% endfor %}
</table>
{% endif %}
<h1>Examples</h1>
{% for html in htmls %}
{{ html | safe }}
<hr>
{% endfor %}
</body>
</html>
"""
def make_report(eval_result: EvalResult) -> str:
"""
Create a standalone HTML report from an EvalResult.
"""
return jinja_env.from_string(_report_template).render(
score=eval_result.score,
metrics=eval_result.metrics,
htmls=eval_result.htmls,
)
def make_report_from_example_htmls(htmls: List[str]):
"""
Create a standalone HTML report from a list of example htmls
"""
return jinja_env.from_string(_report_template).render(
score=None, metrics={}, htmls=htmls
)
def download_dataset(path, url):
print(f"Downloading dataset {path} from {url}")
try:
response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
with open(path, "wb") as f, tqdm(
desc="Downloading",
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(block_size):
size = f.write(data)
progress_bar.update(size)
print(f"Dataset downloaded and saved to {path}")
except requests.RequestException as e:
raise Exception(f"Failed to download dataset: {e}")
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
print(f"Fail to set RLIMIT_NOFILE: {e}")