diff --git a/train_sft_ds.py b/train_sft_ds.py index 53df433..a874b9d 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -300,8 +300,15 @@ def main(): ) # 探针已消耗流;为正式训练重建一次 + ds_stream2 = load_dataset( + "json", + data_files={"train": files}, + split="train", + streaming=True + ) + + # 多机/多卡分片(让每个全局 rank 读不同子流) world_size = int(os.environ.get("WORLD_SIZE", "1")) - #ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True) def ex_iter2(): for ex in ds_stream2: