98 lines
3.3 KiB
Python
98 lines
3.3 KiB
Python
import logging
|
|
from transformers import HfArgumentParser
|
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
|
from src import (
|
|
Data,
|
|
DefaultDataCollator,
|
|
ModelArgs,
|
|
FileLogger,
|
|
get_model_and_tokenizer,
|
|
makedirs,
|
|
format_numel_str
|
|
)
|
|
from src.args import TrainingArgs
|
|
from src.metrics import Metric
|
|
from src.trainer import ActivationBeaconTrainer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def main():
|
|
parser = HfArgumentParser([ModelArgs, TrainingArgs])
|
|
model_args, training_args = parser.parse_args_into_dataclasses()
|
|
|
|
model, tokenizer = get_model_and_tokenizer(model_args, device="cuda", evaluation_mode=False)
|
|
|
|
if model_args.enable_beacon and training_args.only_train_beacon:
|
|
for name, param in model.named_parameters():
|
|
if "beacon" not in name:
|
|
param.requires_grad_(False)
|
|
|
|
if training_args.lora_tune:
|
|
from peft import (
|
|
LoraConfig,
|
|
get_peft_model,
|
|
)
|
|
# copied from LongLoRA
|
|
config = LoraConfig(
|
|
r=training_args.lora_rank,
|
|
lora_alpha=training_args.lora_alpha,
|
|
target_modules=training_args.lora_targets,
|
|
modules_to_save=training_args.lora_extra_params,
|
|
lora_dropout=training_args.lora_dropout,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
)
|
|
model = get_peft_model(model, config)
|
|
|
|
logger.info(f"Trainable Model params: {format_numel_str(sum(p.numel() for p in model.parameters() if p.requires_grad))}")
|
|
|
|
with training_args.main_process_first():
|
|
train_dataset = Data.prepare_train_data(
|
|
model_args.train_data,
|
|
tokenizer=tokenizer,
|
|
max_length=model_args.max_length,
|
|
min_length=training_args.min_length,
|
|
chat_template=model_args.chat_template,
|
|
seed=training_args.seed,
|
|
cache_dir=model_args.dataset_cache_dir,
|
|
)
|
|
|
|
with training_args.main_process_first():
|
|
if is_deepspeed_zero3_enabled() and training_args.eval_method != "perplexity":
|
|
logger.warning(f"In deepspeed zero3, evaluation with generation is may lead to hang because of the unequal number of forward passes across different devices.")
|
|
eval_dataset = Data.prepare_eval_data(
|
|
model_args.eval_data,
|
|
tokenizer=tokenizer,
|
|
max_length=training_args.eval_max_length,
|
|
min_length=training_args.eval_min_length,
|
|
chat_template=model_args.chat_template,
|
|
seed=training_args.seed,
|
|
cache_dir=model_args.dataset_cache_dir,
|
|
)
|
|
|
|
trainer = ActivationBeaconTrainer(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
args=training_args,
|
|
model_args=model_args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
data_collator=DefaultDataCollator(tokenizer),
|
|
file_logger=FileLogger(makedirs(training_args.log_path)),
|
|
compute_metrics=Metric.get_metric_fn(
|
|
metrics=training_args.metrics,
|
|
save_path=Metric.get_save_path(
|
|
model_args.eval_data,
|
|
training_args.output_dir
|
|
) if model_args.eval_data is not None else None
|
|
)
|
|
)
|
|
if train_dataset is not None:
|
|
trainer.train()
|
|
elif eval_dataset is not None:
|
|
trainer.evaluate()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|