287 lines
11 KiB
Python
287 lines
11 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
# Copyright (c) ZhipuAI, Inc. and its affiliates.
|
|
import json
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import os
|
|
import random
|
|
import re
|
|
import requests
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from tqdm import tqdm
|
|
|
|
from evalscope.utils import get_logger
|
|
from evalscope.utils.io_utils import jsonl_to_list
|
|
|
|
logger = get_logger()
|
|
|
|
"""
|
|
This script is used to evaluate results of predictions for the LongWriter model.
|
|
Refer to https://github.com/THUDM/LongWriter for more details.
|
|
|
|
EvalLength:
|
|
Evaluate the length of the generated responses.
|
|
Metrics:
|
|
score_l: The average score of the length evaluation.
|
|
|
|
EvalQuality:
|
|
Evaluate the quality of the generated responses by using Judge Model.
|
|
Metrics:
|
|
score_q: The average score of the quality evaluation.
|
|
"""
|
|
|
|
|
|
class EvalLength:
|
|
|
|
EVAL_L = 'eval_length'
|
|
|
|
def __init__(self, model: str, pred_path: str, output_dir: str):
|
|
self.model = model
|
|
self.pred_path = pred_path
|
|
self.output_dir = output_dir
|
|
|
|
self.model_id_path = self.model.strip(os.sep).replace(os.sep, '__')
|
|
|
|
@staticmethod
|
|
def score(x, y):
|
|
if y > x:
|
|
return 100 * max(0, 1. - (y / x - 1) / 3)
|
|
else:
|
|
return 100 * max(0, 1. - (x / y - 1) / 2)
|
|
|
|
def eval(self, dump_res: bool = True):
|
|
# example = {"prompt": "Write an outline for a short 100-word blog post about xxx",
|
|
# "type": "Community Forum", "length": 100, "response_length": 103,
|
|
# "response": "I. Introduction A. xxx"}
|
|
predictions = [json.loads(line) for line in open(self.pred_path, encoding='utf-8')]
|
|
x, y, scores = [], [], []
|
|
|
|
for pred in tqdm(predictions, total=len(predictions), desc='[Processing eval_l]'):
|
|
x.append(pred['length'])
|
|
y.append(pred['response_length'])
|
|
scores.append(self.score(pred['length'], pred['response_length']))
|
|
|
|
avg_score_l = np.mean(scores)
|
|
logger.info(f'Average score of length evaluation: {avg_score_l:.2f}')
|
|
|
|
# Dump to output file
|
|
if dump_res:
|
|
output_res_path = f'{self.output_dir}/{self.model_id_path}/{self.EVAL_L}.jsonl'
|
|
with open(output_res_path, 'w') as f:
|
|
f.write(json.dumps({'score_l': avg_score_l, 'scores': scores}, ensure_ascii=False) + '\n')
|
|
logger.info(f'Successfully dumped evaluation results to {output_res_path}')
|
|
|
|
return x, y, scores
|
|
|
|
def plot(self, x: list, y: list):
|
|
plt = self.plot_img(x, y)
|
|
output_pic_path = f'{self.output_dir}/{self.model_id_path}/eval_length_scatter.png'
|
|
plt.savefig(output_pic_path)
|
|
logger.info(f'Successfully saved scatter plot to {output_pic_path}')
|
|
|
|
@staticmethod
|
|
def plot_img(x: list, y: list):
|
|
# set plt size 6x6
|
|
plt.figure(figsize=(6, 6))
|
|
lmt = 25000
|
|
# plot x, y
|
|
plt.scatter(x, y, s=100, c='r', alpha=0.3)
|
|
# plot x=y
|
|
plt.plot([0, lmt], [0, lmt], 'k--')
|
|
plt.xscale('log')
|
|
plt.yscale('log')
|
|
plt.xlim(50, lmt)
|
|
plt.ylim(50, lmt)
|
|
plt.xlabel('Required Length', fontsize=20, fontweight='bold')
|
|
plt.ylabel('Output Length', fontsize=20, fontweight='bold')
|
|
plt.xticks(fontsize=24)
|
|
plt.yticks(fontsize=24)
|
|
plt.tight_layout()
|
|
|
|
return plt
|
|
|
|
|
|
class EvalQuality:
|
|
|
|
EVAL_Q = 'eval_quality'
|
|
OPENAI_BASE_URL = 'https://api.openai.com/v1/chat/completions'
|
|
DIMS = ['Relevance', 'Accuracy', 'Coherence', 'Clarity', 'Breadth and Depth', 'Reading Experience']
|
|
|
|
def __init__(self,
|
|
model: str,
|
|
pred_path: str,
|
|
output_dir: str,
|
|
prompt_template_path: str,
|
|
openai_api_key: str = None,
|
|
openai_api_base: str = OPENAI_BASE_URL,
|
|
openai_gpt_model: str = 'gpt-4o-2024-05-13',
|
|
generation_kwargs: dict = None,
|
|
proc_num: int = 8):
|
|
|
|
self.model = model
|
|
self.openai_api_base = openai_api_base
|
|
self.pred_path = pred_path
|
|
self.output_dir = output_dir
|
|
self.proc_num = proc_num
|
|
self.eval_scores = []
|
|
|
|
assert os.path.exists(self.pred_path), f'Prediction file not found: {self.pred_path}'
|
|
|
|
# Default: temperature=0.5, max_new_tokens=1024, stop=None
|
|
if generation_kwargs is None:
|
|
self.generation_kwargs = dict({
|
|
'max_new_tokens': 1024,
|
|
'temperature': 0.5,
|
|
'stop': None,
|
|
})
|
|
else:
|
|
self.generation_kwargs = generation_kwargs
|
|
|
|
self.prompt_template: str = open(prompt_template_path, 'r', encoding='utf-8').read()
|
|
|
|
self.model_id_path = self.model.strip(os.sep).replace(os.sep, '__')
|
|
self.output_res_path = f'{self.output_dir}/{self.model_id_path}/{self.EVAL_Q}.jsonl'
|
|
|
|
self.openai_api_key: str = openai_api_key
|
|
self.openai_gpt_model = openai_gpt_model
|
|
if not self.openai_api_key:
|
|
logger.error('Please set `OPENAI_API_KEY` in the envs when stage `eval_q` is activated!')
|
|
|
|
def get_response_gpt4(self, prompt, temperature=0.5, max_new_tokens=1024, stop=None):
|
|
tries = 0
|
|
while tries < 1:
|
|
tries += 1
|
|
try:
|
|
headers = {
|
|
'Authorization': 'Bearer {}'.format(self.openai_api_key),
|
|
}
|
|
messages = [
|
|
{'role': 'user', 'content': prompt},
|
|
]
|
|
resp = requests.post(self.openai_api_base, json={
|
|
'model': self.openai_gpt_model,
|
|
'messages': messages,
|
|
'temperature': temperature,
|
|
'max_tokens': max_new_tokens,
|
|
'stop': stop,
|
|
}, headers=headers, timeout=600)
|
|
if resp.status_code != 200:
|
|
raise Exception(resp.text)
|
|
resp = resp.json()
|
|
logger.info(f'>>gpt resp: {resp}')
|
|
break
|
|
except KeyboardInterrupt as e:
|
|
raise e
|
|
except Exception as e:
|
|
if 'maximum context length' in str(e):
|
|
raise e
|
|
elif 'triggering' in str(e):
|
|
return 'Trigger OpenAI\'s content management policy'
|
|
logger.error("Error Occurs: \"%s\" Retry ..." % (str(e)))
|
|
else:
|
|
logger.error('Max tries. Failed.')
|
|
return 'Max tries. Failed.'
|
|
try:
|
|
return resp['choices'][0]['message']['content']
|
|
except:
|
|
return ''
|
|
|
|
@staticmethod
|
|
def extract_info(pattern, text):
|
|
match = re.search(pattern, text, re.DOTALL)
|
|
if match:
|
|
return match.group(1)
|
|
else:
|
|
return None
|
|
|
|
def process_data(self, item):
|
|
# for item in tqdm(items, total=len(items), desc=f'Process of eval_q: '):
|
|
prompt = self.prompt_template.replace('$INST$', item['prompt']).replace('$RESPONSE$', item['response'])
|
|
scores = None
|
|
output = self.get_response_gpt4(prompt, **self.generation_kwargs)
|
|
try:
|
|
if '```json' in output:
|
|
output = self.extract_info(r'```json\n(.*?)\n```', output)
|
|
output = output.replace('\n', '')
|
|
scores = json.loads(output)
|
|
for dim in self.DIMS:
|
|
if dim not in scores:
|
|
logger.warning(f'Cannot find score for dimension: {dim} in scores {scores}.')
|
|
scores = None
|
|
except Exception as e:
|
|
logger.error(f'Error occurs during process data: {str(e)}')
|
|
|
|
if scores is None:
|
|
logger.error(f'Failed to extract scores for item: {item}')
|
|
else:
|
|
logger.info(f'>>scores: {scores}')
|
|
item['scores'] = scores
|
|
|
|
return item
|
|
|
|
def eval(self):
|
|
|
|
data_all = jsonl_to_list(self.pred_path)
|
|
total = len(data_all)
|
|
assert total > 0, f'No data found in prediction file: {self.pred_path}'
|
|
|
|
random.shuffle(data_all)
|
|
|
|
with ThreadPoolExecutor() as executor:
|
|
self.eval_scores = list(executor.map(self.process_data, data_all))
|
|
|
|
# self.process_data(items=data)
|
|
logger.info(f'>>self.eval_scores: {self.eval_scores}')
|
|
|
|
total_score = dict()
|
|
for dim in self.DIMS:
|
|
# scores = [float(score[dim]) if dim in score else 3 for score in self.eval_scores]
|
|
scores = [float(item['scores'][dim]) if 'scores' in item and dim in item['scores']
|
|
else 3 for item in self.eval_scores]
|
|
total_score[dim] = ((sum(scores) / len(scores)) - 1) * 25
|
|
total_score['total'] = sum(total_score.values()) / len(total_score)
|
|
logger.info(f'Total score of quality evaluation: {total_score["total"]:.2f}')
|
|
|
|
output_res_path: str = f'{self.output_dir}/{self.model_id_path}/{self.EVAL_Q}.jsonl'
|
|
with open(output_res_path, 'w', encoding='utf-8') as fout:
|
|
fout.write(json.dumps(total_score, ensure_ascii=False) + '\n')
|
|
|
|
|
|
def run_eval(model: str,
|
|
pred_path: str,
|
|
output_dir: str,
|
|
prompt_template_path: str,
|
|
openai_api_key: str,
|
|
openai_api_base: str,
|
|
openai_gpt_model: str,
|
|
generation_kwargs: dict,
|
|
proc_num: int,
|
|
stage: list,
|
|
):
|
|
logger.info(f'Got eval stages: {stage}')
|
|
|
|
if 'eval_l' in stage:
|
|
logger.info(f'Processing evaluation of length for model: {model}')
|
|
eval_length = EvalLength(model=model,
|
|
pred_path=pred_path,
|
|
output_dir=output_dir)
|
|
x, y, _ = eval_length.eval()
|
|
eval_length.plot(x, y)
|
|
else:
|
|
logger.warning(f'*** Skip `eval_l` stage ***')
|
|
|
|
if 'eval_q' in stage:
|
|
logger.info(f'Processing evaluation of quality for model: {model}')
|
|
eval_quality = EvalQuality(model=model,
|
|
pred_path=pred_path,
|
|
output_dir=output_dir,
|
|
prompt_template_path=prompt_template_path,
|
|
openai_api_key=openai_api_key,
|
|
openai_api_base=openai_api_base,
|
|
openai_gpt_model=openai_gpt_model,
|
|
generation_kwargs=generation_kwargs,
|
|
proc_num=proc_num)
|
|
eval_quality.eval()
|
|
else:
|
|
logger.warning('*** Skip `eval_q` stage ***')
|