diff --git a/train_sft_ds.py b/train_sft_ds.py index 1b0022d..4962edc 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -237,6 +237,12 @@ class SFTDataCollator: # raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. " # f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.") + if not features: + raise RuntimeError( + f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. " + f"Check dataset sharding/streaming." + ) + def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) input_ids = [_to_list(f["input_ids"]) for f in features] attn_masks = [_to_list(f["attention_mask"]) for f in features] @@ -269,11 +275,11 @@ class SFTDataCollator: flush=True ) # 额外严苛校验:防止空 batch 继续往下走 - if not features: - raise RuntimeError( - f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. " - f"Check dataset sharding/streaming." - ) + # if not features: + # raise RuntimeError( + # f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. " + # f"Check dataset sharding/streaming." + # ) # >>> DEBUG END return { @@ -459,13 +465,6 @@ def main(): yield ex train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) - def has_one_sample(stream): - it = iter(stream) - try: - next(it); return 1 - except StopIteration: - return 0 - # ====== 一致性探针(与上面保持同逻辑)===== ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) if world_size > 1 and len(files) >= world_size: @@ -478,24 +477,55 @@ def main(): yield ex probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len) + def has_at_least(stream, n: int): + it = iter(stream) + for _ in range(n): + try: + next(it) + except StopIteration: + return 0 + return 1 - local_ok = has_one_sample(probe_stream) + need = max(1, args.gradient_accumulation_steps) + + local_ok = has_at_least(probe_stream, need) if dist.is_available() and dist.is_initialized(): - # t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu")) - t = torch.tensor( - local_ok, - device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu") - ) + t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")) dist.all_reduce(t, op=dist.ReduceOp.MIN) if t.item() == 0: if is_main_process(): - print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True) + print( + f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批 (GA={need})。 " + f"请减少 GA 或扩大/清洗数据;本次训练不会启动。", + flush=True + ) dist.barrier() sys.exit(2) else: if local_ok == 0: - print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2) + print( + f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批 (GA={need})。 " + f"请减少 GA 或扩大/清洗数据;本次训练不会启动。", + flush=True + ) + sys.exit(2) + + # if dist.is_available() and dist.is_initialized(): + # # t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu")) + # t = torch.tensor( + # local_ok, + # device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu") + # ) + # dist.all_reduce(t, op=dist.ReduceOp.MIN) + # if t.item() == 0: + # if is_main_process(): + # print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True) + # dist.barrier() + # sys.exit(2) + # else: + # if local_ok == 0: + # print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2) # ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ---- eval_dataset: Optional[Dataset] = None @@ -552,9 +582,9 @@ def main(): r = len(eval_dataset) % global_bs if r != 0: - need = global_bs - r - # 你的 eval_dataset 是上面自定义的 ListDataset,带 .items - eval_dataset.items += eval_dataset.items[:need] + pad_need = global_bs - r + eval_dataset.items += eval_dataset.items[:pad_need] + if is_main_process(): print(f"[eval] padded eval set to {len(eval_dataset)} " f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})",