This commit is contained in:
hailin 2025-08-28 13:42:18 +08:00
parent ad9b288e4a
commit 5a30fee9a6
1 changed files with 8 additions and 3 deletions

View File

@ -500,9 +500,14 @@ def main():
if args.gradient_checkpointing: if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
try: try:
torch.backends.cuda.enable_flash_sdp(False) # torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False) # torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True) # torch.backends.cuda.enable_math_sdp(True)
# 让 PyTorch 自己选,或显式打开高效实现(任选其一):
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
except Exception: except Exception:
pass pass