From 5a30fee9a60d4dc0bbe925217b04aa9938b5fd45 Mon Sep 17 00:00:00 2001 From: hailin Date: Thu, 28 Aug 2025 13:42:18 +0800 Subject: [PATCH] . --- train_sft_ds.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/train_sft_ds.py b/train_sft_ds.py index 13cb148..a478012 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -500,9 +500,14 @@ def main(): if args.gradient_checkpointing: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) try: - torch.backends.cuda.enable_flash_sdp(False) - torch.backends.cuda.enable_mem_efficient_sdp(False) - torch.backends.cuda.enable_math_sdp(True) + # torch.backends.cuda.enable_flash_sdp(False) + # torch.backends.cuda.enable_mem_efficient_sdp(False) + # torch.backends.cuda.enable_math_sdp(True) + + # 让 PyTorch 自己选,或显式打开高效实现(任选其一): + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) + torch.backends.cuda.enable_math_sdp(False) except Exception: pass