From 746b8471eeb3c0facf07f36c10ba9e963ca0e28a Mon Sep 17 00:00:00 2001 From: hailin Date: Tue, 26 Aug 2025 15:12:35 +0800 Subject: [PATCH] . --- train_sft_ds.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/train_sft_ds.py b/train_sft_ds.py index 0bcf065..6ba4c95 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -447,11 +447,15 @@ def main(): # ====== 正式训练流 + 模数分片(不要求样本数整除 world_size) ====== ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) - def ex_iter2(): - for i, ex in enumerate(ds_stream2): - if i % max(world_size, 1) == rank: - yield ex - train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) + if world_size > 1: + ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True) + train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) + + # def ex_iter2(): + # for i, ex in enumerate(ds_stream2): + # if i % max(world_size, 1) == rank: + # yield ex + # train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) # ====== 一致性探针:任意 rank 无样本 -> 全体退出 ====== def has_one_sample(stream): @@ -462,11 +466,16 @@ def main(): return 0 ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) - def ex_iter2_probe(): - for i, ex in enumerate(ds_stream_probe2): - if i % max(world_size, 1) == rank: - yield ex - probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len) + if world_size > 1: + ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True) + probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) + + # def ex_iter2_probe(): + # for i, ex in enumerate(ds_stream_probe2): + # if i % max(world_size, 1) == rank: + # yield ex + # probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len) + local_ok = has_one_sample(probe_stream) if dist.is_available() and dist.is_initialized(): @@ -614,8 +623,8 @@ def main(): args=training_args, train_dataset=train_stream, eval_dataset=eval_dataset, - tokenizer=tokenizer, - # processing_class=tokenizer, + #tokenizer=tokenizer, + processing_class=tokenizer, data_collator=data_collator )