diff --git a/train_sft_ds.py b/train_sft_ds.py index 892eb66..6fdeb92 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import os +os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") import glob import socket import argparse @@ -158,35 +159,36 @@ class QwenChatSFTDataset(IterableDataset): # ----------------- 专用 Collator:pad inputs, pad labels=-100 ----------------- class SFTDataCollator: - def __init__(self, tokenizer: AutoTokenizer): + def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None): self.tok = tokenizer - assert self.tok.pad_token_id is not None, "tokenizer.pad_token 不能为空;已在主函数里兜底为 eos_token" + self.pad_to_length = pad_to_length + assert self.tok.pad_token_id is not None - def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: - # 将变长样本对齐到 batch 内最大长度;labels 用 -100 补齐 - def _to_list(x): - return x.tolist() if isinstance(x, torch.Tensor) else list(x) - input_ids = [ _to_list(f["input_ids"]) for f in features ] - attn_masks = [ _to_list(f["attention_mask"]) for f in features ] - labels_list = [ _to_list(f["labels"]) for f in features ] + def __call__(self, features): + def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) + input_ids = [_to_list(f["input_ids"]) for f in features] + attn_masks = [_to_list(f["attention_mask"]) for f in features] + labels_list = [_to_list(f["labels"]) for f in features] + + max_len_in_batch = max(len(x) for x in input_ids) + target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch - max_len = max(len(x) for x in input_ids) pad_id = self.tok.pad_token_id - batch_inp, batch_attn, batch_lab = [], [], [] for inp, msk, lab in zip(input_ids, attn_masks, labels_list): - pad_len = max_len - len(inp) + pad_len = target_len - len(inp) + if pad_len < 0: + inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len] + pad_len = 0 batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long)) batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long)) batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long)) - return { "input_ids": torch.stack(batch_inp, dim=0), "attention_mask": torch.stack(batch_attn, dim=0), "labels": torch.stack(batch_lab, dim=0), } - # ----------------- 参数 ----------------- def parse_args(): ap = argparse.ArgumentParser() @@ -227,18 +229,27 @@ def main(): args = parse_args() set_seed(args.seed) - # Tokenizer/Model + # 1) 先补 tokenizer 的 pad tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 供 padding 使用 + + # 可选:让警告更少 + tokenizer.model_max_length = args.seq_len + # 2) 再加载模型 model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16), low_cpu_mem_usage=True, trust_remote_code=True ) + + # 3) 最后对齐模型的 pad_token_id + model.config.pad_token_id = tokenizer.pad_token_id model.config.use_cache = False # 训练时禁用 cache + if args.gradient_checkpointing: model.gradient_checkpointing_enable() @@ -336,7 +347,7 @@ def main(): if len(eval_samples) > 0: eval_dataset = ListDataset(eval_samples) - data_collator = SFTDataCollator(tokenizer) + data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len) os.makedirs(args.output_dir, exist_ok=True) logging_dir = os.path.join(args.output_dir, "logs") @@ -376,6 +387,8 @@ def main(): deepspeed=args.deepspeed, dataloader_drop_last=True, dataloader_num_workers=0, + dataloader_prefetch_factor=None, + dataloader_pin_memory=False, report_to=([] if args.report_to == "none" else [args.report_to]), bf16=args.bf16, fp16=(not args.bf16), @@ -392,7 +405,7 @@ def main(): args=training_args, train_dataset=train_stream, eval_dataset=eval_dataset, - tokenizer=tokenizer, + processing_class=tokenizer, data_collator=data_collator ) trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))