embed-bge-m3/FlagEmbedding/research/llm_embedder/evaluation/eval_icl.py

356 lines
13 KiB
Python

import os
import re
import json
import random
import logging
import datasets
import numpy as np
from tqdm import tqdm
from datetime import timedelta
from typing import List, Optional
from accelerate import Accelerator, InitProcessGroupKwargs
from torch.utils.data import DataLoader
from transformers import HfArgumentParser
from dataclasses import dataclass, field, asdict
from collections import defaultdict
from functools import partial
from transformers import DataCollatorWithPadding
from src.lm import LM, LMArgs, GenerationArgs
from src.retrieval import RetrievalArgs
from src.utils.util import makedirs, load_json, FileLogger
from .eval_retrieval import main as retrieval_main
from .icl_utils import flat_options, perplexity_to_choice, compute_scores, _llm_generation_func, _llm_perplexity_func
logger = logging.getLogger(__name__)
CQA = {
"arc_c":{'method':'perplexity', 'metric':'acc'},
"arc_e":{'method':'perplexity', 'metric':'acc'},
"natural_questions":{'method':'generation', 'metric':'em'},
"cate_name":'CQA'
}
Commonsense = {
"copa":{'method':'perplexity', 'metric':'acc'},
"hellaswag":{'method':'perplexity', 'metric':'acc'},
"piqa":{'method':'perplexity', 'metric':'acc'},
'cate_name': 'Commonsense'
}
Coreference = {
"winogrande":{'method':'perplexity', 'metric':'acc'},
"wsc":{'method':'perplexity', 'metric':'acc'},
"wsc273":{'method':'perplexity', 'metric':'acc'},
'cate_name': 'Coreference'
}
Paraphrase = {
"mrpc":{'method':'perplexity', 'metric':'acc'},
"paws":{'method':'perplexity', 'metric':'acc'},
"qqp":{'method':'perplexity', 'metric':'acc'},
'cate_name': 'Paraphrase'
}
NLI = {
"rte":{'method':'perplexity', 'metric':'acc'},
"snli":{'method':'perplexity', 'metric':'acc'},
"mnli_m":{'method':'perplexity', 'metric':'acc'},
"mnli_mm":{'method':'perplexity', 'metric':'acc'},
"qnli":{'method':'perplexity', 'metric':'acc'},
'cate_name': 'NLI'
}
ReadingComp = {
"multirc":{'method':'perplexity', 'metric':'f1'},
"openbookqa":{'method':'perplexity', 'metric':'acc'},
"boolq":{'method':'perplexity', 'metric':'acc'},
"squad_v1":{'method':'generation', 'metric':'em'},
'cate_name': 'ReadingComp'
}
Sentiment = {
"sentiment140":{'method':'perplexity', 'metric':'acc'},
"sst2":{'method':'perplexity', 'metric':'acc'},
"yelp":{'method':'perplexity', 'metric':'acc'},
'cate_name': 'Sentiment'
}
Data2Text = {
"common_gen":{'method':'generation', 'metric':'rl'},
"e2e_nlg":{'method':'generation', 'metric':'rl'},
"dart":{'method':'generation', 'metric':'rl'},
'cate_name': 'Data2Text'
}
Summarize = {
"aeslc":{'method':'generation', 'metric':'rl'},
"ag_news":{'method':'perplexity', 'metric':'acc'},
"gigaword":{'method':'generation', 'metric':'rl'},
'cate_name': 'Summarize'
}
TASK_LIST = [CQA, Commonsense, Coreference, Paraphrase, NLI, ReadingComp, Sentiment, Data2Text, Summarize]
task2cat = {}
for category in TASK_LIST:
cat_name = category["cate_name"]
for key, value in category.items():
if key == "cate_name":
continue
task2cat[key] = cat_name
@dataclass
class ICLArgs(LMArgs, RetrievalArgs):
output_dir: str = field(
default="data/results/icl/",
metadata={'help': 'Path to the file for saving embeddings and results.'}
)
eval_data: str = field(
default="llm-embedder:icl/icl/test.json",
metadata={'help': 'Path to the file containing both retrieved keys and answers.'}
)
task_names: Optional[List[str]] = field(
default=None,
metadata={'help': 'List of tasks to evaluate.'}
)
load_prev_result: bool = field(
default=False,
metadata={'help': 'Load existing results in output_dir?'}
)
context_max_length: int = field(
default=1024,
metadata={'help': 'Evaluation json file.'},
)
few_shot: int = field(
default=8,
metadata={'help': 'How many few shot train samples?'},
)
corpus: str = field(
default="llm-embedder:icl/icl/corpus.json",
metadata={'help': 'Corpus path for retrieval.'}
)
key_template: str = field(
default="{contents}",
metadata={'help': 'How to concatenate columns in the corpus to form one key?'}
)
metrics: List[str] = field(
default_factory=lambda: [],
)
log_path: str = field(
default="data/results/icl/icl.log",
metadata={'help': 'Path to the file for logging.'}
)
@dataclass
class GenerationArgs(GenerationArgs):
max_new_tokens: int = field(
default=64,
metadata={'help': 'Maximum new tokens to generate.'}
)
def remove_double_space(string):
return re.sub("[ ]{2,}", " ", string)
def load_test_data(knn_inxs,
test_data,
corpus_data,
filter_diff_task: bool=False,
example_num=8,
same_task_random=False,
):
dataset = datasets.load_dataset('json', data_files=test_data)['train']
passage_dataset = datasets.load_dataset('json', data_files=corpus_data)['train']
task_data = defaultdict(list)
for i, e in enumerate(tqdm(dataset, desc="Organizing Data")):
query = remove_double_space(e['query'])
answers = [remove_double_space(x) for x in e['answers']]
if knn_inxs is not None:
if filter_diff_task:
few_shot = []
rest_passage = []
for x in knn_inxs[i]:
icl_e = passage_dataset[int(x)]
# print(icl_e['task_name'], e['task_name'])
if icl_e['task_name'][:4] == e['task_name'][:4]:
few_shot.append(remove_double_space(icl_e['contents']))
if len(few_shot) > example_num: break
else:
if len(rest_passage) < example_num:
rest_passage.append(remove_double_space(icl_e['contents']))
if len(few_shot) < example_num:
few_shot.extend(rest_passage)
few_shot = few_shot[:example_num]
else:
# if task2cat[e['task_name']] == 'Coreference':
# candidates = random.sample(knn_inxs[i][:20], example_num)
# else:
# candidates = knn_inxs[i][:example_num]
candidates = knn_inxs[i][:example_num]
few_shot = [remove_double_space(passage_dataset[int(x)]['contents']) for x in candidates]
else:
few_shot = []
data = {"query":query, "answers":answers, "few_shot":few_shot}
if 'options' in e:
data['options'] = e['options']
task_data[e['task_name']].append(data)
if same_task_random:
task_name_2_idx = defaultdict(list)
for i, example in enumerate(tqdm(passage_dataset, "Collecting Task Indices")):
task_name_2_idx[example["task_name"]].append(i)
for task_name, task_examples in tqdm(task_data.items(), desc="Collecting Same-Task-Random Examples"):
if task_name in ["mnli_m", "mnli_mm"]:
corpus_task_name = "mnli"
else:
corpus_task_name = task_name
for i, _ in enumerate(task_examples):
task_indices = task_name_2_idx[corpus_task_name]
example_num = min(example_num, len(task_indices))
# get examples of the same task
few_shot = [remove_double_space(content) for content in passage_dataset[random.sample(task_indices, example_num)]["contents"]]
task_data[task_name][i]["few_shot"] = few_shot
return task_data
def main():
parser = HfArgumentParser([ICLArgs, GenerationArgs])
args, generation_args = parser.parse_args_into_dataclasses()
accelerator = Accelerator(cpu=args.cpu, kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=100000))])
if args.retrieval_method == "dense":
output_dir = os.path.join(args.output_dir, args.query_encoder.strip(os.sep).replace(os.sep, "--"))
else:
output_dir = os.path.join(args.output_dir, args.retrieval_method)
args.output_dir = output_dir
if args.retrieval_method != "no":
_, preds, _ = retrieval_main(args=args, accelerator=accelerator, log=False)
else:
preds = None
llm = LM(
model_name_or_path=args.model_name_or_path,
dtype=args.lm_dtype,
device_map=args.lm_device_map,
padding_side=args.padding_side,
cache_dir=args.model_cache_dir,
accelerator=accelerator,
generation_args=asdict(generation_args)
)
tokenizer = llm.tokenizer
args.output_dir = os.path.join(args.output_dir, args.model_name_or_path.strip(os.sep).replace(os.sep, "--"))
task_data = load_test_data(preds, test_data=args.eval_data, corpus_data=args.corpus, example_num=args.few_shot, same_task_random=args.retrieval_method == "same-task-random")
all_results = []
metrics = {}
for task_cate in [CQA, Commonsense, Coreference, Paraphrase, NLI, ReadingComp, Sentiment, Data2Text, Summarize]:
task_results = []
for task_name, setting in task_cate.items():
if task_name == 'cate_name':
continue
# skip tasks that are not specified
if args.task_names is not None and task_name not in args.task_names:
continue
save_path = os.path.join(args.output_dir, f'{task_name}.json')
if args.load_prev_result and os.path.exists(save_path):
# the first line is the metric
result = load_json(save_path, lines=True)[0]
task_results.append(result['metric_value'][setting['metric']])
all_results.append(result['metric_value'][setting['metric']])
if accelerator.process_index == 0:
logger.info(f"loading existing results from {save_path}...")
print(result)
continue
test_data = task_data[task_name]
if accelerator.process_index == 0:
print(f"------{task_name} ({len(all_results) + 1} / {30})------")
if setting['metric'] == 'acc':
assert setting['method'] == 'perplexity'
if setting['method'] == 'perplexity':
flat_data = flat_options(test_data)
dataset = datasets.Dataset.from_list(flat_data)
dataset.set_transform(
partial(
_llm_perplexity_func,
tokenizer=tokenizer,
example_num=args.few_shot,
max_input_tokens=args.context_max_length,
add_llama_inst=args.add_llama_inst,
)
)
else:
dataset = datasets.Dataset.from_list(test_data)
dataset.set_transform(
partial(
_llm_generation_func,
tokenizer=tokenizer,
example_num=args.few_shot,
max_input_tokens=args.context_max_length,
add_llama_inst=args.add_llama_inst,
)
)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
dataloader = DataLoader(
dataset,
batch_size=args.lm_batch_size,
collate_fn=data_collator,
pin_memory=True,
)
dataloader = accelerator.prepare(dataloader)
if setting['method'] == 'perplexity':
predictions = llm.compute_nlls(dataloader)
predictions = perplexity_to_choice(test_data, predictions)
else:
if args.add_llama_inst:
eos_token_id = tokenizer.eos_token_id
else:
eos_token_id = tokenizer.encode("\n", add_special_tokens=False)[-1]
predictions = llm.generate(dataloader, eos_token_id=eos_token_id)
predictions = [x.strip() for x in predictions]
if setting['metric'] in ['em']:
labels = [x['answers'] for x in test_data]
else:
labels = [x['answers'][0] for x in test_data]
metric_value = compute_scores(setting['metric'], predictions, labels)
result = {'task_name':task_name, 'setting':setting, 'metric_value':metric_value}
if accelerator.process_index == 0:
print(result)
with open(makedirs(save_path), 'w') as f:
f.write(json.dumps(result, ensure_ascii=False) + "\n")
for i, sample in enumerate(test_data):
sample["output"] = predictions[i]
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
task_results.append(result['metric_value'][setting['metric']])
all_results.append(result['metric_value'][setting['metric']])
if len(task_results):
metrics[task_cate['cate_name']] = np.mean(task_results)
metrics['avg'] = np.mean(all_results)
file_logger = FileLogger(makedirs(args.log_path))
if accelerator.process_index == 0:
file_logger.log(metrics, Args=asdict(args))
if __name__ == "__main__":
main()