diff --git a/train_sft_ds.py b/train_sft_ds.py index 6ba4c95..1b0022d 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -445,36 +445,39 @@ def main(): "另外检查 seq_len 是否过小导致全部被裁。" ) - # ====== 正式训练流 + 模数分片(不要求样本数整除 world_size) ====== + # ====== 正式训练流 ====== ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) - if world_size > 1: + if world_size > 1 and len(files) >= world_size: + # 多文件,按文件连续分片 ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True) - train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) + train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) + else: + # 单文件或文件数不足,按样本取模轮转 + def ex_iter2(): + for i, ex in enumerate(ds_stream2): + if i % max(world_size, 1) == rank: + yield ex + train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) - # def ex_iter2(): - # for i, ex in enumerate(ds_stream2): - # if i % max(world_size, 1) == rank: - # yield ex - # train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) - - # ====== 一致性探针:任意 rank 无样本 -> 全体退出 ====== def has_one_sample(stream): it = iter(stream) try: next(it); return 1 except StopIteration: return 0 - + + # ====== 一致性探针(与上面保持同逻辑)===== ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) - if world_size > 1: + if world_size > 1 and len(files) >= world_size: ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True) - probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) + probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) + else: + def ex_iter2_probe(): + for i, ex in enumerate(ds_stream_probe2): + if i % max(world_size, 1) == rank: + yield ex + probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len) - # def ex_iter2_probe(): - # for i, ex in enumerate(ds_stream_probe2): - # if i % max(world_size, 1) == rank: - # yield ex - # probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len) local_ok = has_one_sample(probe_stream)