# Fine-tuning

In the previous section, we went through how to construct training and testing data properly. In this tutorial, we will actually fine-tune the model.

## Installation

Note to fine-tune BGE models using FlagEmbedding, we need to install the package with the finetune dependency:

In [None]:
% pip install -U FlagEmbedding[finetune]

## Fine-tune

Below are the arguments for fine-tuning:

The following arguments are for model:
- `model_name_or_path`: The model checkpoint for initialization.
- `config_name`: Pretrained config name or path if not the same as model_name.
- `tokenizer_name`: Pretrained tokenizer name or path if not the same as model_name.
- `cache_dir`: Where do you want to store the pre-trained models downloaded from s3.
- `trust_remote_code`: Trust remote code
- `token`: The token to use when accessing the model.

The following arguments are for data:
- `train_data`: One or more paths to training data. `query: str`, `pos: List[str]`, `neg: List[str]` are required in the training data. Argument type: multiple.
- `cache_path`: Where do you want to store the cached data.
- `train_group_size`: (No metadata provided)
- `query_max_len`: The maximum total input sequence length after tokenization for passage. Sequences longer than this will be truncated.
- `passage_max_len`: The maximum total input sequence length after tokenization for passage. Sequences longer than this will be truncated.
- `pad_to_multiple_of`: If set will pad the sequence to be a multiple of the provided value.
- `max_example_num_per_dataset`: The max number of examples for each dataset.
- `query_instruction_for_retrieval`: Instruction for query.
- `query_instruction_format`: Format for query instruction.
- `knowledge_distillation`: Use knowledge distillation when `pos_scores: List[float]` and `neg_scores: List[float]` are in features of training data.
- `passage_instruction_for_retrieval`: Instruction for passage.
- `passage_instruction_format`: Format for passage instruction.
- `shuffle_ratio`: The ratio of shuffling the text.
- `same_dataset_within_batch`: All samples in the same batch comes from the same dataset.
- `small_threshold`: The threshold of small dataset. All small dataset in the same directory will be merged into one dataset.
- `drop_threshold`: The threshold for dropping merged small dataset. If the number of examples in the merged small dataset is less than this threshold, it will be dropped.

And the following extra arguments:
- `negatives_cross_device`: Share negatives across devices.
- `temperature`: Temperature used for similarity score.
- `fix_position_embedding`: Freeze the parameters of position embeddings.
- `sentence_pooling_method`: The pooling method. Available options: cls, mean, last_token. Default: cls.
- `normalize_embeddings`: Whether to normalize the embeddings.
- `sub_batch_size`: Sub batch size for training.
- `kd_loss_type`: The loss type for knowledge distillation. Available options: kl_div, m3_kd_loss. Default: kl_div.

In [1]:
%%bash
torchrun --nproc_per_node 2 \
	-m FlagEmbedding.finetune.embedder.encoder_only.base \
	--model_name_or_path BAAI/bge-large-en-v1.5 \
    --cache_dir ./cache/model \
    --train_data ./ft_data/training.json \
    --cache_path ./cache/data \
    --train_group_size 8 \
    --query_max_len 512 \
    --passage_max_len 512 \
    --pad_to_multiple_of 8 \
    --query_instruction_for_retrieval 'Represent this sentence for searching relevant passages: ' \
    --query_instruction_format '{}{}' \
    --knowledge_distillation False \
	--output_dir ./test_encoder_only_base_bge-large-en-v1.5 \
    --overwrite_output_dir \
    --learning_rate 1e-5 \
    --fp16 \
    --num_train_epochs 2 \
    --per_device_train_batch_size 2 \
    --dataloader_drop_last True \
    --warmup_ratio 0.1 \
    --gradient_checkpointing \
    --deepspeed config/ds_stage0.json \
    --logging_steps 1 \
    --save_steps 1000 \
    --negatives_cross_device \
    --temperature 0.02 \
    --sentence_pooling_method cls \
    --normalize_embeddings True \
    --kd_loss_type kl_div

W1223 06:27:06.807000 1362426 site-packages/torch/distributed/run.py:793] 
W1223 06:27:06.807000 1362426 site-packages/torch/distributed/run.py:793] *****************************************
W1223 06:27:06.807000 1362426 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1223 06:27:06.807000 1362426 site-packages/torch/distributed/run.py:793] *****************************************


[2024-12-23 06:27:31,423] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-12-23 06:27:31,424] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-12-23 06:27:40,529] [INFO] [comm.py:652:init_distributed] cdb=None
[2024-12-23 06:27:40,529] [INFO] [comm.py:652:init_distributed] cdb=None
[2024-12-23 06:27:40,529] [INFO] [comm.py:683:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl


12/23/2024 06:27:40 - INFO - FlagEmbedding.abc.finetune.embedder.AbsRunner -   Training/evaluation parameters AbsEmbedderTrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
batch_eval_metrics=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=True,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=config/ds_stage0.json,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_do_c

[1734935704.354551] [job-40fb0ce3-8bfb-46ea-b409-0a2e2a1a3163-master-0:1362491:f]        vfs_fuse.c:281  UCX  ERROR inotify_add_watch(/tmp) failed: No space left on device
[1734935704.383634] [job-40fb0ce3-8bfb-46ea-b409-0a2e2a1a3163-master-0:1362492:f]        vfs_fuse.c:281  UCX  ERROR inotify_add_watch(/tmp) failed: No space left on device


Using /root/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...
Using /root/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py311_cu124/fused_adam/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module fused_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module fused_adam...


Time to load fused_adam op: 1.1966907978057861 seconds


Loading extension module fused_adam...


Time to load fused_adam op: 1.2037739753723145 seconds


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


{'loss': 0.0124, 'grad_norm': 1.0943871958089542, 'learning_rate': 0.0, 'epoch': 0.0}
{'loss': 0.1189, 'grad_norm': 9.971958134471109, 'learning_rate': 1.2049342512977792e-06, 'epoch': 0.0}
{'loss': 0.0067, 'grad_norm': 0.676847884003986, 'learning_rate': 1.9097756041415023e-06, 'epoch': 0.0}
{'loss': 1.5215, 'grad_norm': 40.51544573089919, 'learning_rate': 2.4098685025955585e-06, 'epoch': 0.0}
{'loss': 0.0111, 'grad_norm': 0.8537607081175989, 'learning_rate': 2.7977706905803826e-06, 'epoch': 0.0}
{'loss': 0.0019, 'grad_norm': 0.1699944264536089, 'learning_rate': 3.1147098554392813e-06, 'epoch': 0.0}
{'loss': 0.0003, 'grad_norm': 0.026271846378513198, 'learning_rate': 3.3826781011366144e-06, 'epoch': 0.0}
{'loss': 0.0039, 'grad_norm': 0.3161338881928349, 'learning_rate': 3.614802753893337e-06, 'epoch': 0.01}
{'loss': 0.0351, 'grad_norm': 2.335078256835444, 'learning_rate': 3.8195512082830046e-06, 'epoch': 0.01}
{'loss': 0.1005, 'grad_norm': 10.32570731855295, 'learning_rate': 4.0027049

 13%|█▎        | 417/3150 [01:45<10:51,  4.20it/s]

{'loss': 0.0052, 'grad_norm': 0.5387894588544504, 'learning_rate': 9.643738977072311e-06, 'epoch': 0.26}
{'loss': 0.002, 'grad_norm': 0.20979235778898053, 'learning_rate': 9.64021164021164e-06, 'epoch': 0.27}
{'loss': 0.0002, 'grad_norm': 0.026038436142895877, 'learning_rate': 9.63668430335097e-06, 'epoch': 0.27}
{'loss': 0.0, 'grad_norm': 0.00018212249686265307, 'learning_rate': 9.6331569664903e-06, 'epoch': 0.27}
{'loss': 0.0083, 'grad_norm': 1.033955002999129, 'learning_rate': 9.62962962962963e-06, 'epoch': 0.27}
{'loss': 0.0023, 'grad_norm': 0.343699549093858, 'learning_rate': 9.62610229276896e-06, 'epoch': 0.27}
{'loss': 0.0, 'grad_norm': 0.0010749272909065962, 'learning_rate': 9.622574955908291e-06, 'epoch': 0.27}
{'loss': 0.0, 'grad_norm': 0.00010975655595019302, 'learning_rate': 9.61904761904762e-06, 'epoch': 0.27}
{'loss': 0.0051, 'grad_norm': 0.4788360612721627, 'learning_rate': 9.61552028218695e-06, 'epoch': 0.27}
{'loss': 0.0, 'grad_norm': 0.00011672187140924894, 'learning_

 26%|██▌       | 818/3150 [03:29<10:50,  3.58it/s]

{'loss': 0.0002, 'grad_norm': 0.01366528763197738, 'learning_rate': 8.229276895943562e-06, 'epoch': 0.52}
{'loss': 0.0, 'grad_norm': 9.54747062331347e-05, 'learning_rate': 8.225749559082893e-06, 'epoch': 0.52}
{'loss': 0.0, 'grad_norm': 0.00011539470773831022, 'learning_rate': 8.222222222222222e-06, 'epoch': 0.52}
{'loss': 0.2018, 'grad_norm': 16.709705680113448, 'learning_rate': 8.218694885361552e-06, 'epoch': 0.52}
{'loss': 0.0, 'grad_norm': 0.0007446771029235906, 'learning_rate': 8.215167548500883e-06, 'epoch': 0.52}
{'loss': 0.543, 'grad_norm': 22.15912234003999, 'learning_rate': 8.211640211640213e-06, 'epoch': 0.52}
{'loss': 0.0003, 'grad_norm': 0.030122672840349505, 'learning_rate': 8.208112874779542e-06, 'epoch': 0.52}
{'loss': 0.0002, 'grad_norm': 0.013163602206137692, 'learning_rate': 8.204585537918873e-06, 'epoch': 0.52}
{'loss': 0.0027, 'grad_norm': 0.18348203466131782, 'learning_rate': 8.201058201058202e-06, 'epoch': 0.52}
{'loss': 0.0001, 'grad_norm': 0.011142931175322368,

 32%|███▏      | 1000/3150 [04:16<08:47,  4.08it/s]12/23/2024 06:39:23 - INFO - FlagEmbedding.finetune.embedder.encoder_only.base.trainer -   Saving model checkpoint to ./test_encoder_only_base_bge-large-en-v1.5/checkpoint-1000


{'loss': 0.0001, 'grad_norm': 0.0052094104905307205, 'learning_rate': 7.583774250440918e-06, 'epoch': 0.64}
{'loss': 0.0004, 'grad_norm': 0.05414591780232195, 'learning_rate': 7.580246913580247e-06, 'epoch': 0.64}
{'loss': 0.0, 'grad_norm': 0.005033967507836883, 'learning_rate': 7.576719576719578e-06, 'epoch': 0.64}
{'loss': 0.0002, 'grad_norm': 0.01698784361595978, 'learning_rate': 7.573192239858908e-06, 'epoch': 0.64}
{'loss': 0.0, 'grad_norm': 0.00047723063982967767, 'learning_rate': 7.569664902998237e-06, 'epoch': 0.64}
{'loss': 0.0006, 'grad_norm': 0.0427643550196247, 'learning_rate': 7.566137566137567e-06, 'epoch': 0.64}
{'loss': 0.0189, 'grad_norm': 2.0302958668418953, 'learning_rate': 7.562610229276897e-06, 'epoch': 0.64}
{'loss': 0.0001, 'grad_norm': 0.00556046268974225, 'learning_rate': 7.5590828924162264e-06, 'epoch': 0.64}
{'loss': 0.0, 'grad_norm': 0.0005143339470081945, 'learning_rate': 7.555555555555556e-06, 'epoch': 0.64}
{'loss': 0.0001, 'grad_norm': 0.0107442057404229

 44%|████▎     | 1377/3150 [05:55<08:04,  3.66it/s]  

{'loss': 0.0159, 'grad_norm': 1.7536696378495038, 'learning_rate': 6.2610229276895955e-06, 'epoch': 0.87}
{'loss': 0.0004, 'grad_norm': 0.03909537600833843, 'learning_rate': 6.257495590828925e-06, 'epoch': 0.88}
{'loss': 0.0, 'grad_norm': 1.3072261291430532e-05, 'learning_rate': 6.253968253968254e-06, 'epoch': 0.88}
{'loss': 0.0137, 'grad_norm': 0.9366626178635848, 'learning_rate': 6.250440917107584e-06, 'epoch': 0.88}
{'loss': 0.0006, 'grad_norm': 0.06750650731592978, 'learning_rate': 6.2469135802469135e-06, 'epoch': 0.88}
{'loss': 0.0066, 'grad_norm': 0.6520149178816838, 'learning_rate': 6.243386243386243e-06, 'epoch': 0.88}
{'loss': 0.0001, 'grad_norm': 0.007519813360458526, 'learning_rate': 6.239858906525573e-06, 'epoch': 0.88}
{'loss': 0.3818, 'grad_norm': 22.781509879347606, 'learning_rate': 6.236331569664904e-06, 'epoch': 0.88}
{'loss': 0.0, 'grad_norm': 0.0004221153474469201, 'learning_rate': 6.232804232804234e-06, 'epoch': 0.88}
{'loss': 0.0109, 'grad_norm': 1.572822121648162,

 55%|█████▌    | 1745/3150 [07:21<04:42,  4.97it/s]

{'loss': 0.0009, 'grad_norm': 0.09674130508777008, 'learning_rate': 4.966490299823634e-06, 'epoch': 1.11}
{'loss': 0.0, 'grad_norm': 0.0059433610812439225, 'learning_rate': 4.962962962962964e-06, 'epoch': 1.11}
{'loss': 0.0, 'grad_norm': 0.0023787512831038777, 'learning_rate': 4.959435626102293e-06, 'epoch': 1.11}
{'loss': 0.0001, 'grad_norm': 0.0053428588825208895, 'learning_rate': 4.955908289241623e-06, 'epoch': 1.11}
{'loss': 0.0006, 'grad_norm': 0.06979952889194821, 'learning_rate': 4.952380952380953e-06, 'epoch': 1.11}
{'loss': 0.0015, 'grad_norm': 0.12311688959209235, 'learning_rate': 4.9488536155202825e-06, 'epoch': 1.11}
{'loss': 0.0, 'grad_norm': 0.005659354743994171, 'learning_rate': 4.945326278659612e-06, 'epoch': 1.11}
{'loss': 0.1051, 'grad_norm': 12.496530280160519, 'learning_rate': 4.941798941798942e-06, 'epoch': 1.11}
{'loss': 0.0, 'grad_norm': 7.150547607013999e-05, 'learning_rate': 4.938271604938272e-06, 'epoch': 1.11}
{'loss': 0.0013, 'grad_norm': 0.11405076724287291

 63%|██████▎   | 2000/3150 [08:21<04:30,  4.25it/s]12/23/2024 06:43:28 - INFO - FlagEmbedding.finetune.embedder.encoder_only.base.trainer -   Saving model checkpoint to ./test_encoder_only_base_bge-large-en-v1.5/checkpoint-2000


{'loss': 0.0001, 'grad_norm': 0.012578935167627541, 'learning_rate': 4.063492063492064e-06, 'epoch': 1.27}
{'loss': 0.0, 'grad_norm': 4.0970670215411106e-05, 'learning_rate': 4.059964726631394e-06, 'epoch': 1.27}
{'loss': 0.0001, 'grad_norm': 0.011658719653620064, 'learning_rate': 4.0564373897707236e-06, 'epoch': 1.27}
{'loss': 0.0, 'grad_norm': 0.0008131945372193306, 'learning_rate': 4.052910052910053e-06, 'epoch': 1.27}
{'loss': 0.0, 'grad_norm': 0.00026534978358113776, 'learning_rate': 4.049382716049383e-06, 'epoch': 1.27}
{'loss': 0.0, 'grad_norm': 1.1366631129487001e-05, 'learning_rate': 4.045855379188713e-06, 'epoch': 1.27}
{'loss': 0.0116, 'grad_norm': 1.3234954028653214, 'learning_rate': 4.042328042328042e-06, 'epoch': 1.27}
{'loss': 0.0004, 'grad_norm': 0.05145979726251188, 'learning_rate': 4.038800705467372e-06, 'epoch': 1.27}
{'loss': 0.0, 'grad_norm': 0.0003372150780671462, 'learning_rate': 4.035273368606703e-06, 'epoch': 1.28}
{'loss': 0.0, 'grad_norm': 0.00165260511842168

 75%|███████▍  | 2352/3150 [09:50<02:43,  4.89it/s]

{'loss': 0.0, 'grad_norm': 0.001067494846590443, 'learning_rate': 2.8218694885361552e-06, 'epoch': 1.49}
{'loss': 0.0, 'grad_norm': 0.0019456256210602489, 'learning_rate': 2.818342151675485e-06, 'epoch': 1.49}
{'loss': 0.003, 'grad_norm': 0.3920454900412361, 'learning_rate': 2.814814814814815e-06, 'epoch': 1.5}
{'loss': 0.0, 'grad_norm': 1.4884702153316084e-05, 'learning_rate': 2.811287477954145e-06, 'epoch': 1.5}
{'loss': 0.0, 'grad_norm': 0.000446312017993157, 'learning_rate': 2.8077601410934745e-06, 'epoch': 1.5}
{'loss': 0.0, 'grad_norm': 0.0011505404281340935, 'learning_rate': 2.8042328042328042e-06, 'epoch': 1.5}
{'loss': 1.3574, 'grad_norm': 22.403333742837667, 'learning_rate': 2.800705467372134e-06, 'epoch': 1.5}
{'loss': 0.0055, 'grad_norm': 0.5809690268068924, 'learning_rate': 2.797178130511464e-06, 'epoch': 1.5}
{'loss': 0.0, 'grad_norm': 0.0002727443025917454, 'learning_rate': 2.7936507936507938e-06, 'epoch': 1.5}
{'loss': 0.0, 'grad_norm': 0.00023945617647247515, 'learning

 86%|████████▌ | 2696/3150 [11:12<01:56,  3.88it/s]

{'loss': 0.0, 'grad_norm': 0.002211572384720586, 'learning_rate': 1.6084656084656086e-06, 'epoch': 1.71}
{'loss': 0.0086, 'grad_norm': 0.68364602025328, 'learning_rate': 1.6049382716049383e-06, 'epoch': 1.71}
{'loss': 0.0003, 'grad_norm': 0.026921721577691494, 'learning_rate': 1.6014109347442683e-06, 'epoch': 1.71}
{'loss': 0.0, 'grad_norm': 0.0001044867510325982, 'learning_rate': 1.597883597883598e-06, 'epoch': 1.71}
{'loss': 0.0, 'grad_norm': 0.00025342561809199815, 'learning_rate': 1.5943562610229279e-06, 'epoch': 1.71}
{'loss': 0.0, 'grad_norm': 2.4368396991437897e-05, 'learning_rate': 1.5908289241622576e-06, 'epoch': 1.72}
{'loss': 0.0, 'grad_norm': 5.531408833197814e-06, 'learning_rate': 1.5873015873015873e-06, 'epoch': 1.72}
{'loss': 0.0, 'grad_norm': 9.209012136005157e-05, 'learning_rate': 1.5837742504409172e-06, 'epoch': 1.72}
{'loss': 0.0001, 'grad_norm': 0.015541857089792681, 'learning_rate': 1.580246913580247e-06, 'epoch': 1.72}
{'loss': 0.0, 'grad_norm': 0.0006011956716473

 95%|█████████▌| 3000/3150 [12:29<00:34,  4.40it/s]12/23/2024 06:47:36 - INFO - FlagEmbedding.finetune.embedder.encoder_only.base.trainer -   Saving model checkpoint to ./test_encoder_only_base_bge-large-en-v1.5/checkpoint-3000


{'loss': 0.0684, 'grad_norm': 7.017930091803245, 'learning_rate': 5.361552028218695e-07, 'epoch': 1.91}
{'loss': 0.0, 'grad_norm': 0.000511339045166218, 'learning_rate': 5.326278659611994e-07, 'epoch': 1.91}
{'loss': 0.0, 'grad_norm': 0.0005950045300509824, 'learning_rate': 5.291005291005291e-07, 'epoch': 1.91}
{'loss': 0.0, 'grad_norm': 0.003924217893271839, 'learning_rate': 5.255731922398589e-07, 'epoch': 1.91}
{'loss': 0.0, 'grad_norm': 0.004307797663678439, 'learning_rate': 5.220458553791887e-07, 'epoch': 1.91}
{'loss': 0.0, 'grad_norm': 2.603897655768594e-05, 'learning_rate': 5.185185185185186e-07, 'epoch': 1.91}
{'loss': 0.0, 'grad_norm': 5.483371389463175e-06, 'learning_rate': 5.149911816578484e-07, 'epoch': 1.91}
{'loss': 0.4844, 'grad_norm': 20.3684123804579, 'learning_rate': 5.114638447971781e-07, 'epoch': 1.91}
{'loss': 0.0001, 'grad_norm': 0.01247325742979955, 'learning_rate': 5.07936507936508e-07, 'epoch': 1.91}
{'loss': 0.0004, 'grad_norm': 0.05214513400326202, 'learning_

100%|██████████| 3150/3150 [13:10<00:00,  3.72it/s]12/23/2024 06:48:17 - INFO - FlagEmbedding.finetune.embedder.encoder_only.base.trainer -   Saving model checkpoint to ./test_encoder_only_base_bge-large-en-v1.5/checkpoint-3150


{'train_runtime': 799.0537, 'train_samples_per_second': 15.769, 'train_steps_per_second': 3.942, 'train_loss': 0.04348497095562163, 'epoch': 2.0}


100%|██████████| 3150/3150 [13:19<00:00,  3.94it/s]
12/23/2024 06:48:26 - INFO - FlagEmbedding.finetune.embedder.encoder_only.base.trainer -   Saving model checkpoint to ./test_encoder_only_base_bge-large-en-v1.5
