import os from dataclasses import dataclass, field from transformers.training_args import TrainingArguments from typing import Optional, List, Union @dataclass class BaseArgs: model_cache_dir: Optional[str] = field( default=None, metadata={'help': 'Default path to save language models.'} ) dataset_cache_dir: Optional[str] = field( default=None, metadata={'help': 'Default path to save huggingface datasets.'} ) data_root: str = field( default="/data/llm-embedder", metadata={'help': 'The base directory storing all data used for training and evaluation. If specified, make sure all train_data, eval_data, and corpus are path relative to data_root!'}, ) train_data: Optional[List[str]] = field( default=None, metadata={'help': 'Training json file or glob to match a list of files.'}, ) eval_data: Optional[str] = field( default=None, metadata={'help': 'Evaluation json file.'}, ) corpus: str = field( default=None, metadata={'help': 'Corpus jsonl file.'} ) key_template: str = field( default="{title} {text}", metadata={'help': 'How to concatenate columns in the corpus to form one key?'} ) metrics: List[str] = field( default_factory=lambda: ["mrr", "recall", "ndcg"], metadata={'help': 'List of metrics'} ) cutoffs: List[int] = field( default_factory=lambda: [1, 5, 10, 100], metadata={'help': 'Cutoffs to evaluate retrieval metrics.'} ) filter_answers: bool = field( default=False, metadata={'help': 'Remove negatives that contain the desired answer when collating negatives?'} ) max_neg_num: int = field( default=100, metadata={'help': 'Maximum negative number to mine.'} ) load_result: bool = field( default=False, metadata={'help': 'Load retrieval results directly?'} ) save_result: bool = field( default=True, metadata={'help': 'Save retrieval results?'} ) save_name: Optional[str] = field( default=None, metadata={'help': 'Name suffix of the json file when saving the collated retrieval results.'} ) save_to_output: bool = field( default=False, metadata={'help': 'Save the result/key/negative to output_dir? If not true, they will be saved next to the eval_data.'} ) def resolve_path(self, path): """Resolve any path starting with 'llm-embedder:' to relative path against data_root.""" pattern = "llm-embedder:" # resolve relative data paths when necessary if isinstance(path, list): for i, x in enumerate(path): if x.startswith(pattern): path[i] = os.path.join(self.data_root, x.replace(pattern, "")) else: if path.startswith(pattern): path = os.path.join(self.data_root, path.replace(pattern, "")) return path def __post_init__(self): if self.train_data is not None: self.train_data = self.resolve_path(self.train_data) if self.eval_data is not None: self.eval_data = self.resolve_path(self.eval_data) if self.corpus is not None: self.corpus = self.resolve_path(self.corpus) @dataclass class DenseRetrievalArgs(BaseArgs): query_encoder: str = field( default="BAAI/bge-base-en", metadata={'help': 'Path to encoder model or model identifier from huggingface.co/models.'} ) key_encoder: str = field( default="BAAI/bge-base-en", metadata={'help': 'Path to encoder model or model identifier from huggingface.co/models.'} ) add_instruction: bool = field( default=True, metadata={'help': 'Add instruction for each task?'} ) version: str = field( default="bge", metadata={'help': 'Version for configs.'} ) query_max_length: int = field( default=256, metadata={'help': 'Max query length.'} ) key_max_length: int = field( default=256, metadata={'help': 'Max key length.'} ) truncation_side: str = field( default="right", metadata={'help': 'Which side to truncate?'} ) pooling_method: List[str] = field( default_factory=lambda: ["cls"], metadata={'help': 'Pooling methods to aggregate token embeddings for a sequence embedding. {cls, mean, dense, decoder}'} ) tie_encoders: bool = field( default=True, metadata={'help': 'Tie query encoder and key encoder? If True, then the query_encoder_name is used.'} ) dense_metric: str = field( default="cos", metadata={'help': 'What type of metric for dense retrieval? ip, l2, or cos.'} ) faiss_index_factory: str = field( default="Flat", metadata={'help': 'Index factory string for faiss.'} ) hits: int = field( default=200, metadata={'help': 'How many keys to retrieve?'} ) batch_size: int = field( default=1000, metadata={'help': 'Batch size for indexing and retrieval.'} ) load_encode: bool = field( default=False, metadata={'help': 'Load cached embeddings?'} ) save_encode: bool = field( default=False, metadata={'help': 'Save embeddings?'} ) load_index: bool = field( default=False, metadata={'help': 'Load cached index?'} ) save_index: bool = field( default=False, metadata={'help': 'Save index?'} ) embedding_name: str = field( default="embeddings", metadata={'help': 'The embedding name for saving? (Also used for faiss index name.)'} ) dtype: str = field( default="fp16", metadata={'help': 'Data type for retriever.'} ) cpu: bool = field( default=False, metadata={'help': 'Use cpu?'} ) @dataclass class BM25Args(BaseArgs): anserini_dir: str = field( default='/share/peitian/Apps/anserini', metadata={'help': 'Anserini installation directory.'} ) k1: float = field( default=0.82, metadata={'help': 'BM25 k1.'} ) b: float = field( default=0.68, metadata={'help': 'BM25 b.'} ) storeDocvectors: bool = field( default=False, metadata={'help': 'Store document vector? Useful when you want to inspect the word-level statistics (tf-idf) after index construction.'} ) hits: int = field( default=200, metadata={'help': 'How many keys to retrieve?'} ) language: str = field( default="en", metadata={'help': 'Language.'} ) threads: int = field( default=32, metadata={'help': 'Indexing/Searching thread number.'} ) load_index: bool = field( default=False, metadata={'help': 'Load index?'} ) load_collection: bool = field( default=False, metadata={'help': 'Load collection?'} ) @dataclass class RankerArgs(BaseArgs): ranker: str = field( default="BAAI/bge-base-en", metadata={'help': 'Ranker name or path.'} ) ranker_method: str = field( default="cross-encoder", metadata={'help': 'What kind of ranker to use? {cross: cross encoder}'} ) dtype: str = field( default="fp16", metadata={'help': 'Data type for ranker.'} ) query_max_length: int = field( default=256, metadata={'help': 'Max query length.'} ) key_max_length: int = field( default=256, metadata={'help': 'Max key length.'} ) add_instruction: bool = field( default=False, metadata={'help': 'Add instruction for each task?'} ) version: str = field( default="bge", metadata={'help': 'Version for configs.'} ) hits: Optional[int] = field( default=None, metadata={'help': 'How many top reranked keys to keep?'} ) batch_size: int = field( default=4, metadata={'help': 'Batch size for indexing and retrieval.'} ) cpu: bool = field( default=False, metadata={'help': 'Use cpu?'} ) @dataclass class RetrievalArgs(DenseRetrievalArgs, BM25Args): retrieval_method: str = field( default="dense", metadata={'help': 'How to retrieve? {dense, bm25, random, no}'} ) @dataclass class RetrievalTrainingArgs(TrainingArguments): output_dir: str = field( default='data/outputs/', metadata={'help': 'The output directory where the model predictions and checkpoints will be written.'}, ) eval_method: str = field( default="retrieval", metadata={'help': 'How to evaluate?'}, ) use_train_config: bool = field( default=False, metadata={'help': 'Use training config from TASK_CONFIG to override arguments?'} ) inbatch_same_dataset: Optional[str] = field( default=None, metadata={'help': 'Whether and how to use samples from the same task in each batch (across devices). {epoch, random}'} ) negative_cross_device: bool = field( default=True, metadata={'help': 'Gather negatives from all devices when distributed training?'} ) cos_temperature: float = field( default=0.01, metadata={'help': 'Temperature used for cosine dense metric.'} ) teacher_temperature:float = field( default=1., metadata={'help': 'Temperature used for cosine dense metric.'} ) student_temperature:float = field( default=1., metadata={'help': 'Temperature used for cosine dense metric.'} ) contrastive_weight: float = field( default=0.2, metadata={'help': 'Weight for contrastive loss.'} ) distill_weight: float = field( default=1.0, metadata={'help': 'Weight for distillation loss.'} ) stable_distill: bool = field( default=False, metadata={'help': 'Sort distillation.'} ) max_sample_num: Optional[int] = field( default=None, metadata={'help': 'How many samples at most for training dataset?'} ) train_group_size: int = field( default=8, metadata={'help': 'How many keys in a batch?'} ) select_positive: str = field( default="first", metadata={'help': 'How to select the positive key from a set of positives?'} ) select_negative: str = field( default="random", metadata={'help': 'How to select the negative keys from a set of negatives?'} ) teacher_scores_margin: Optional[float] = field( default=None, metadata={'help': 'Minimum margin in teacher_scores. The samples with smaller margin will be removed from training.'} ) teacher_scores_min: Optional[float] = field( default=None, metadata={'help': 'Minimum teacher_scores. The samples whose biggest score is lower than this will be removed from training.'} ) per_device_train_batch_size: int = field( default=16, metadata={'help': 'Train batch size'}, ) learning_rate: float = field( default=5e-6, metadata={'help': 'Learning rate.'}, ) warmup_ratio: float = field( default=0.1, metadata={'help': 'Warmup ratio for linear scheduler.'}, ) weight_decay: float = field( default=0.01, metadata={'help': 'Weight decay in AdamW.'}, ) fp16: bool = field( default=True, metadata={'help': 'Use fp16 training?'} ) ddp_find_unused_parameters: bool = field( default=False, metadata={'help': 'Find unused parameters in torch DDP?'}, ) remove_unused_columns: bool = field( default=False, metadata={'help': 'Remove columns that are not registered in the forward function of the model?'}, ) evaluation_strategy: str = field( default='steps', metadata={'help': 'Evaluation strategy'}, ) save_steps: int = field( default=2000, metadata={'help': 'Saving frequency.'}, ) logging_steps: int = field( default=100, metadata={'help': 'Logging frequency according to logging strategy.'}, ) early_exit_steps: Optional[int] = field( default=None, metadata={'help': 'After how many steps to exit training loop.'}, ) report_to: str = field( default="none", metadata={"help": "The list of integrations to report the results and logs to."} ) log_path: str = field( default="data/results/performance.log", metadata={'help': 'Pooling method to aggregate token embeddings for a sequence embedding.'} ) # NOTE: newer version of transformers forbid modifying the configs after initilization, we bypass this setting def __setattr__(self, name, value): super(TrainingArguments, self).__setattr__(name, value) def __post_init__(self): super().__post_init__() # for convenience # self.eval_steps = self.save_steps