diff --git a/train_sft_ds.py b/train_sft_ds.py index 78cc51d..76bf3fc 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -31,6 +31,28 @@ def print_once(*args, **kwargs): if is_main_process(): print(*args, **kwargs, flush=True) +class DebugTrainer(Trainer): + def training_step(self, model, inputs, num_items_in_batch=None): + if not hasattr(self, "_dbg_printed"): + rank = int(os.environ.get("RANK", "0")) + host = socket.gethostname() + ids = inputs["input_ids"] + msk = inputs["attention_mask"] + labs = inputs["labels"] + print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} " + f"supervised={(labs!=-100).sum().item()}", + flush=True) + print( + f"[step0][host={host} RANK={rank}] " + f"input_ids.shape={tuple(ids.shape)} " + f"attention_mask.shape={tuple(msk.shape)} " + f"labels.shape={tuple(labs.shape)} " + f"num_items_in_batch={num_items_in_batch}", + flush=True + ) + self._dbg_printed = True + return super().training_step(model, inputs, num_items_in_batch) + # ----------------- 日志回调 ----------------- class CsvLossLogger(TrainerCallback): @@ -90,6 +112,16 @@ class QwenChatSFTDataset(IterableDataset): self.seq_len = seq_len def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + + # >>> DEBUG BEGIN + dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1" + if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0 + dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3")) + rank = int(os.environ.get("RANK", "0")) + lrank = int(os.environ.get("LOCAL_RANK", "-1")) + host = socket.gethostname() + # >>> DEBUG END + for ex in self.ex_iter: msgs = ex.get("messages", None) if not msgs or not isinstance(msgs, list): @@ -168,6 +200,23 @@ class QwenChatSFTDataset(IterableDataset): assert len(labels) == self.seq_len assert len(attn_mask) == self.seq_len + # >>> DEBUG PRINT(此时变量已定义) + if dbg_on and self._dbg_seen < dbg_limit: + sup_tok = sum(1 for v in labels if v != -100) + print( + f"[sample][host={host} RANK={rank} LRank={lrank}] " + f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} " + f"seq_len={self.seq_len} pad_id={pad_id}", + flush=True + ) + if sup_tok == 0: + print( + f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped", + flush=True + ) + self._dbg_seen += 1 + # <<< DEBUG PRINT + yield { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attn_mask, dtype=torch.long), @@ -184,9 +233,9 @@ class SFTDataCollator: assert self.tok.pad_token_id is not None def __call__(self, features): - if not features: - raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. " - f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.") + # if not features: + # raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. " + # f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.") def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) input_ids = [_to_list(f["input_ids"]) for f in features] @@ -206,6 +255,27 @@ class SFTDataCollator: batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long)) batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long)) batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long)) + + # >>> DEBUG BEGIN + dbg_on = os.environ.get("DBG_COLLATE", "0") == "1" + if dbg_on: + rank = int(os.environ.get("RANK", "0")) + host = socket.gethostname() + bs = len(features) + first_len = len(input_ids[0]) if bs > 0 else None + print( + f"[collate][host={host} RANK={rank}] features={bs} " + f"target_len={target_len} first_len={first_len}", + flush=True + ) + # 额外严苛校验:防止空 batch 继续往下走 + if not features: + raise RuntimeError( + f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. " + f"Check dataset sharding/streaming." + ) + # >>> DEBUG END + return { "input_ids": torch.stack(batch_inp, dim=0), "attention_mask": torch.stack(batch_attn, dim=0), @@ -290,13 +360,6 @@ def main(): local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank))) dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}") - if world_size > 1 and dist.is_available() and not dist.is_initialized(): - backend = "nccl" if torch.cuda.is_available() else "gloo" - dbg(f"init_process_group backend={backend} via env://") - dist.init_process_group(backend=backend, init_method="env://") - else: - dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}") - if torch.cuda.is_available() and local_rank >= 0: torch.cuda.set_device(local_rank) dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} " @@ -304,6 +367,13 @@ def main(): else: dbg("no cuda or invalid local_rank; not calling set_device") + if world_size > 1 and dist.is_available() and not dist.is_initialized(): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dbg(f"init_process_group backend={backend} via env://") + dist.init_process_group(backend=backend, init_method="env://") + else: + dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}") + if dist.is_available() and dist.is_initialized(): try: @@ -399,7 +469,11 @@ def main(): local_ok = has_one_sample(probe_stream) if dist.is_available() and dist.is_initialized(): - t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu")) + # t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu")) + t = torch.tensor( + local_ok, + device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu") + ) dist.all_reduce(t, op=dist.ReduceOp.MIN) if t.item() == 0: if is_main_process(): @@ -533,14 +607,25 @@ def main(): **ta_kwargs, ) - trainer = Trainer( + + trainer = DebugTrainer( model=model, args=training_args, train_dataset=train_stream, eval_dataset=eval_dataset, - processing_class=tokenizer, + tokenizer=tokenizer, + # processing_class=tokenizer, data_collator=data_collator ) + + # trainer = Trainer( + # model=model, + # args=training_args, + # train_dataset=train_stream, + # eval_dataset=eval_dataset, + # processing_class=tokenizer, + # data_collator=data_collator + # ) trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) # 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*