356 lines
13 KiB
Python
356 lines
13 KiB
Python
import os
|
|
import re
|
|
import json
|
|
import random
|
|
import logging
|
|
import datasets
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from datetime import timedelta
|
|
from typing import List, Optional
|
|
from accelerate import Accelerator, InitProcessGroupKwargs
|
|
from torch.utils.data import DataLoader
|
|
from transformers import HfArgumentParser
|
|
from dataclasses import dataclass, field, asdict
|
|
from collections import defaultdict
|
|
from functools import partial
|
|
from transformers import DataCollatorWithPadding
|
|
|
|
from src.lm import LM, LMArgs, GenerationArgs
|
|
from src.retrieval import RetrievalArgs
|
|
from src.utils.util import makedirs, load_json, FileLogger
|
|
from .eval_retrieval import main as retrieval_main
|
|
from .icl_utils import flat_options, perplexity_to_choice, compute_scores, _llm_generation_func, _llm_perplexity_func
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
CQA = {
|
|
"arc_c":{'method':'perplexity', 'metric':'acc'},
|
|
"arc_e":{'method':'perplexity', 'metric':'acc'},
|
|
"natural_questions":{'method':'generation', 'metric':'em'},
|
|
"cate_name":'CQA'
|
|
}
|
|
Commonsense = {
|
|
"copa":{'method':'perplexity', 'metric':'acc'},
|
|
"hellaswag":{'method':'perplexity', 'metric':'acc'},
|
|
"piqa":{'method':'perplexity', 'metric':'acc'},
|
|
'cate_name': 'Commonsense'
|
|
}
|
|
Coreference = {
|
|
"winogrande":{'method':'perplexity', 'metric':'acc'},
|
|
"wsc":{'method':'perplexity', 'metric':'acc'},
|
|
"wsc273":{'method':'perplexity', 'metric':'acc'},
|
|
'cate_name': 'Coreference'
|
|
}
|
|
Paraphrase = {
|
|
"mrpc":{'method':'perplexity', 'metric':'acc'},
|
|
"paws":{'method':'perplexity', 'metric':'acc'},
|
|
"qqp":{'method':'perplexity', 'metric':'acc'},
|
|
'cate_name': 'Paraphrase'
|
|
}
|
|
NLI = {
|
|
"rte":{'method':'perplexity', 'metric':'acc'},
|
|
"snli":{'method':'perplexity', 'metric':'acc'},
|
|
"mnli_m":{'method':'perplexity', 'metric':'acc'},
|
|
"mnli_mm":{'method':'perplexity', 'metric':'acc'},
|
|
"qnli":{'method':'perplexity', 'metric':'acc'},
|
|
'cate_name': 'NLI'
|
|
}
|
|
ReadingComp = {
|
|
"multirc":{'method':'perplexity', 'metric':'f1'},
|
|
"openbookqa":{'method':'perplexity', 'metric':'acc'},
|
|
"boolq":{'method':'perplexity', 'metric':'acc'},
|
|
"squad_v1":{'method':'generation', 'metric':'em'},
|
|
'cate_name': 'ReadingComp'
|
|
}
|
|
Sentiment = {
|
|
"sentiment140":{'method':'perplexity', 'metric':'acc'},
|
|
"sst2":{'method':'perplexity', 'metric':'acc'},
|
|
"yelp":{'method':'perplexity', 'metric':'acc'},
|
|
'cate_name': 'Sentiment'
|
|
}
|
|
Data2Text = {
|
|
"common_gen":{'method':'generation', 'metric':'rl'},
|
|
"e2e_nlg":{'method':'generation', 'metric':'rl'},
|
|
"dart":{'method':'generation', 'metric':'rl'},
|
|
'cate_name': 'Data2Text'
|
|
}
|
|
Summarize = {
|
|
"aeslc":{'method':'generation', 'metric':'rl'},
|
|
"ag_news":{'method':'perplexity', 'metric':'acc'},
|
|
"gigaword":{'method':'generation', 'metric':'rl'},
|
|
'cate_name': 'Summarize'
|
|
}
|
|
TASK_LIST = [CQA, Commonsense, Coreference, Paraphrase, NLI, ReadingComp, Sentiment, Data2Text, Summarize]
|
|
task2cat = {}
|
|
for category in TASK_LIST:
|
|
cat_name = category["cate_name"]
|
|
for key, value in category.items():
|
|
if key == "cate_name":
|
|
continue
|
|
task2cat[key] = cat_name
|
|
|
|
|
|
@dataclass
|
|
class ICLArgs(LMArgs, RetrievalArgs):
|
|
output_dir: str = field(
|
|
default="data/results/icl/",
|
|
metadata={'help': 'Path to the file for saving embeddings and results.'}
|
|
)
|
|
eval_data: str = field(
|
|
default="llm-embedder:icl/icl/test.json",
|
|
metadata={'help': 'Path to the file containing both retrieved keys and answers.'}
|
|
)
|
|
task_names: Optional[List[str]] = field(
|
|
default=None,
|
|
metadata={'help': 'List of tasks to evaluate.'}
|
|
)
|
|
load_prev_result: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Load existing results in output_dir?'}
|
|
)
|
|
|
|
context_max_length: int = field(
|
|
default=1024,
|
|
metadata={'help': 'Evaluation json file.'},
|
|
)
|
|
few_shot: int = field(
|
|
default=8,
|
|
metadata={'help': 'How many few shot train samples?'},
|
|
)
|
|
|
|
corpus: str = field(
|
|
default="llm-embedder:icl/icl/corpus.json",
|
|
metadata={'help': 'Corpus path for retrieval.'}
|
|
)
|
|
key_template: str = field(
|
|
default="{contents}",
|
|
metadata={'help': 'How to concatenate columns in the corpus to form one key?'}
|
|
)
|
|
metrics: List[str] = field(
|
|
default_factory=lambda: [],
|
|
)
|
|
|
|
log_path: str = field(
|
|
default="data/results/icl/icl.log",
|
|
metadata={'help': 'Path to the file for logging.'}
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class GenerationArgs(GenerationArgs):
|
|
max_new_tokens: int = field(
|
|
default=64,
|
|
metadata={'help': 'Maximum new tokens to generate.'}
|
|
)
|
|
|
|
|
|
def remove_double_space(string):
|
|
return re.sub("[ ]{2,}", " ", string)
|
|
|
|
|
|
def load_test_data(knn_inxs,
|
|
test_data,
|
|
corpus_data,
|
|
filter_diff_task: bool=False,
|
|
example_num=8,
|
|
same_task_random=False,
|
|
):
|
|
dataset = datasets.load_dataset('json', data_files=test_data)['train']
|
|
passage_dataset = datasets.load_dataset('json', data_files=corpus_data)['train']
|
|
|
|
task_data = defaultdict(list)
|
|
for i, e in enumerate(tqdm(dataset, desc="Organizing Data")):
|
|
query = remove_double_space(e['query'])
|
|
answers = [remove_double_space(x) for x in e['answers']]
|
|
if knn_inxs is not None:
|
|
if filter_diff_task:
|
|
few_shot = []
|
|
rest_passage = []
|
|
for x in knn_inxs[i]:
|
|
icl_e = passage_dataset[int(x)]
|
|
# print(icl_e['task_name'], e['task_name'])
|
|
if icl_e['task_name'][:4] == e['task_name'][:4]:
|
|
few_shot.append(remove_double_space(icl_e['contents']))
|
|
if len(few_shot) > example_num: break
|
|
else:
|
|
if len(rest_passage) < example_num:
|
|
rest_passage.append(remove_double_space(icl_e['contents']))
|
|
|
|
if len(few_shot) < example_num:
|
|
few_shot.extend(rest_passage)
|
|
few_shot = few_shot[:example_num]
|
|
|
|
else:
|
|
# if task2cat[e['task_name']] == 'Coreference':
|
|
# candidates = random.sample(knn_inxs[i][:20], example_num)
|
|
# else:
|
|
# candidates = knn_inxs[i][:example_num]
|
|
candidates = knn_inxs[i][:example_num]
|
|
few_shot = [remove_double_space(passage_dataset[int(x)]['contents']) for x in candidates]
|
|
else:
|
|
few_shot = []
|
|
data = {"query":query, "answers":answers, "few_shot":few_shot}
|
|
if 'options' in e:
|
|
data['options'] = e['options']
|
|
task_data[e['task_name']].append(data)
|
|
|
|
if same_task_random:
|
|
task_name_2_idx = defaultdict(list)
|
|
for i, example in enumerate(tqdm(passage_dataset, "Collecting Task Indices")):
|
|
task_name_2_idx[example["task_name"]].append(i)
|
|
|
|
for task_name, task_examples in tqdm(task_data.items(), desc="Collecting Same-Task-Random Examples"):
|
|
if task_name in ["mnli_m", "mnli_mm"]:
|
|
corpus_task_name = "mnli"
|
|
else:
|
|
corpus_task_name = task_name
|
|
|
|
for i, _ in enumerate(task_examples):
|
|
task_indices = task_name_2_idx[corpus_task_name]
|
|
example_num = min(example_num, len(task_indices))
|
|
# get examples of the same task
|
|
few_shot = [remove_double_space(content) for content in passage_dataset[random.sample(task_indices, example_num)]["contents"]]
|
|
task_data[task_name][i]["few_shot"] = few_shot
|
|
|
|
return task_data
|
|
|
|
|
|
def main():
|
|
parser = HfArgumentParser([ICLArgs, GenerationArgs])
|
|
args, generation_args = parser.parse_args_into_dataclasses()
|
|
accelerator = Accelerator(cpu=args.cpu, kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=100000))])
|
|
|
|
if args.retrieval_method == "dense":
|
|
output_dir = os.path.join(args.output_dir, args.query_encoder.strip(os.sep).replace(os.sep, "--"))
|
|
else:
|
|
output_dir = os.path.join(args.output_dir, args.retrieval_method)
|
|
args.output_dir = output_dir
|
|
|
|
if args.retrieval_method != "no":
|
|
_, preds, _ = retrieval_main(args=args, accelerator=accelerator, log=False)
|
|
else:
|
|
preds = None
|
|
|
|
llm = LM(
|
|
model_name_or_path=args.model_name_or_path,
|
|
dtype=args.lm_dtype,
|
|
device_map=args.lm_device_map,
|
|
padding_side=args.padding_side,
|
|
cache_dir=args.model_cache_dir,
|
|
accelerator=accelerator,
|
|
generation_args=asdict(generation_args)
|
|
)
|
|
|
|
tokenizer = llm.tokenizer
|
|
|
|
args.output_dir = os.path.join(args.output_dir, args.model_name_or_path.strip(os.sep).replace(os.sep, "--"))
|
|
|
|
task_data = load_test_data(preds, test_data=args.eval_data, corpus_data=args.corpus, example_num=args.few_shot, same_task_random=args.retrieval_method == "same-task-random")
|
|
|
|
all_results = []
|
|
metrics = {}
|
|
for task_cate in [CQA, Commonsense, Coreference, Paraphrase, NLI, ReadingComp, Sentiment, Data2Text, Summarize]:
|
|
task_results = []
|
|
for task_name, setting in task_cate.items():
|
|
if task_name == 'cate_name':
|
|
continue
|
|
# skip tasks that are not specified
|
|
if args.task_names is not None and task_name not in args.task_names:
|
|
continue
|
|
|
|
save_path = os.path.join(args.output_dir, f'{task_name}.json')
|
|
|
|
if args.load_prev_result and os.path.exists(save_path):
|
|
# the first line is the metric
|
|
result = load_json(save_path, lines=True)[0]
|
|
task_results.append(result['metric_value'][setting['metric']])
|
|
all_results.append(result['metric_value'][setting['metric']])
|
|
if accelerator.process_index == 0:
|
|
logger.info(f"loading existing results from {save_path}...")
|
|
print(result)
|
|
continue
|
|
|
|
test_data = task_data[task_name]
|
|
if accelerator.process_index == 0:
|
|
print(f"------{task_name} ({len(all_results) + 1} / {30})------")
|
|
|
|
if setting['metric'] == 'acc':
|
|
assert setting['method'] == 'perplexity'
|
|
if setting['method'] == 'perplexity':
|
|
flat_data = flat_options(test_data)
|
|
dataset = datasets.Dataset.from_list(flat_data)
|
|
dataset.set_transform(
|
|
partial(
|
|
_llm_perplexity_func,
|
|
tokenizer=tokenizer,
|
|
example_num=args.few_shot,
|
|
max_input_tokens=args.context_max_length,
|
|
add_llama_inst=args.add_llama_inst,
|
|
)
|
|
)
|
|
else:
|
|
dataset = datasets.Dataset.from_list(test_data)
|
|
dataset.set_transform(
|
|
partial(
|
|
_llm_generation_func,
|
|
tokenizer=tokenizer,
|
|
example_num=args.few_shot,
|
|
max_input_tokens=args.context_max_length,
|
|
add_llama_inst=args.add_llama_inst,
|
|
)
|
|
)
|
|
|
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=args.lm_batch_size,
|
|
collate_fn=data_collator,
|
|
pin_memory=True,
|
|
)
|
|
dataloader = accelerator.prepare(dataloader)
|
|
|
|
if setting['method'] == 'perplexity':
|
|
predictions = llm.compute_nlls(dataloader)
|
|
predictions = perplexity_to_choice(test_data, predictions)
|
|
else:
|
|
if args.add_llama_inst:
|
|
eos_token_id = tokenizer.eos_token_id
|
|
else:
|
|
eos_token_id = tokenizer.encode("\n", add_special_tokens=False)[-1]
|
|
|
|
predictions = llm.generate(dataloader, eos_token_id=eos_token_id)
|
|
predictions = [x.strip() for x in predictions]
|
|
|
|
if setting['metric'] in ['em']:
|
|
labels = [x['answers'] for x in test_data]
|
|
else:
|
|
labels = [x['answers'][0] for x in test_data]
|
|
|
|
metric_value = compute_scores(setting['metric'], predictions, labels)
|
|
|
|
result = {'task_name':task_name, 'setting':setting, 'metric_value':metric_value}
|
|
if accelerator.process_index == 0:
|
|
print(result)
|
|
with open(makedirs(save_path), 'w') as f:
|
|
f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
|
for i, sample in enumerate(test_data):
|
|
sample["output"] = predictions[i]
|
|
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
|
|
|
task_results.append(result['metric_value'][setting['metric']])
|
|
all_results.append(result['metric_value'][setting['metric']])
|
|
|
|
if len(task_results):
|
|
metrics[task_cate['cate_name']] = np.mean(task_results)
|
|
|
|
metrics['avg'] = np.mean(all_results)
|
|
|
|
file_logger = FileLogger(makedirs(args.log_path))
|
|
if accelerator.process_index == 0:
|
|
file_logger.log(metrics, Args=asdict(args))
|
|
|
|
if __name__ == "__main__":
|
|
main()
|