This commit is contained in:
parent
32da7c0e5b
commit
552caf31f1
|
|
@ -184,6 +184,10 @@ class SFTDataCollator:
|
||||||
assert self.tok.pad_token_id is not None
|
assert self.tok.pad_token_id is not None
|
||||||
|
|
||||||
def __call__(self, features):
|
def __call__(self, features):
|
||||||
|
if not features:
|
||||||
|
raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
|
||||||
|
f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
|
||||||
|
|
||||||
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
|
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
|
||||||
input_ids = [_to_list(f["input_ids"]) for f in features]
|
input_ids = [_to_list(f["input_ids"]) for f in features]
|
||||||
attn_masks = [_to_list(f["attention_mask"]) for f in features]
|
attn_masks = [_to_list(f["attention_mask"]) for f in features]
|
||||||
|
|
@ -240,6 +244,7 @@ def parse_args():
|
||||||
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
|
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
|
||||||
ap.add_argument("--local_rank", type=int, default=-1,
|
ap.add_argument("--local_rank", type=int, default=-1,
|
||||||
help="for deepspeed/torchrun launcher; ignored by user code")
|
help="for deepspeed/torchrun launcher; ignored by user code")
|
||||||
|
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
|
||||||
return ap.parse_args()
|
return ap.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -396,6 +401,26 @@ def main():
|
||||||
if len(eval_samples) > 0:
|
if len(eval_samples) > 0:
|
||||||
eval_dataset = ListDataset(eval_samples)
|
eval_dataset = ListDataset(eval_samples)
|
||||||
|
|
||||||
|
# ---- 统一补齐 eval 集(确保不会出现空 batch)----
|
||||||
|
if eval_dataset is not None:
|
||||||
|
ws = max(world_size, 1)
|
||||||
|
be = max(1, args.per_device_eval_batch_size)
|
||||||
|
global_bs = ws * be
|
||||||
|
|
||||||
|
r = len(eval_dataset) % global_bs
|
||||||
|
if r != 0:
|
||||||
|
need = global_bs - r
|
||||||
|
# 你的 eval_dataset 是上面自定义的 ListDataset,带 .items
|
||||||
|
eval_dataset.items += eval_dataset.items[:need]
|
||||||
|
if is_main_process():
|
||||||
|
print(f"[eval] padded eval set to {len(eval_dataset)} "
|
||||||
|
f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})",
|
||||||
|
flush=True)
|
||||||
|
|
||||||
|
# 补齐后再做 sanity check
|
||||||
|
assert len(eval_dataset) % global_bs == 0, \
|
||||||
|
f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}"
|
||||||
|
|
||||||
# 更稳:联调阶段不强行 pad 到 4096
|
# 更稳:联调阶段不强行 pad 到 4096
|
||||||
# data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
|
# data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
|
||||||
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
|
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
|
||||||
|
|
@ -440,6 +465,7 @@ def main():
|
||||||
dataloader_num_workers=0,
|
dataloader_num_workers=0,
|
||||||
dataloader_prefetch_factor=None,
|
dataloader_prefetch_factor=None,
|
||||||
dataloader_pin_memory=False,
|
dataloader_pin_memory=False,
|
||||||
|
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||||
report_to=([] if args.report_to == "none" else [args.report_to]),
|
report_to=([] if args.report_to == "none" else [args.report_to]),
|
||||||
bf16=args.bf16,
|
bf16=args.bf16,
|
||||||
fp16=(not args.bf16),
|
fp16=(not args.bf16),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue