This commit is contained in:
parent
ad9b288e4a
commit
5a30fee9a6
|
|
@ -500,9 +500,14 @@ def main():
|
|||
if args.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
try:
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
# torch.backends.cuda.enable_flash_sdp(False)
|
||||
# torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
# torch.backends.cuda.enable_math_sdp(True)
|
||||
|
||||
# 让 PyTorch 自己选,或显式打开高效实现(任选其一):
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
torch.backends.cuda.enable_math_sdp(False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue