embed-bge-m3/FlagEmbedding/research/Long_LLM/activation_beacon/main/eval_generation.py

117 lines
3.5 KiB
Python

import os
import torch
from typing import List, Optional
from accelerate import Accelerator
from transformers import HfArgumentParser
from transformers.utils import logging
from torch.utils.data import DataLoader
from dataclasses import dataclass, field, asdict
from src.data import Data
from src.metrics import Metric
from src import ModelArgs, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs, evaluate_generation, split_file_dir_name_ext
logger = logging.get_logger(__name__)
@dataclass
class Args(ModelArgs):
eval_data: Optional[str] = field(
default=None,
metadata={'help': 'Evaluation json data.'}
)
output_dir: str = field(
default="data/results/generation/",
metadata={'help': 'The base directory for saving results and logs.'}
)
result_dir: Optional[str] = field(
default=None,
metadata={'help': 'The directory relative to output_dir for saving results.'}
)
min_length: int = field(
default=0,
metadata={'help': 'How many tokens at minimum for evaluation?'}
)
max_length: int = field(
default=None,
metadata={'help': 'How many tokens at maximum for evaluation?'}
)
seed: int = field(
default=42
)
max_num: int = field(
default=None,
metadata={'help': 'Max number of instances to evaluate.'}
)
metrics: List[str] = field(
default_factory=lambda: ["save_result"],
metadata={'help': 'List of metrics. {rouge, save_result}'}
)
@torch.no_grad()
def main():
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
accelerator = Accelerator(cpu=args.cpu)
model, tokenizer = get_model_and_tokenizer(args, device=accelerator.device)
with accelerator.main_process_first():
dataset = Data.prepare_eval_data(
args.eval_data,
tokenizer=tokenizer,
max_length=args.max_length,
min_length=args.min_length,
chat_template=args.chat_template,
seed=args.seed,
max_eval_num=args.max_num,
cache_dir=args.dataset_cache_dir,
)
# get labels (the target generation result)
labels = dataset["labels"]
dataset = dataset.remove_columns(["labels"])
data_collator = DefaultDataCollator(tokenizer=tokenizer)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
collate_fn=data_collator,
# only pin memory when no gpu
pin_memory=not args.cpu,
)
# NOTE: prepare dataloader so the data moves to GPU automatically
dataloader = accelerator.prepare(dataloader)
save_path = Metric.get_save_path(
args.eval_data,
os.path.join(args.output_dir, args.result_dir) if args.result_dir is not None else args.output_dir
)
compute_metrics_fn = Metric.get_metric_fn(
metrics=args.metrics,
save_path=save_path
)
indices, outputs = evaluate_generation(
model,
dataloader,
accelerator=accelerator,
tokenizer=tokenizer,
)
if accelerator.process_index == 0:
metrics = compute_metrics_fn(outputs, labels, indices=indices)
config_save_path = os.path.join(split_file_dir_name_ext(save_path)[0], "config.json")
args.save(config_save_path)
file_logger = FileLogger(makedirs(os.path.join(args.output_dir, "metrics.log")))
file_logger.log(metrics, Args=asdict(args))
if __name__ == "__main__":
main()