This commit is contained in:
hailin 2025-08-26 18:24:06 +08:00
parent c9a00b6208
commit 7706fcf842
1 changed files with 53 additions and 23 deletions

View File

@ -237,6 +237,12 @@ class SFTDataCollator:
# raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
# f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
if not features:
raise RuntimeError(
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
f"Check dataset sharding/streaming."
)
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
input_ids = [_to_list(f["input_ids"]) for f in features]
attn_masks = [_to_list(f["attention_mask"]) for f in features]
@ -269,11 +275,11 @@ class SFTDataCollator:
flush=True
)
# 额外严苛校验:防止空 batch 继续往下走
if not features:
raise RuntimeError(
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
f"Check dataset sharding/streaming."
)
# if not features:
# raise RuntimeError(
# f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
# f"Check dataset sharding/streaming."
# )
# >>> DEBUG END
return {
@ -459,13 +465,6 @@ def main():
yield ex
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
def has_one_sample(stream):
it = iter(stream)
try:
next(it); return 1
except StopIteration:
return 0
# ====== 一致性探针(与上面保持同逻辑)=====
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
if world_size > 1 and len(files) >= world_size:
@ -478,24 +477,55 @@ def main():
yield ex
probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
def has_at_least(stream, n: int):
it = iter(stream)
for _ in range(n):
try:
next(it)
except StopIteration:
return 0
return 1
local_ok = has_one_sample(probe_stream)
need = max(1, args.gradient_accumulation_steps)
local_ok = has_at_least(probe_stream, need)
if dist.is_available() and dist.is_initialized():
# t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
t = torch.tensor(
local_ok,
device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")
)
t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu"))
dist.all_reduce(t, op=dist.ReduceOp.MIN)
if t.item() == 0:
if is_main_process():
print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True)
print(
f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
flush=True
)
dist.barrier()
sys.exit(2)
else:
if local_ok == 0:
print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2)
print(
f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
flush=True
)
sys.exit(2)
# if dist.is_available() and dist.is_initialized():
# # t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
# t = torch.tensor(
# local_ok,
# device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")
# )
# dist.all_reduce(t, op=dist.ReduceOp.MIN)
# if t.item() == 0:
# if is_main_process():
# print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True)
# dist.barrier()
# sys.exit(2)
# else:
# if local_ok == 0:
# print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2)
# ---- Eval 构造:优先使用 --eval_data_glob否则才用 eval_ratio 抽样 ----
eval_dataset: Optional[Dataset] = None
@ -552,9 +582,9 @@ def main():
r = len(eval_dataset) % global_bs
if r != 0:
need = global_bs - r
# 你的 eval_dataset 是上面自定义的 ListDataset带 .items
eval_dataset.items += eval_dataset.items[:need]
pad_need = global_bs - r
eval_dataset.items += eval_dataset.items[:pad_need]
if is_main_process():
print(f"[eval] padded eval set to {len(eval_dataset)} "
f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})",