This commit is contained in:
parent
9225b4b327
commit
fa4d4d8b1e
111
train_sft_ds.py
111
train_sft_ds.py
|
|
@ -31,6 +31,28 @@ def print_once(*args, **kwargs):
|
|||
if is_main_process():
|
||||
print(*args, **kwargs, flush=True)
|
||||
|
||||
class DebugTrainer(Trainer):
|
||||
def training_step(self, model, inputs, num_items_in_batch=None):
|
||||
if not hasattr(self, "_dbg_printed"):
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
host = socket.gethostname()
|
||||
ids = inputs["input_ids"]
|
||||
msk = inputs["attention_mask"]
|
||||
labs = inputs["labels"]
|
||||
print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} "
|
||||
f"supervised={(labs!=-100).sum().item()}",
|
||||
flush=True)
|
||||
print(
|
||||
f"[step0][host={host} RANK={rank}] "
|
||||
f"input_ids.shape={tuple(ids.shape)} "
|
||||
f"attention_mask.shape={tuple(msk.shape)} "
|
||||
f"labels.shape={tuple(labs.shape)} "
|
||||
f"num_items_in_batch={num_items_in_batch}",
|
||||
flush=True
|
||||
)
|
||||
self._dbg_printed = True
|
||||
return super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
|
||||
# ----------------- 日志回调 -----------------
|
||||
class CsvLossLogger(TrainerCallback):
|
||||
|
|
@ -90,6 +112,16 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
self.seq_len = seq_len
|
||||
|
||||
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
||||
|
||||
# >>> DEBUG BEGIN
|
||||
dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
|
||||
if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0
|
||||
dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
lrank = int(os.environ.get("LOCAL_RANK", "-1"))
|
||||
host = socket.gethostname()
|
||||
# >>> DEBUG END
|
||||
|
||||
for ex in self.ex_iter:
|
||||
msgs = ex.get("messages", None)
|
||||
if not msgs or not isinstance(msgs, list):
|
||||
|
|
@ -168,6 +200,23 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
assert len(labels) == self.seq_len
|
||||
assert len(attn_mask) == self.seq_len
|
||||
|
||||
# >>> DEBUG PRINT(此时变量已定义)
|
||||
if dbg_on and self._dbg_seen < dbg_limit:
|
||||
sup_tok = sum(1 for v in labels if v != -100)
|
||||
print(
|
||||
f"[sample][host={host} RANK={rank} LRank={lrank}] "
|
||||
f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} "
|
||||
f"seq_len={self.seq_len} pad_id={pad_id}",
|
||||
flush=True
|
||||
)
|
||||
if sup_tok == 0:
|
||||
print(
|
||||
f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped",
|
||||
flush=True
|
||||
)
|
||||
self._dbg_seen += 1
|
||||
# <<< DEBUG PRINT
|
||||
|
||||
yield {
|
||||
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||||
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
|
||||
|
|
@ -184,9 +233,9 @@ class SFTDataCollator:
|
|||
assert self.tok.pad_token_id is not None
|
||||
|
||||
def __call__(self, features):
|
||||
if not features:
|
||||
raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
|
||||
f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
|
||||
# if not features:
|
||||
# raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
|
||||
# f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
|
||||
|
||||
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
|
||||
input_ids = [_to_list(f["input_ids"]) for f in features]
|
||||
|
|
@ -206,6 +255,27 @@ class SFTDataCollator:
|
|||
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
|
||||
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
|
||||
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
|
||||
|
||||
# >>> DEBUG BEGIN
|
||||
dbg_on = os.environ.get("DBG_COLLATE", "0") == "1"
|
||||
if dbg_on:
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
host = socket.gethostname()
|
||||
bs = len(features)
|
||||
first_len = len(input_ids[0]) if bs > 0 else None
|
||||
print(
|
||||
f"[collate][host={host} RANK={rank}] features={bs} "
|
||||
f"target_len={target_len} first_len={first_len}",
|
||||
flush=True
|
||||
)
|
||||
# 额外严苛校验:防止空 batch 继续往下走
|
||||
if not features:
|
||||
raise RuntimeError(
|
||||
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
|
||||
f"Check dataset sharding/streaming."
|
||||
)
|
||||
# >>> DEBUG END
|
||||
|
||||
return {
|
||||
"input_ids": torch.stack(batch_inp, dim=0),
|
||||
"attention_mask": torch.stack(batch_attn, dim=0),
|
||||
|
|
@ -290,13 +360,6 @@ def main():
|
|||
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
|
||||
dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}")
|
||||
|
||||
if world_size > 1 and dist.is_available() and not dist.is_initialized():
|
||||
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
||||
dbg(f"init_process_group backend={backend} via env://")
|
||||
dist.init_process_group(backend=backend, init_method="env://")
|
||||
else:
|
||||
dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}")
|
||||
|
||||
if torch.cuda.is_available() and local_rank >= 0:
|
||||
torch.cuda.set_device(local_rank)
|
||||
dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} "
|
||||
|
|
@ -304,6 +367,13 @@ def main():
|
|||
else:
|
||||
dbg("no cuda or invalid local_rank; not calling set_device")
|
||||
|
||||
if world_size > 1 and dist.is_available() and not dist.is_initialized():
|
||||
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
||||
dbg(f"init_process_group backend={backend} via env://")
|
||||
dist.init_process_group(backend=backend, init_method="env://")
|
||||
else:
|
||||
dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}")
|
||||
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
try:
|
||||
|
|
@ -399,7 +469,11 @@ def main():
|
|||
local_ok = has_one_sample(probe_stream)
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
# t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
t = torch.tensor(
|
||||
local_ok,
|
||||
device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")
|
||||
)
|
||||
dist.all_reduce(t, op=dist.ReduceOp.MIN)
|
||||
if t.item() == 0:
|
||||
if is_main_process():
|
||||
|
|
@ -533,14 +607,25 @@ def main():
|
|||
**ta_kwargs,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
|
||||
trainer = DebugTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_stream,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
tokenizer=tokenizer,
|
||||
# processing_class=tokenizer,
|
||||
data_collator=data_collator
|
||||
)
|
||||
|
||||
# trainer = Trainer(
|
||||
# model=model,
|
||||
# args=training_args,
|
||||
# train_dataset=train_stream,
|
||||
# eval_dataset=eval_dataset,
|
||||
# processing_class=tokenizer,
|
||||
# data_collator=data_collator
|
||||
# )
|
||||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||||
|
||||
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
|
||||
|
|
|
|||
Loading…
Reference in New Issue