This commit is contained in:
hailin 2025-08-26 13:45:55 +08:00
parent 9225b4b327
commit fa4d4d8b1e
1 changed files with 98 additions and 13 deletions

View File

@ -31,6 +31,28 @@ def print_once(*args, **kwargs):
if is_main_process():
print(*args, **kwargs, flush=True)
class DebugTrainer(Trainer):
def training_step(self, model, inputs, num_items_in_batch=None):
if not hasattr(self, "_dbg_printed"):
rank = int(os.environ.get("RANK", "0"))
host = socket.gethostname()
ids = inputs["input_ids"]
msk = inputs["attention_mask"]
labs = inputs["labels"]
print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} "
f"supervised={(labs!=-100).sum().item()}",
flush=True)
print(
f"[step0][host={host} RANK={rank}] "
f"input_ids.shape={tuple(ids.shape)} "
f"attention_mask.shape={tuple(msk.shape)} "
f"labels.shape={tuple(labs.shape)} "
f"num_items_in_batch={num_items_in_batch}",
flush=True
)
self._dbg_printed = True
return super().training_step(model, inputs, num_items_in_batch)
# ----------------- 日志回调 -----------------
class CsvLossLogger(TrainerCallback):
@ -90,6 +112,16 @@ class QwenChatSFTDataset(IterableDataset):
self.seq_len = seq_len
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
# >>> DEBUG BEGIN
dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0
dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
rank = int(os.environ.get("RANK", "0"))
lrank = int(os.environ.get("LOCAL_RANK", "-1"))
host = socket.gethostname()
# >>> DEBUG END
for ex in self.ex_iter:
msgs = ex.get("messages", None)
if not msgs or not isinstance(msgs, list):
@ -168,6 +200,23 @@ class QwenChatSFTDataset(IterableDataset):
assert len(labels) == self.seq_len
assert len(attn_mask) == self.seq_len
# >>> DEBUG PRINT此时变量已定义
if dbg_on and self._dbg_seen < dbg_limit:
sup_tok = sum(1 for v in labels if v != -100)
print(
f"[sample][host={host} RANK={rank} LRank={lrank}] "
f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} "
f"seq_len={self.seq_len} pad_id={pad_id}",
flush=True
)
if sup_tok == 0:
print(
f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped",
flush=True
)
self._dbg_seen += 1
# <<< DEBUG PRINT
yield {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
@ -184,9 +233,9 @@ class SFTDataCollator:
assert self.tok.pad_token_id is not None
def __call__(self, features):
if not features:
raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
# if not features:
# raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
# f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
input_ids = [_to_list(f["input_ids"]) for f in features]
@ -206,6 +255,27 @@ class SFTDataCollator:
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
# >>> DEBUG BEGIN
dbg_on = os.environ.get("DBG_COLLATE", "0") == "1"
if dbg_on:
rank = int(os.environ.get("RANK", "0"))
host = socket.gethostname()
bs = len(features)
first_len = len(input_ids[0]) if bs > 0 else None
print(
f"[collate][host={host} RANK={rank}] features={bs} "
f"target_len={target_len} first_len={first_len}",
flush=True
)
# 额外严苛校验:防止空 batch 继续往下走
if not features:
raise RuntimeError(
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
f"Check dataset sharding/streaming."
)
# >>> DEBUG END
return {
"input_ids": torch.stack(batch_inp, dim=0),
"attention_mask": torch.stack(batch_attn, dim=0),
@ -290,13 +360,6 @@ def main():
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}")
if world_size > 1 and dist.is_available() and not dist.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
dbg(f"init_process_group backend={backend} via env://")
dist.init_process_group(backend=backend, init_method="env://")
else:
dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}")
if torch.cuda.is_available() and local_rank >= 0:
torch.cuda.set_device(local_rank)
dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} "
@ -304,6 +367,13 @@ def main():
else:
dbg("no cuda or invalid local_rank; not calling set_device")
if world_size > 1 and dist.is_available() and not dist.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
dbg(f"init_process_group backend={backend} via env://")
dist.init_process_group(backend=backend, init_method="env://")
else:
dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}")
if dist.is_available() and dist.is_initialized():
try:
@ -399,7 +469,11 @@ def main():
local_ok = has_one_sample(probe_stream)
if dist.is_available() and dist.is_initialized():
t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
# t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
t = torch.tensor(
local_ok,
device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")
)
dist.all_reduce(t, op=dist.ReduceOp.MIN)
if t.item() == 0:
if is_main_process():
@ -533,14 +607,25 @@ def main():
**ta_kwargs,
)
trainer = Trainer(
trainer = DebugTrainer(
model=model,
args=training_args,
train_dataset=train_stream,
eval_dataset=eval_dataset,
processing_class=tokenizer,
tokenizer=tokenizer,
# processing_class=tokenizer,
data_collator=data_collator
)
# trainer = Trainer(
# model=model,
# args=training_args,
# train_dataset=train_stream,
# eval_dataset=eval_dataset,
# processing_class=tokenizer,
# data_collator=data_collator
# )
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*