embed-bge-m3/FlagEmbedding/research/Long_LLM/activation_beacon/examples/training.md

6.2 KiB

Training

There are two stages in training:

  • Pretrain

    • 1B token from redpajama with auto-regressive language modeling
    • Add eos to each document and no packing
    • 20K context length at maximum
  • Finetune

    • 5K samples from LongAlpaca, 2K samples from Booksum, 16K synthetic long-context QA data from GPT-3.5, and 5K samples from pretraining data
    • 20K context length at maximum

Prerequisite

Make sure you have created the environment and downloaded the data according to README.

Mistral

Pretrain

output_name=beacon-mistral-pretrain

torchrun --nproc_per_node 8 $DDP -m main.train \
--output_dir data/outputs/$output_name \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
--train_data long-llm:redpajama/train.json \
--min_length 2400 \
--max_length 20000 \
--group_by_stride strict \
--enable_beacon \
--beacon_window 2048 \
--beacon_stride 2048 \
--beacon_attn full-coverage \
--beacon_attend_prev True \
--beacon_sink_size 0 \
--beacon_ratio 2 4 8 16 32 \
--beacon_ratio_mix step-random \
--beacon_param q k v \
--beacon_pos interleave \
--attn_impl flash_attention_2 \
--gradient_checkpointing \
--use_reentrant False \
--save_only_model \
--save_strategy epoch \
--evaluation_strategy steps \
--num_train_epochs 1 \
--logging_steps 50 \
--bf16 \
--deepspeed data/deepspeed/stage2.json

Finetune

output_name=beacon-mistral-finetune

torchrun --nproc_per_node 8 $DDP -m main.train \
--output_dir data/outputs/$output_name \
--model_name_or_path data/outputs/beacon-mistral-pretrain/* \
--train_data long-llm:gpt/one_detail_book.train.16K.json long-llm:gpt/one_detail_paper.train.16K.json long-llm:longalpaca/train.json long-llm:booksum/train.16K.json long-llm:needle/train.16K.json long-llm:redpajama/train.json[5000] \
--max_length 20000 \
--min_length 7200 \
--group_by_stride strict \
--enable_beacon \
--beacon_window 2048 \
--beacon_stride 2048 \
--beacon_attn full-coverage \
--beacon_attend_prev True \
--beacon_sink_size 0 \
--beacon_ratio 2 4 8 \
--beacon_ratio_mix step-random \
--beacon_param q k v \
--beacon_pos interleave \
--attn_impl flash_attention_2 \
--learning_rate 1e-5 \
--gradient_checkpointing \
--use_reentrant False \
--save_only_model \
--num_train_epochs 1 \
--save_strategy epoch \
--logging_steps 50 \
--bf16 \
--deepspeed data/deepspeed/stage2.json \
--chat_template mistral

Llama-3

NOTE: according to our experiment, Llama-3 requires attention sink.

Pretrain

output_name=beacon-llama3-pretrain

torchrun --nproc_per_node 8 $DDP -m main.train \
--output_dir data/outputs/$output_name \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--train_data long-llm:redpajama/train.json \
--min_length 2400 \
--max_length 20000 \
--group_by_stride strict \
--enable_beacon \
--beacon_window 1024 \
--beacon_stride 1024 \
--beacon_attn full-coverage \
--beacon_attend_prev True \
--beacon_sink_size 1 \
--beacon_ratio 2 4 8 16 32 \
--beacon_ratio_mix step-random \
--beacon_param q k v \
--beacon_pos interleave \
--attn_impl flash_attention_2 \
--gradient_checkpointing \
--use_reentrant False \
--save_only_model \
--save_strategy epoch \
--evaluation_strategy steps \
--num_train_epochs 1 \
--logging_steps 50 \
--bf16 \
--deepspeed data/deepspeed/stage2.json

Finetune

output_name=beacon-llama3-finetune

torchrun --nproc_per_node 8 $DDP -m main.train \
--output_dir data/outputs/$output_name \
--model_name_or_path data/outputs/beacon-llama3-pretrain/* \
--train_data long-llm:gpt/one_detail_book.train.16K.json long-llm:gpt/one_detail_paper.train.16K.json long-llm:longalpaca/train.json long-llm:booksum/train.16K.json long-llm:needle/train.16K.json long-llm:redpajama/train.json[5000] \
--max_length 20000 \
--min_length 7200 \
--group_by_stride strict \
--enable_beacon \
--beacon_window 1024 \
--beacon_stride 1024 \
--beacon_attn full-coverage \
--beacon_attend_prev True \
--beacon_sink_size 1 \
--beacon_ratio 2 4 8 \
--beacon_ratio_mix step-random \
--beacon_param q k v \
--beacon_pos interleave \
--attn_impl flash_attention_2 \
--learning_rate 1e-5 \
--gradient_checkpointing \
--use_reentrant False \
--save_only_model \
--num_train_epochs 1 \
--save_strategy epoch \
--logging_steps 50 \
--bf16 \
--deepspeed data/deepspeed/stage2.json \
--chat_template llama-3

Qwen-2

Pretrain

output_name=beacon-qwen2-pretrain

torchrun --nproc_per_node 8 $DDP -m main.train \
--output_dir data/outputs/$output_name \
--model_name_or_path Qwen/Qwen2-7B-Instruct \
--train_data long-llm:redpajama/train.json \
--min_length 2400 \
--max_length 20000 \
--group_by_stride strict \
--enable_beacon \
--beacon_window 2048 \
--beacon_stride 2048 \
--beacon_attn full-coverage \
--beacon_attend_prev True \
--beacon_sink_size 0 \
--beacon_ratio 2 4 8 16 32 \
--beacon_ratio_mix step-random \
--beacon_param q k v \
--beacon_pos interleave \
--attn_impl flash_attention_2 \
--gradient_checkpointing \
--use_reentrant False \
--save_only_model \
--save_strategy epoch \
--evaluation_strategy steps \
--num_train_epochs 1 \
--logging_steps 50 \
--bf16 \
--deepspeed data/deepspeed/stage2.json

Finetune

torchrun --nproc_per_node 8 $DDP -m main.train \
--output_dir data/outputs/$output_name \
--model_name_or_path data/outputs/beacon-qwen2-pretrain/* \
--train_data long-llm:gpt/one_detail_book.train.16K.json long-llm:gpt/one_detail_paper.train.16K.json long-llm:longalpaca/train.json long-llm:booksum/train.16K.json long-llm:needle/train.16K.json long-llm:redpajama/train.json[5000] \
--max_length 20000 \
--min_length 7200 \
--group_by_stride strict \
--enable_beacon \
--beacon_window 2048 \
--beacon_stride 2048 \
--beacon_attn full-coverage \
--beacon_attend_prev True \
--beacon_sink_size 0 \
--beacon_ratio 2 4 8 \
--beacon_ratio_mix step-random \
--beacon_param q k v \
--beacon_pos interleave \
--attn_impl flash_attention_2 \
--learning_rate 1e-5 \
--gradient_checkpointing \
--use_reentrant False \
--save_only_model \
--num_train_epochs 1 \
--save_strategy epoch \
--logging_steps 50 \
--bf16 \
--deepspeed data/deepspeed/stage2.json \
--chat_template qwen