This commit is contained in:
hailin 2025-08-27 22:33:10 +08:00
parent eaa79e566c
commit 4739fc615d
1 changed files with 33 additions and 61 deletions

View File

@ -232,11 +232,7 @@ class SFTDataCollator:
self.pad_to_length = pad_to_length
assert self.tok.pad_token_id is not None
def __call__(self, features):
# if not features:
# raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
# f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
def __call__(self, features):
if not features:
raise RuntimeError(
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
@ -274,14 +270,6 @@ class SFTDataCollator:
f"target_len={target_len} first_len={first_len}",
flush=True
)
# 额外严苛校验:防止空 batch 继续往下走
# if not features:
# raise RuntimeError(
# f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
# f"Check dataset sharding/streaming."
# )
# >>> DEBUG END
return {
"input_ids": torch.stack(batch_inp, dim=0),
"attention_mask": torch.stack(batch_attn, dim=0),
@ -322,6 +310,7 @@ def parse_args():
help="for deepspeed/torchrun launcher; ignored by user code")
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
ap.add_argument("--deepspeed", type=str, default=None)
return ap.parse_args()
@ -451,31 +440,39 @@ def main():
"另外检查 seq_len 是否过小导致全部被裁。"
)
# ====== 正式训练流 ======
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
if world_size > 1 and len(files) >= world_size:
# 多文件,按文件连续分片
ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
else:
# 单文件或文件数不足,按样本取模轮转
def ex_iter2():
for i, ex in enumerate(ds_stream2):
if i % max(world_size, 1) == rank:
yield ex
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
# # ====== 正式训练流 ======
# ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
# if world_size > 1 and len(files) >= world_size:
# # 多文件,按文件连续分片
# ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
# train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
# else:
# # 单文件或文件数不足,按样本取模轮转
# def ex_iter2():
# for i, ex in enumerate(ds_stream2):
# if i % max(world_size, 1) == rank:
# yield ex
# train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
# ====== 一致性探针(与上面保持同逻辑)=====
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer======
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
# # ====== 一致性探针(与上面保持同逻辑)=====
# ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
# if world_size > 1 and len(files) >= world_size:
# ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True)
# probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
# else:
# def ex_iter2_probe():
# for i, ex in enumerate(ds_stream_probe2):
# if i % max(world_size, 1) == rank:
# yield ex
# probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
# ====== 一致性探针(不分片)======
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
if world_size > 1 and len(files) >= world_size:
ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True)
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
else:
def ex_iter2_probe():
for i, ex in enumerate(ds_stream_probe2):
if i % max(world_size, 1) == rank:
yield ex
probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
def has_at_least(stream, n: int):
it = iter(stream)
@ -511,22 +508,6 @@ def main():
)
sys.exit(2)
# if dist.is_available() and dist.is_initialized():
# # t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
# t = torch.tensor(
# local_ok,
# device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")
# )
# dist.all_reduce(t, op=dist.ReduceOp.MIN)
# if t.item() == 0:
# if is_main_process():
# print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True)
# dist.barrier()
# sys.exit(2)
# else:
# if local_ok == 0:
# print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2)
# ---- Eval 构造:优先使用 --eval_data_glob否则才用 eval_ratio 抽样 ----
eval_dataset: Optional[Dataset] = None
@ -595,7 +576,6 @@ def main():
f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}"
# 更稳:联调阶段不强行 pad 到 4096
# data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
os.makedirs(args.output_dir, exist_ok=True)
@ -661,14 +641,6 @@ def main():
data_collator=data_collator
)
# trainer = Trainer(
# model=model,
# args=training_args,
# train_dataset=train_stream,
# eval_dataset=eval_dataset,
# processing_class=tokenizer,
# data_collator=data_collator
# )
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*