From af095b448a86011bf06855f8889b6ca0d56ebc08 Mon Sep 17 00:00:00 2001 From: hailin Date: Mon, 25 Aug 2025 23:09:42 +0800 Subject: [PATCH] . --- train_sft_ds.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/train_sft_ds.py b/train_sft_ds.py index 53df433..a874b9d 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -300,8 +300,15 @@ def main(): ) # 探针已消耗流;为正式训练重建一次 + ds_stream2 = load_dataset( + "json", + data_files={"train": files}, + split="train", + streaming=True + ) + + # 多机/多卡分片(让每个全局 rank 读不同子流) world_size = int(os.environ.get("WORLD_SIZE", "1")) - #ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True) def ex_iter2(): for ex in ds_stream2: