diff --git a/train_sft_ds.py b/train_sft_ds.py index 3d5f1f0..e6bd4cb 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -487,7 +487,7 @@ def main(): torch_dtype=dtype, low_cpu_mem_usage=True, trust_remote_code=True, - attn_implementation="eager" + attn_implementation="sdpa" ) print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)