This commit is contained in:
parent
ad9b288e4a
commit
5a30fee9a6
|
|
@ -500,9 +500,14 @@ def main():
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
try:
|
try:
|
||||||
torch.backends.cuda.enable_flash_sdp(False)
|
# torch.backends.cuda.enable_flash_sdp(False)
|
||||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
# torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
# torch.backends.cuda.enable_math_sdp(True)
|
||||||
|
|
||||||
|
# 让 PyTorch 自己选,或显式打开高效实现(任选其一):
|
||||||
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
|
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||||
|
torch.backends.cuda.enable_math_sdp(False)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue