embed-bge-m3/FlagEmbedding/research/Long_LLM/activation_beacon/main/eval_multiturn.py

176 lines
5.8 KiB
Python

import os
import torch
import time
import datasets
from typing import List, Optional
from accelerate import Accelerator
from transformers import HfArgumentParser
from transformers.utils import logging
from torch.utils.data import DataLoader
from dataclasses import dataclass, field, asdict
from functools import partial
from src.data import Data
from src.metrics import Metric
from src import ModelArgs, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs, evaluate_perplexity, split_file_dir_name_ext, apply_chat_template
logger = logging.get_logger(__name__)
@dataclass
class Args(ModelArgs):
eval_data: Optional[str] = field(
default="long-llm:sharegpt/3-turn.json",
metadata={'help': 'Evaluation json data.'}
)
output_dir: str = field(
default="data/results/multiturn/",
metadata={'help': 'The base directory for saving results and logs.'}
)
min_length: int = field(
default=0,
metadata={'help': 'How many tokens at minimum for evaluation?'}
)
# no more than 1536 tokens because gist cannot process more
max_length: int = field(
default=100000,
metadata={'help': 'How many tokens at maximum for evaluation?'}
)
num_turn: int = field(
default=3,
metadata={'help': 'How many turns?'}
)
breakdown: bool = field(
default=False,
)
def process_multiturn(data, indices, tokenizer, chat_template, min_length, max_length, num_turn=None, breakdown=False):
outputs = {'input_ids': [], 'attention_mask': [], "labels": [], "length": [], "index": []}
# accumulative
if breakdown:
for i, source in enumerate(data['accum_conversations']):
# break the multi-turn conversation
if num_turn is None:
num_turn = len(source) // 2
# skip conversations that do not have enough turns
if num_turn * 2 > len(source):
continue
for j in range(0, 2 * num_turn, 2):
turn_source = source[j: j + 2]
encoded = apply_chat_template(
chat_template,
turn_source,
tokenizer=tokenizer,
return_labels=True,
).encoded
# skip data that not fall in between min_length and max_length
if min_length is not None and len(encoded["input_ids"]) < min_length:
continue
if max_length is not None and len(encoded["input_ids"]) > max_length:
continue
for k, v in encoded.items():
outputs[k].append(v)
outputs['length'].append(len(encoded['input_ids']))
# NOTE: the breakdown conversations belong to the same root
outputs['index'].append(indices[i])
return outputs
else:
for i, source in enumerate(data['conversations']):
if num_turn is not None:
source = source[:2 * num_turn]
encoded = apply_chat_template(
chat_template,
source,
tokenizer=tokenizer,
return_labels=True,
).encoded
# skip data that not fall in between min_length and max_length
if min_length is not None and len(encoded["input_ids"]) < min_length:
continue
if max_length is not None and len(encoded["input_ids"]) > max_length:
continue
for k, v in encoded.items():
outputs[k].append(v)
outputs['length'].append(len(encoded['input_ids']))
outputs['index'].append(indices[i])
return outputs
@torch.no_grad()
def main():
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
accelerator = Accelerator(cpu=args.cpu)
model, tokenizer = get_model_and_tokenizer(args, device=accelerator.device)
with accelerator.main_process_first():
raw_dataset = datasets.load_dataset("json", data_files=args.eval_data, split="train", cache_dir=args.dataset_cache_dir)
process_fn = partial(
process_multiturn,
tokenizer=tokenizer,
chat_template=args.chat_template,
max_length=args.max_length,
min_length=args.min_length,
num_turn=args.num_turn,
breakdown=args.breakdown,
)
dataset = raw_dataset.map(process_fn, batched=True, num_proc=32, batch_size=10, with_indices=True, remove_columns=raw_dataset.column_names)
# get labels (the target generation result)
data_collator = DefaultDataCollator(tokenizer=tokenizer)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
collate_fn=data_collator,
# only pin memory when no gpu
pin_memory=not args.cpu,
)
if not args.enable_tp:
model, dataloader = accelerator.prepare(model, dataloader)
# NOTE: unwrap because we just use the model for evaluation
model = accelerator.unwrap_model(model)
else:
# NOTE: prepare dataloader so the data moves to GPU automatically
dataloader = accelerator.prepare(dataloader)
accelerator.wait_for_everyone()
accelerator.print(dataset['index'])
t1 = time.time()
perplexity = evaluate_perplexity(model, dataloader, accelerator)
t2 = time.time()
t = [t2 - t1]
if accelerator.num_processes > 1:
t = accelerator.gather_for_metrics(t)
t = sum(t)
metrics = {"perplexity": perplexity, "time": round(t, 4)}
if accelerator.process_index == 0:
log_path = os.path.join(args.output_dir, f"metrics.log")
file_logger = FileLogger(makedirs(log_path))
file_logger.log(metrics, Args=asdict(args))
if __name__ == "__main__":
main()