This commit is contained in:
parent
c9a00b6208
commit
7706fcf842
|
|
@ -237,6 +237,12 @@ class SFTDataCollator:
|
|||
# raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
|
||||
# f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
|
||||
|
||||
if not features:
|
||||
raise RuntimeError(
|
||||
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
|
||||
f"Check dataset sharding/streaming."
|
||||
)
|
||||
|
||||
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
|
||||
input_ids = [_to_list(f["input_ids"]) for f in features]
|
||||
attn_masks = [_to_list(f["attention_mask"]) for f in features]
|
||||
|
|
@ -269,11 +275,11 @@ class SFTDataCollator:
|
|||
flush=True
|
||||
)
|
||||
# 额外严苛校验:防止空 batch 继续往下走
|
||||
if not features:
|
||||
raise RuntimeError(
|
||||
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
|
||||
f"Check dataset sharding/streaming."
|
||||
)
|
||||
# if not features:
|
||||
# raise RuntimeError(
|
||||
# f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
|
||||
# f"Check dataset sharding/streaming."
|
||||
# )
|
||||
# >>> DEBUG END
|
||||
|
||||
return {
|
||||
|
|
@ -459,13 +465,6 @@ def main():
|
|||
yield ex
|
||||
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
def has_one_sample(stream):
|
||||
it = iter(stream)
|
||||
try:
|
||||
next(it); return 1
|
||||
except StopIteration:
|
||||
return 0
|
||||
|
||||
# ====== 一致性探针(与上面保持同逻辑)=====
|
||||
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
if world_size > 1 and len(files) >= world_size:
|
||||
|
|
@ -478,24 +477,55 @@ def main():
|
|||
yield ex
|
||||
probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
def has_at_least(stream, n: int):
|
||||
it = iter(stream)
|
||||
for _ in range(n):
|
||||
try:
|
||||
next(it)
|
||||
except StopIteration:
|
||||
return 0
|
||||
return 1
|
||||
|
||||
local_ok = has_one_sample(probe_stream)
|
||||
need = max(1, args.gradient_accumulation_steps)
|
||||
|
||||
local_ok = has_at_least(probe_stream, need)
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
# t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
t = torch.tensor(
|
||||
local_ok,
|
||||
device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")
|
||||
)
|
||||
t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu"))
|
||||
dist.all_reduce(t, op=dist.ReduceOp.MIN)
|
||||
if t.item() == 0:
|
||||
if is_main_process():
|
||||
print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True)
|
||||
print(
|
||||
f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
|
||||
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
|
||||
flush=True
|
||||
)
|
||||
dist.barrier()
|
||||
sys.exit(2)
|
||||
else:
|
||||
if local_ok == 0:
|
||||
print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2)
|
||||
print(
|
||||
f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
|
||||
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
|
||||
flush=True
|
||||
)
|
||||
sys.exit(2)
|
||||
|
||||
# if dist.is_available() and dist.is_initialized():
|
||||
# # t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
# t = torch.tensor(
|
||||
# local_ok,
|
||||
# device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")
|
||||
# )
|
||||
# dist.all_reduce(t, op=dist.ReduceOp.MIN)
|
||||
# if t.item() == 0:
|
||||
# if is_main_process():
|
||||
# print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True)
|
||||
# dist.barrier()
|
||||
# sys.exit(2)
|
||||
# else:
|
||||
# if local_ok == 0:
|
||||
# print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2)
|
||||
|
||||
# ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
|
||||
eval_dataset: Optional[Dataset] = None
|
||||
|
|
@ -552,9 +582,9 @@ def main():
|
|||
|
||||
r = len(eval_dataset) % global_bs
|
||||
if r != 0:
|
||||
need = global_bs - r
|
||||
# 你的 eval_dataset 是上面自定义的 ListDataset,带 .items
|
||||
eval_dataset.items += eval_dataset.items[:need]
|
||||
pad_need = global_bs - r
|
||||
eval_dataset.items += eval_dataset.items[:pad_need]
|
||||
|
||||
if is_main_process():
|
||||
print(f"[eval] padded eval set to {len(eval_dataset)} "
|
||||
f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})",
|
||||
|
|
|
|||
Loading…
Reference in New Issue