This commit is contained in:
parent
c9a00b6208
commit
7706fcf842
|
|
@ -237,6 +237,12 @@ class SFTDataCollator:
|
||||||
# raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
|
# raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
|
||||||
# f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
|
# f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
|
||||||
|
|
||||||
|
if not features:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
|
||||||
|
f"Check dataset sharding/streaming."
|
||||||
|
)
|
||||||
|
|
||||||
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
|
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
|
||||||
input_ids = [_to_list(f["input_ids"]) for f in features]
|
input_ids = [_to_list(f["input_ids"]) for f in features]
|
||||||
attn_masks = [_to_list(f["attention_mask"]) for f in features]
|
attn_masks = [_to_list(f["attention_mask"]) for f in features]
|
||||||
|
|
@ -269,11 +275,11 @@ class SFTDataCollator:
|
||||||
flush=True
|
flush=True
|
||||||
)
|
)
|
||||||
# 额外严苛校验:防止空 batch 继续往下走
|
# 额外严苛校验:防止空 batch 继续往下走
|
||||||
if not features:
|
# if not features:
|
||||||
raise RuntimeError(
|
# raise RuntimeError(
|
||||||
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
|
# f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
|
||||||
f"Check dataset sharding/streaming."
|
# f"Check dataset sharding/streaming."
|
||||||
)
|
# )
|
||||||
# >>> DEBUG END
|
# >>> DEBUG END
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -459,13 +465,6 @@ def main():
|
||||||
yield ex
|
yield ex
|
||||||
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
||||||
|
|
||||||
def has_one_sample(stream):
|
|
||||||
it = iter(stream)
|
|
||||||
try:
|
|
||||||
next(it); return 1
|
|
||||||
except StopIteration:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# ====== 一致性探针(与上面保持同逻辑)=====
|
# ====== 一致性探针(与上面保持同逻辑)=====
|
||||||
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||||
if world_size > 1 and len(files) >= world_size:
|
if world_size > 1 and len(files) >= world_size:
|
||||||
|
|
@ -478,24 +477,55 @@ def main():
|
||||||
yield ex
|
yield ex
|
||||||
probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
|
probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
|
||||||
|
|
||||||
|
def has_at_least(stream, n: int):
|
||||||
|
it = iter(stream)
|
||||||
|
for _ in range(n):
|
||||||
|
try:
|
||||||
|
next(it)
|
||||||
|
except StopIteration:
|
||||||
|
return 0
|
||||||
|
return 1
|
||||||
|
|
||||||
local_ok = has_one_sample(probe_stream)
|
need = max(1, args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
local_ok = has_at_least(probe_stream, need)
|
||||||
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
# t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
|
t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu"))
|
||||||
t = torch.tensor(
|
|
||||||
local_ok,
|
|
||||||
device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")
|
|
||||||
)
|
|
||||||
dist.all_reduce(t, op=dist.ReduceOp.MIN)
|
dist.all_reduce(t, op=dist.ReduceOp.MIN)
|
||||||
if t.item() == 0:
|
if t.item() == 0:
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True)
|
print(
|
||||||
|
f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
|
||||||
|
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
|
||||||
|
flush=True
|
||||||
|
)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
sys.exit(2)
|
sys.exit(2)
|
||||||
else:
|
else:
|
||||||
if local_ok == 0:
|
if local_ok == 0:
|
||||||
print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2)
|
print(
|
||||||
|
f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
|
||||||
|
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
|
||||||
|
flush=True
|
||||||
|
)
|
||||||
|
sys.exit(2)
|
||||||
|
|
||||||
|
# if dist.is_available() and dist.is_initialized():
|
||||||
|
# # t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
|
||||||
|
# t = torch.tensor(
|
||||||
|
# local_ok,
|
||||||
|
# device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")
|
||||||
|
# )
|
||||||
|
# dist.all_reduce(t, op=dist.ReduceOp.MIN)
|
||||||
|
# if t.item() == 0:
|
||||||
|
# if is_main_process():
|
||||||
|
# print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True)
|
||||||
|
# dist.barrier()
|
||||||
|
# sys.exit(2)
|
||||||
|
# else:
|
||||||
|
# if local_ok == 0:
|
||||||
|
# print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2)
|
||||||
|
|
||||||
# ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
|
# ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
|
||||||
eval_dataset: Optional[Dataset] = None
|
eval_dataset: Optional[Dataset] = None
|
||||||
|
|
@ -552,9 +582,9 @@ def main():
|
||||||
|
|
||||||
r = len(eval_dataset) % global_bs
|
r = len(eval_dataset) % global_bs
|
||||||
if r != 0:
|
if r != 0:
|
||||||
need = global_bs - r
|
pad_need = global_bs - r
|
||||||
# 你的 eval_dataset 是上面自定义的 ListDataset,带 .items
|
eval_dataset.items += eval_dataset.items[:pad_need]
|
||||||
eval_dataset.items += eval_dataset.items[:need]
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
print(f"[eval] padded eval set to {len(eval_dataset)} "
|
print(f"[eval] padded eval set to {len(eval_dataset)} "
|
||||||
f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})",
|
f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue