From 809e24c9c6a6051e006b8539ae349651e457acc1 Mon Sep 17 00:00:00 2001 From: hailin Date: Mon, 25 Aug 2025 23:03:24 +0800 Subject: [PATCH] . --- train_sft_ds.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/train_sft_ds.py b/train_sft_ds.py index 6fdeb92..53df433 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -251,7 +251,14 @@ def main(): model.config.use_cache = False # 训练时禁用 cache if args.gradient_checkpointing: - model.gradient_checkpointing_enable() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + try: + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_math_sdp(True) # 走 math 实现 + except Exception: + pass # ===== 数据鲁棒性检查(多机各自执行)===== host = socket.gethostname() @@ -293,7 +300,9 @@ def main(): ) # 探针已消耗流;为正式训练重建一次 - ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + #ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True) def ex_iter2(): for ex in ds_stream2: yield ex