diff --git a/train_sft_ds.py b/train_sft_ds.py index c255748..e6bd4cb 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -487,7 +487,7 @@ def main(): torch_dtype=dtype, low_cpu_mem_usage=True, trust_remote_code=True, - attn_implementation="flash_attention_2" + attn_implementation="sdpa" ) print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)