diff --git a/train_sft_ds.py b/train_sft_ds.py index a874b9d..52b6454 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -308,8 +308,8 @@ def main(): ) # 多机/多卡分片(让每个全局 rank 读不同子流) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True) + # world_size = int(os.environ.get("WORLD_SIZE", "1")) + # ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True) def ex_iter2(): for ex in ds_stream2: yield ex