From 552caf31f187a2bd7cba34c2f3ee5e52559c3628 Mon Sep 17 00:00:00 2001 From: hailin Date: Tue, 26 Aug 2025 11:16:05 +0800 Subject: [PATCH] . --- train_sft_ds.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/train_sft_ds.py b/train_sft_ds.py index 3753cb9..d987653 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -184,6 +184,10 @@ class SFTDataCollator: assert self.tok.pad_token_id is not None def __call__(self, features): + if not features: + raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. " + f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.") + def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) input_ids = [_to_list(f["input_ids"]) for f in features] attn_masks = [_to_list(f["attention_mask"]) for f in features] @@ -240,6 +244,7 @@ def parse_args(): help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用") ap.add_argument("--local_rank", type=int, default=-1, help="for deepspeed/torchrun launcher; ignored by user code") + ap.add_argument("--per_device_eval_batch_size", type=int, default=1) return ap.parse_args() @@ -396,6 +401,26 @@ def main(): if len(eval_samples) > 0: eval_dataset = ListDataset(eval_samples) + # ---- 统一补齐 eval 集(确保不会出现空 batch)---- + if eval_dataset is not None: + ws = max(world_size, 1) + be = max(1, args.per_device_eval_batch_size) + global_bs = ws * be + + r = len(eval_dataset) % global_bs + if r != 0: + need = global_bs - r + # 你的 eval_dataset 是上面自定义的 ListDataset,带 .items + eval_dataset.items += eval_dataset.items[:need] + if is_main_process(): + print(f"[eval] padded eval set to {len(eval_dataset)} " + f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})", + flush=True) + + # 补齐后再做 sanity check + assert len(eval_dataset) % global_bs == 0, \ + f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}" + # 更稳:联调阶段不强行 pad 到 4096 # data_collator = SFTDataCollator(tokenizer, pad_to_length=None) data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len) @@ -440,6 +465,7 @@ def main(): dataloader_num_workers=0, dataloader_prefetch_factor=None, dataloader_pin_memory=False, + per_device_eval_batch_size=args.per_device_eval_batch_size, report_to=([] if args.report_to == "none" else [args.report_to]), bf16=args.bf16, fp16=(not args.bf16),