embed-bge-m3/FlagEmbedding/research/Long_LLM/activation_beacon/src/args.py

378 lines
12 KiB
Python

import os
import json
from dataclasses import dataclass, field, asdict
from transformers.training_args import TrainingArguments
from typing import Optional, List, Tuple, Union, Dict
@dataclass
class ModelArgs:
model_cache_dir: str = field(
default=None,
metadata={'help': 'Default path to save language models.'}
)
dataset_cache_dir: str = field(
default=None,
metadata={'help': 'Default path to save huggingface datasets.'}
)
data_root: str = field(
default="/data/long-llm",
metadata={'help': 'The base directory storing all data used for training and evaluation. If specified, make sure all train_data, eval_data, and corpus are path relative to data_root!'},
)
train_data: Optional[List[str]] = field(
default=None,
metadata={'help': 'Training json file or glob to match a list of files.'},
)
eval_data: Optional[str] = field(
default=None,
metadata={'help': 'Evaluation json file.'},
)
model_name_or_path: str = field(
default='Qwen/Qwen2-7B-Instruct',
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
padding_side: str = field(
default="left",
metadata={'help': 'Tokenizer padding side.'}
)
no_use_fast: bool = field(
default=False,
metadata={'help': 'Do not use fast tokenizer?'}
)
access_token: Optional[str] = field(
default=None,
metadata={'help': 'Huggingface access token.'}
)
attn_impl: Optional[str] = field(
default="flash_attention_2",
metadata={'help': 'The implementation of attention.'}
)
max_length: int = field(
default=4096,
metadata={'help': 'How many tokens at maximum for each input.'},
)
chat_template: str = field(
default="hf",
metadata={'help': 'Instruction template name in fastchat.'}
)
max_position_embeddings: Optional[int] = field(
default=None,
metadata={'help': 'Maximum position.'},
)
mistral_sliding_window: Optional[int] = field(
default=None,
metadata={'help': 'Sliding window size in Mistral models.'},
)
rope_theta: Optional[float] = field(
default=None,
metadata={'help': 'RoPE base (theta).'},
)
rope_method: Optional[str] = field(
default=None,
metadata={'help': 'How to scale RoPE? {linear, dynamic, yarn}'},
)
rope_factor: float = field(
default=1.,
metadata={'help': 'RoPE scaling factor.'},
)
lora: Optional[str] = field(
default=None,
metadata={'help': 'LoRA ID.'},
)
lora_unload: bool = field(
default=True,
metadata={'help': 'Merge and unload LoRA?'},
)
load_in_4_bit: bool = field(
default=False,
metadata={'help': 'Load model in 4 bits?'},
)
dtype: str = field(
default="bf16",
metadata={'help': 'Data type for embeddings.'}
)
device_map: Optional[str] = field(
default=None,
metadata={'help': 'Device map for loading the model. Set to auto to load across devices.'}
)
batch_size: int = field(
default=1,
metadata={'help': 'Evaluation batch size.'},
)
cpu: bool = field(
default=False,
metadata={'help': 'Use cpu?'}
)
enable_tp: bool = field(
default=False,
metadata={'help': 'Use tensor parallel to wrap the model?'}
)
enable_vllm: bool = field(
default=False,
metadata={'help': 'Use vllm?'}
)
vllm_mem: float = field(
default=0.9,
metadata={'help': 'Vllm maximum GPU memory utilization.'}
)
vllm_tp: int = field(
default=1,
metadata={'help': 'Vllm tensor parallel degree.'}
)
vllm_len: Optional[int] = field(
default=None,
metadata={'help': 'Vllm maximum sequence length.'}
)
vllm_disable_ar: bool = field(
default=False,
metadata={'help': 'Disable custom all-reduce in vllm?'}
)
enable_beacon: bool = field(
default=False,
metadata={'help': 'Enable activation beacon?'}
)
beacon_window: Optional[int] = field(
default=None,
metadata={'help': 'The initial sliding window size.'}
)
beacon_stride: Optional[int] = field(
default=None,
metadata={'help': 'The stride of the sliding window.'}
)
beacon_attn: Optional[str] = field(
default=None,
metadata={'help': 'How to assign attention masks of beacon tokens? {segmentation, step-expansion, full-converage}'}
)
beacon_ratio: Optional[List[int]] = field(
default=None,
metadata={'help': 'Condensing ratios for beacons.'}
)
beacon_ratio_mix: Optional[str] = field(
default=None,
metadata={'help': 'How to determine the beacon_ratio for each input. {step-random, instance-random, adapt-x}'}
)
beacon_param: Optional[List[str]] = field(
default=None,
metadata={'help': 'The introduced parameters for beacon. {q, k, v, o}'}
)
beacon_embed_init: str = field(
default="eos",
metadata={'help': 'Initialize beacon embedding from eos/bos embedding.'}
)
beacon_sink_size: Optional[int] = field(
default=None,
metadata={'help': 'The number of activations that are always kept in the head of the sequence according to StreamingLLM.'}
)
beacon_attend_prev: Optional[bool] = field(
default=None,
metadata={'help': 'Can beacon tokens attend to previous beacon tokens?'}
)
beacon_pos: Optional[str] = field(
default=None,
metadata={'help': 'Where to put beacon tokens? {append, interleave}'}
)
beacon_parallel_window: Optional[int] = field(
default=None,
metadata={'help': 'How many windows to run in parallel?'}
)
max_new_tokens: Optional[int] = field(
default=None,
metadata={'help': 'How many tokens at maximum to return?'},
)
do_sample: Optional[bool] = field(
default=None,
metadata={'help': 'Do sampling when decoding?'},
)
temperature: Optional[float] = field(
default=None,
metadata={'help': 'Sampling temperature.'},
)
top_p: Optional[float] = field(
default=None,
metadata={'help': "If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation."}
)
def resolve_path(self, path):
"""Resolve any path starting with 'long-llm:' to relative path against data_root."""
pattern = "long-llm:"
# resolve relative data paths when necessary
if isinstance(path, list):
for i, x in enumerate(path):
if x.startswith(pattern):
path[i] = os.path.join(self.data_root, x.replace(pattern, ""))
else:
if path.startswith(pattern):
path = os.path.join(self.data_root, path.replace(pattern, ""))
return path
def get_generation_config(self):
generation_config = {}
if self.max_new_tokens is not None:
generation_config["max_new_tokens"] = self.max_new_tokens
if self.do_sample is not None:
generation_config["do_sample"] = self.do_sample
if self.temperature is not None:
generation_config["temperature"] = self.temperature
if self.top_p is not None:
generation_config["top_p"] = self.top_p
return generation_config
def to_dict(self):
return asdict(self)
def save(self, path):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f)
def __post_init__(self):
if self.train_data is not None:
self.train_data = self.resolve_path(self.train_data)
if self.eval_data is not None:
self.eval_data = self.resolve_path(self.eval_data)
if hasattr(self, "output_dir") and self.output_dir is not None:
self.output_dir = self.resolve_path(self.output_dir)
if hasattr(self, "result_dir"):
if self.result_dir is None:
if self.lora is not None:
name_or_path_components = [x for x in self.lora.split("/") if len(x)][-2:]
else:
name_or_path_components = [x for x in self.model_name_or_path.split("/") if len(x)][-2:]
self.result_dir = os.path.join(*name_or_path_components)
else:
self.result_dir = self.resolve_path(self.result_dir)
@dataclass
class TrainingArgs(TrainingArguments):
# ==============================
# Common arguments
# ==============================
output_dir: str = field(
default="data/outputs/pretrain",
)
per_device_train_batch_size: int = field(
default=1,
metadata={'help': 'Train batch size.'}
)
per_device_eval_batch_size: int = field(
default=1,
metadata={'help': 'Evaluation batch size.'}
)
remove_unused_columns: bool = field(
default=False,
metadata={'help': 'Remove columns in the dataset that are not registered in the forward function?'}
)
ddp_find_unused_parameters: bool = field(
default=False,
metadata={'help': 'Find unusuable parameters?'}
)
# NOTE: essential to keep comuputation graph because we need gradients for beacon tokens
use_reentrant: Optional[bool] = field(
default=None,
metadata={'help': "Use reetrant in gradient checkpointing?"}
)
report_to: str = field(
default="none",
metadata={'help': 'Log results by external tools?'}
)
# ==============================
# Customized arguments
# ==============================
min_length: int = field(
default=0,
metadata={'help': 'How many tokens at minimum for training?'}
)
group_by_stride: Optional[str] = field(
default=None,
metadata={'help': 'Group the training data instances by the number of strides in the beacon model. {relaxed, strict}'}
)
sort_by_stride: Optional[str] = field(
default=None,
metadata={'help': 'Sort the training data instances by the number of strides in the beacon model. {ascend, descend}'}
)
only_train_beacon: bool = field(
default=True,
metadata={'help': 'Freeze LLM parameters when training beacon parameters?'}
)
eval_method: str = field(
default="perplexity",
metadata={'help': 'How to evaluate during training? {perplexity, generation}'}
)
eval_max_length: int = field(
default=4096,
metadata={'help': 'How many tokens at maximum for each input in evaluation.'},
)
eval_min_length: int = field(
default=512,
metadata={'help': 'How many tokens at minimum for each input in evaluation.'},
)
eval_beacon_ratio: List[int] = field(
default_factory=lambda: [32],
metadata={'help': 'Condensing ratios for beacons in evaluation.'}
)
eval_beacon_ratio_mix: str = field(
default="adapt-1024",
metadata={'help': 'How to determine the beacon_ratio for each input. {step-random, instance-random, adapt-x}'}
)
max_eval_num: Optional[int] = field(
default=None,
metadata={'help': 'How many samples for validation?'}
)
lora_tune: bool = field(
default=False,
metadata={"help": "Use LoRA fine-tuning?"},
)
lora_rank: int = field(
default=32,
metadata={'help': 'LoRA rank.'}
)
lora_alpha: int = field(
default=16,
metadata={'help': 'LoRA scaling factor.'}
)
lora_dropout: float = field(
default=0.,
metadata={'help': 'LoRA dropout p.'}
)
lora_targets: List[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"],
metadata={"help": "Module name patterns to add LoRA."},
)
lora_extra_params: List[str] = field(
default_factory=lambda: ["embed_tokens", "norm"],
metadata={"help": "Extra trainable parameters except LoRA weights, if low rank training."},
)
metrics: List[str] = field(
default_factory=lambda: [],
metadata={'help': 'List of metrics. {rouge, save_result}'}
)
log_path: str = field(
default="data/outputs/metrics.log",
metadata={'help': 'Log file path.'}
)
def __post_init__(self):
if self.use_reentrant is not None:
self.gradient_checkpointing_kwargs = {"use_reentrant": self.use_reentrant}
return super().__post_init__()