embed-bge-m3/FlagEmbedding/research/Long_LLM/longllm_qlora/main/train.py

113 lines
4.0 KiB
Python

import torch
import logging
from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled
from src import (
Data,
DefaultDataCollator,
ModelArgs,
FileLogger,
get_model_and_tokenizer,
makedirs,
format_numel_str
)
from src.args import TrainingArgs
from src.metrics import Metric
from src.trainer import LLMTrainer
logger = logging.getLogger(__name__)
def main():
parser = HfArgumentParser([ModelArgs, TrainingArgs])
model_args, training_args = parser.parse_args_into_dataclasses()
tokenizer = get_model_and_tokenizer(model_args, return_tokenizer_only=True, evaluation_mode=False)
# NOTE: must import here, otherwise raise invalid device error
from unsloth import FastLanguageModel
if model_args.load_in_4_bit:
device_map = None
else:
device_map = {"": "cuda"}
model, _ = FastLanguageModel.from_pretrained(
model_name = model_args.model_name_or_path,
max_seq_length = model_args.max_length,
dtype = torch.bfloat16,
device_map=device_map,
load_in_4bit = model_args.load_in_4_bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
token=model_args.access_token,
cache_dir=model_args.model_cache_dir,
rope_theta=model_args.rope_theta,
)
model = FastLanguageModel.get_peft_model(
model,
r = training_args.lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = training_args.lora_targets,
modules_to_save=training_args.lora_extra_params,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
print(model.config)
logger.info(f"Trainable Model params: {format_numel_str(sum(p.numel() for p in model.parameters() if p.requires_grad))}")
with training_args.main_process_first():
train_dataset = Data.prepare_train_data(
model_args.train_data,
tokenizer=tokenizer,
max_length=model_args.max_length,
min_length=training_args.min_length,
chat_template=model_args.chat_template,
seed=training_args.seed,
cache_dir=model_args.dataset_cache_dir,
)
with training_args.main_process_first():
if is_deepspeed_zero3_enabled() and training_args.eval_method != "perplexity":
logger.warning(f"In deepspeed zero3, evaluation with generation is may lead to hang because of the unequal number of forward passes across different devices.")
eval_dataset = Data.prepare_eval_data(
model_args.eval_data,
tokenizer=tokenizer,
max_length=training_args.eval_max_length,
min_length=training_args.eval_min_length,
chat_template=model_args.chat_template,
seed=training_args.seed,
cache_dir=model_args.dataset_cache_dir,
)
trainer = LLMTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
model_args=model_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=DefaultDataCollator(tokenizer),
file_logger=FileLogger(makedirs(training_args.log_path)),
compute_metrics=Metric.get_metric_fn(
metrics=training_args.metrics,
save_path=Metric.get_save_path(
model_args.eval_data,
training_args.output_dir
) if model_args.eval_data is not None else None
)
)
if train_dataset is not None:
trainer.train()
elif eval_dataset is not None:
trainer.evaluate()
if __name__ == "__main__":
main()