This commit is contained in:
hailin 2025-08-26 11:16:05 +08:00
parent 32da7c0e5b
commit 552caf31f1
1 changed files with 26 additions and 0 deletions

View File

@ -184,6 +184,10 @@ class SFTDataCollator:
assert self.tok.pad_token_id is not None assert self.tok.pad_token_id is not None
def __call__(self, features): def __call__(self, features):
if not features:
raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
input_ids = [_to_list(f["input_ids"]) for f in features] input_ids = [_to_list(f["input_ids"]) for f in features]
attn_masks = [_to_list(f["attention_mask"]) for f in features] attn_masks = [_to_list(f["attention_mask"]) for f in features]
@ -240,6 +244,7 @@ def parse_args():
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用") help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
ap.add_argument("--local_rank", type=int, default=-1, ap.add_argument("--local_rank", type=int, default=-1,
help="for deepspeed/torchrun launcher; ignored by user code") help="for deepspeed/torchrun launcher; ignored by user code")
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
return ap.parse_args() return ap.parse_args()
@ -396,6 +401,26 @@ def main():
if len(eval_samples) > 0: if len(eval_samples) > 0:
eval_dataset = ListDataset(eval_samples) eval_dataset = ListDataset(eval_samples)
# ---- 统一补齐 eval 集(确保不会出现空 batch----
if eval_dataset is not None:
ws = max(world_size, 1)
be = max(1, args.per_device_eval_batch_size)
global_bs = ws * be
r = len(eval_dataset) % global_bs
if r != 0:
need = global_bs - r
# 你的 eval_dataset 是上面自定义的 ListDataset带 .items
eval_dataset.items += eval_dataset.items[:need]
if is_main_process():
print(f"[eval] padded eval set to {len(eval_dataset)} "
f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})",
flush=True)
# 补齐后再做 sanity check
assert len(eval_dataset) % global_bs == 0, \
f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}"
# 更稳:联调阶段不强行 pad 到 4096 # 更稳:联调阶段不强行 pad 到 4096
# data_collator = SFTDataCollator(tokenizer, pad_to_length=None) # data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len) data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
@ -440,6 +465,7 @@ def main():
dataloader_num_workers=0, dataloader_num_workers=0,
dataloader_prefetch_factor=None, dataloader_prefetch_factor=None,
dataloader_pin_memory=False, dataloader_pin_memory=False,
per_device_eval_batch_size=args.per_device_eval_batch_size,
report_to=([] if args.report_to == "none" else [args.report_to]), report_to=([] if args.report_to == "none" else [args.report_to]),
bf16=args.bf16, bf16=args.bf16,
fp16=(not args.bf16), fp16=(not args.bf16),