embed-bge-m3/FlagEmbedding/research/Long_LLM/activation_beacon/src/trainer.py

237 lines
8.8 KiB
Python

import os
import math
import torch
import datasets
import random
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Union
from torch.utils.data import Sampler, Dataset
from transformers.trainer import Trainer, is_datasets_available
from transformers.tokenization_utils import BatchEncoding
from transformers.utils import logging
from .modeling_utils import evaluate_generation, evaluate_perplexity
logger = logging.get_logger(__name__)
class ActivationBeaconTrainer(Trainer):
def __init__(self, *args, model_args, file_logger, **kwargs):
super().__init__(*args, **kwargs)
self.model_args = model_args
self.file_logger = file_logger
def compute_loss(self, model, inputs, return_outputs=False):
if "retrieval_span" in inputs:
self.model.memory._retrieval_span = inputs['retrieval_span'][0]
inputs.pop("retrieval_span")
inputs.pop("length", None)
inputs.pop("index", None)
# NOTE: produce labels on the fly to save disk space
if inputs["labels"][0] is None:
inputs["labels"] = inputs["input_ids"].clone()
# NOTE: reset memory for each individual input
if hasattr(self.model, "memory"):
self.model.memory.reset()
outputs = super().compute_loss(model, inputs, return_outputs)
if hasattr(self.model, "memory") and hasattr(self.model.memory, "_retrieval_span"):
del self.model.memory._retrieval_span
del self.model.memory._retrieval_condensing_ratios
return outputs
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
# Build the sampler.
if self.args.group_by_stride is not None:
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
lengths = self.train_dataset[self.args.length_column_name]
else:
lengths = None
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
return StrideGroupedSampler(
# NOTE: multiply world size to get the total number of training instances across devices
batch_size=self.args.train_batch_size * self.args.world_size,
window=self.model.memory.config.beacon_window,
stride=self.model.memory.config.beacon_stride,
group=self.args.group_by_stride,
sort=self.args.sort_by_stride,
dataset=self.train_dataset,
lengths=lengths,
model_input_name=model_input_name,
)
else:
return super()._get_train_sampler()
def _save(self, output_dir: Optional[str] = None, state_dict=None):
outputs = super()._save(output_dir, state_dict)
# NOTE: also save model_args
self.model_args.save(os.path.join(output_dir, "model_args.json"))
return outputs
@torch.no_grad()
def evaluate(self, eval_dataset: Dataset | None = None, ignore_keys: List[str] | None = None, metric_key_prefix: str = "eval") -> Dict[str, float]:
# memory metrics - must set up as early as possible
self._memory_tracker.start()
if eval_dataset is None and self.eval_dataset is None:
return
if self.args.eval_method == "generation":
labels = self.eval_dataset["labels"]
self.eval_dataset = self.eval_dataset.remove_columns(["labels"])
dataloader = self.get_eval_dataloader()
self.model.memory.reset()
train_beacon_ratio = self.model.memory.beacon_ratio
train_beacon_ratio_mix = self.model.memory.beacon_ratio_mix
self.model.memory.set(
beacon_ratio=self.args.eval_beacon_ratio,
beacon_ratio_mix=self.args.eval_beacon_ratio_mix,
)
model = self.model.eval()
if self.args.eval_method == "perplexity":
perplexity = evaluate_perplexity(model, dataloader, accelerator=self.accelerator)
metrics = {"perplexity": perplexity}
elif self.args.eval_method == "generation":
indices, outputs = evaluate_generation(
model,
dataloader,
accelerator=self.accelerator,
tokenizer=self.tokenizer,
)
metrics = self.compute_metrics(outputs, labels, indices=indices)
else:
raise NotImplementedError(f"Eval method {self.args.eval_method} not implemented!")
self.model.memory.reset()
self.model.memory.set(
beacon_ratio=train_beacon_ratio,
beacon_ratio_mix=train_beacon_ratio_mix,
)
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_") and key != "epoch":
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
self.log(metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
self._memory_tracker.stop_and_update_metrics(metrics)
# log to file
if self.args.process_index == 0:
self.file_logger.log(
metrics=metrics,
Model_Args=asdict(self.model_args),
Training_Args=asdict(self.args),
Global_Steps=self.state.global_step
)
return metrics
class StrideGroupedSampler(Sampler):
"""Group """
def __init__(
self,
batch_size: int,
window: int,
stride: int,
group: str,
sort: Optional[str] = None,
dataset: Optional[Dataset] = None,
lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None
):
if dataset is None and lengths is None:
raise ValueError("One of dataset and lengths must be provided.")
if group is None:
raise ValueError("Group cannot be None!")
if lengths is None:
model_input_name = model_input_name if model_input_name is not None else "input_ids"
if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or model_input_name not in dataset[0]
):
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{model_input_name}' key."
)
lengths = [len(feature[model_input_name]) for feature in dataset]
elif isinstance(lengths, torch.Tensor):
logger.info(
"If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]..."
)
lengths = lengths.tolist()
indices = list(range(len(lengths)))
# get number of strides for each data
num_strides = []
for length in lengths:
num_stride = math.ceil((length - window) / stride) + 1
num_strides.append(num_stride)
indice_stride_pairs = list(zip(indices, num_strides))
# NOTE: shuffle the indices in advance, otherwise the randomness may be lost when all num_strides are equal
random.shuffle(indice_stride_pairs)
# sort data according to the number of strides
indice_stride_pairs = sorted(indice_stride_pairs, key=lambda x: x[1])
# group data instances with the same number of strides into the same batch
batches = []
batch = []
prev_num_stride = None
for index, num_stride in indice_stride_pairs:
if num_stride != prev_num_stride:
# in strict mode, all instances in the batch are forced to have the same number of strides
if group == "strict":
batch.clear()
elif group == "relaxed":
pass
else:
raise ValueError(f"Group method {group} must be in None, strict, relaxed!")
batch.append(index)
prev_num_stride = num_stride
if len(batch) == batch_size:
batches.append((batch.copy(), num_stride))
batch.clear()
if len(batch) and group == "relaxed":
batches.append((batch.copy(), num_stride))
if sort is None:
random.shuffle(batches)
elif sort == "ascend":
batches = sorted(batches, key=lambda x: x[1])
elif sort == "descend":
batches = sorted(batches, key=lambda x: x[1], reverse=True)
else:
raise ValueError(f"Sort method {sort} must be in None, ascend, descend!")
batches = [x[0] for x in batches]
self.indices = sum(batches, [])
def __len__(self):
return len(self.indices)
def __iter__(self):
return iter(self.indices)