This commit is contained in:
parent
6ca96bbc52
commit
809e24c9c6
|
|
@ -251,7 +251,14 @@ def main():
|
||||||
model.config.use_cache = False # 训练时禁用 cache
|
model.config.use_cache = False # 训练时禁用 cache
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.backends.cuda.enable_flash_sdp(False)
|
||||||
|
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||||
|
torch.backends.cuda.enable_math_sdp(True) # 走 math 实现
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# ===== 数据鲁棒性检查(多机各自执行)=====
|
# ===== 数据鲁棒性检查(多机各自执行)=====
|
||||||
host = socket.gethostname()
|
host = socket.gethostname()
|
||||||
|
|
@ -293,7 +300,9 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# 探针已消耗流;为正式训练重建一次
|
# 探针已消耗流;为正式训练重建一次
|
||||||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||||
|
#ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||||
|
ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
|
||||||
def ex_iter2():
|
def ex_iter2():
|
||||||
for ex in ds_stream2:
|
for ex in ds_stream2:
|
||||||
yield ex
|
yield ex
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue