diff --git a/train_sft_ds.py b/train_sft_ds.py index c1a4500..297d02f 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -232,11 +232,7 @@ class SFTDataCollator: self.pad_to_length = pad_to_length assert self.tok.pad_token_id is not None - def __call__(self, features): - # if not features: - # raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. " - # f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.") - + def __call__(self, features): if not features: raise RuntimeError( f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. " @@ -274,14 +270,6 @@ class SFTDataCollator: f"target_len={target_len} first_len={first_len}", flush=True ) - # 额外严苛校验:防止空 batch 继续往下走 - # if not features: - # raise RuntimeError( - # f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. " - # f"Check dataset sharding/streaming." - # ) - # >>> DEBUG END - return { "input_ids": torch.stack(batch_inp, dim=0), "attention_mask": torch.stack(batch_attn, dim=0), @@ -322,6 +310,7 @@ def parse_args(): help="for deepspeed/torchrun launcher; ignored by user code") ap.add_argument("--per_device_eval_batch_size", type=int, default=1) ap.add_argument("--deepspeed", type=str, default=None) + return ap.parse_args() @@ -451,31 +440,39 @@ def main(): "另外检查 seq_len 是否过小导致全部被裁。" ) - # ====== 正式训练流 ====== - ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) - if world_size > 1 and len(files) >= world_size: - # 多文件,按文件连续分片 - ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True) - train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) - else: - # 单文件或文件数不足,按样本取模轮转 - def ex_iter2(): - for i, ex in enumerate(ds_stream2): - if i % max(world_size, 1) == rank: - yield ex - train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) + # # ====== 正式训练流 ====== + # ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + # if world_size > 1 and len(files) >= world_size: + # # 多文件,按文件连续分片 + # ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True) + # train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) + # else: + # # 单文件或文件数不足,按样本取模轮转 + # def ex_iter2(): + # for i, ex in enumerate(ds_stream2): + # if i % max(world_size, 1) == rank: + # yield ex + # train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) - # ====== 一致性探针(与上面保持同逻辑)===== + # ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)====== + ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) + + # # ====== 一致性探针(与上面保持同逻辑)===== + # ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + # if world_size > 1 and len(files) >= world_size: + # ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True) + # probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) + # else: + # def ex_iter2_probe(): + # for i, ex in enumerate(ds_stream_probe2): + # if i % max(world_size, 1) == rank: + # yield ex + # probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len) + + # ====== 一致性探针(不分片)====== ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) - if world_size > 1 and len(files) >= world_size: - ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True) - probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) - else: - def ex_iter2_probe(): - for i, ex in enumerate(ds_stream_probe2): - if i % max(world_size, 1) == rank: - yield ex - probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len) + probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) def has_at_least(stream, n: int): it = iter(stream) @@ -511,22 +508,6 @@ def main(): ) sys.exit(2) - # if dist.is_available() and dist.is_initialized(): - # # t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu")) - # t = torch.tensor( - # local_ok, - # device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu") - # ) - # dist.all_reduce(t, op=dist.ReduceOp.MIN) - # if t.item() == 0: - # if is_main_process(): - # print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True) - # dist.barrier() - # sys.exit(2) - # else: - # if local_ok == 0: - # print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2) - # ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ---- eval_dataset: Optional[Dataset] = None @@ -595,7 +576,6 @@ def main(): f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}" # 更稳:联调阶段不强行 pad 到 4096 - # data_collator = SFTDataCollator(tokenizer, pad_to_length=None) data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len) os.makedirs(args.output_dir, exist_ok=True) @@ -661,14 +641,6 @@ def main(): data_collator=data_collator ) - # trainer = Trainer( - # model=model, - # args=training_args, - # train_dataset=train_stream, - # eval_dataset=eval_dataset, - # processing_class=tokenizer, - # data_collator=data_collator - # ) trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) # 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*