This commit is contained in:
parent
809e24c9c6
commit
af095b448a
|
|
@ -300,8 +300,15 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# 探针已消耗流;为正式训练重建一次
|
# 探针已消耗流;为正式训练重建一次
|
||||||
|
ds_stream2 = load_dataset(
|
||||||
|
"json",
|
||||||
|
data_files={"train": files},
|
||||||
|
split="train",
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 多机/多卡分片(让每个全局 rank 读不同子流)
|
||||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||||
#ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
|
||||||
ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
|
ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
|
||||||
def ex_iter2():
|
def ex_iter2():
|
||||||
for ex in ds_stream2:
|
for ex in ds_stream2:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue