This commit is contained in:
parent
9225b4b327
commit
fa4d4d8b1e
111
train_sft_ds.py
111
train_sft_ds.py
|
|
@ -31,6 +31,28 @@ def print_once(*args, **kwargs):
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
print(*args, **kwargs, flush=True)
|
print(*args, **kwargs, flush=True)
|
||||||
|
|
||||||
|
class DebugTrainer(Trainer):
|
||||||
|
def training_step(self, model, inputs, num_items_in_batch=None):
|
||||||
|
if not hasattr(self, "_dbg_printed"):
|
||||||
|
rank = int(os.environ.get("RANK", "0"))
|
||||||
|
host = socket.gethostname()
|
||||||
|
ids = inputs["input_ids"]
|
||||||
|
msk = inputs["attention_mask"]
|
||||||
|
labs = inputs["labels"]
|
||||||
|
print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} "
|
||||||
|
f"supervised={(labs!=-100).sum().item()}",
|
||||||
|
flush=True)
|
||||||
|
print(
|
||||||
|
f"[step0][host={host} RANK={rank}] "
|
||||||
|
f"input_ids.shape={tuple(ids.shape)} "
|
||||||
|
f"attention_mask.shape={tuple(msk.shape)} "
|
||||||
|
f"labels.shape={tuple(labs.shape)} "
|
||||||
|
f"num_items_in_batch={num_items_in_batch}",
|
||||||
|
flush=True
|
||||||
|
)
|
||||||
|
self._dbg_printed = True
|
||||||
|
return super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
|
||||||
|
|
||||||
# ----------------- 日志回调 -----------------
|
# ----------------- 日志回调 -----------------
|
||||||
class CsvLossLogger(TrainerCallback):
|
class CsvLossLogger(TrainerCallback):
|
||||||
|
|
@ -90,6 +112,16 @@ class QwenChatSFTDataset(IterableDataset):
|
||||||
self.seq_len = seq_len
|
self.seq_len = seq_len
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
||||||
|
|
||||||
|
# >>> DEBUG BEGIN
|
||||||
|
dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
|
||||||
|
if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0
|
||||||
|
dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
|
||||||
|
rank = int(os.environ.get("RANK", "0"))
|
||||||
|
lrank = int(os.environ.get("LOCAL_RANK", "-1"))
|
||||||
|
host = socket.gethostname()
|
||||||
|
# >>> DEBUG END
|
||||||
|
|
||||||
for ex in self.ex_iter:
|
for ex in self.ex_iter:
|
||||||
msgs = ex.get("messages", None)
|
msgs = ex.get("messages", None)
|
||||||
if not msgs or not isinstance(msgs, list):
|
if not msgs or not isinstance(msgs, list):
|
||||||
|
|
@ -168,6 +200,23 @@ class QwenChatSFTDataset(IterableDataset):
|
||||||
assert len(labels) == self.seq_len
|
assert len(labels) == self.seq_len
|
||||||
assert len(attn_mask) == self.seq_len
|
assert len(attn_mask) == self.seq_len
|
||||||
|
|
||||||
|
# >>> DEBUG PRINT(此时变量已定义)
|
||||||
|
if dbg_on and self._dbg_seen < dbg_limit:
|
||||||
|
sup_tok = sum(1 for v in labels if v != -100)
|
||||||
|
print(
|
||||||
|
f"[sample][host={host} RANK={rank} LRank={lrank}] "
|
||||||
|
f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} "
|
||||||
|
f"seq_len={self.seq_len} pad_id={pad_id}",
|
||||||
|
flush=True
|
||||||
|
)
|
||||||
|
if sup_tok == 0:
|
||||||
|
print(
|
||||||
|
f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped",
|
||||||
|
flush=True
|
||||||
|
)
|
||||||
|
self._dbg_seen += 1
|
||||||
|
# <<< DEBUG PRINT
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||||||
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
|
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
|
||||||
|
|
@ -184,9 +233,9 @@ class SFTDataCollator:
|
||||||
assert self.tok.pad_token_id is not None
|
assert self.tok.pad_token_id is not None
|
||||||
|
|
||||||
def __call__(self, features):
|
def __call__(self, features):
|
||||||
if not features:
|
# if not features:
|
||||||
raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
|
# raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
|
||||||
f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
|
# f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
|
||||||
|
|
||||||
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
|
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
|
||||||
input_ids = [_to_list(f["input_ids"]) for f in features]
|
input_ids = [_to_list(f["input_ids"]) for f in features]
|
||||||
|
|
@ -206,6 +255,27 @@ class SFTDataCollator:
|
||||||
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
|
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
|
||||||
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
|
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
|
||||||
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
|
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
|
||||||
|
|
||||||
|
# >>> DEBUG BEGIN
|
||||||
|
dbg_on = os.environ.get("DBG_COLLATE", "0") == "1"
|
||||||
|
if dbg_on:
|
||||||
|
rank = int(os.environ.get("RANK", "0"))
|
||||||
|
host = socket.gethostname()
|
||||||
|
bs = len(features)
|
||||||
|
first_len = len(input_ids[0]) if bs > 0 else None
|
||||||
|
print(
|
||||||
|
f"[collate][host={host} RANK={rank}] features={bs} "
|
||||||
|
f"target_len={target_len} first_len={first_len}",
|
||||||
|
flush=True
|
||||||
|
)
|
||||||
|
# 额外严苛校验:防止空 batch 继续往下走
|
||||||
|
if not features:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
|
||||||
|
f"Check dataset sharding/streaming."
|
||||||
|
)
|
||||||
|
# >>> DEBUG END
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": torch.stack(batch_inp, dim=0),
|
"input_ids": torch.stack(batch_inp, dim=0),
|
||||||
"attention_mask": torch.stack(batch_attn, dim=0),
|
"attention_mask": torch.stack(batch_attn, dim=0),
|
||||||
|
|
@ -290,13 +360,6 @@ def main():
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
|
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
|
||||||
dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}")
|
dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}")
|
||||||
|
|
||||||
if world_size > 1 and dist.is_available() and not dist.is_initialized():
|
|
||||||
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
|
||||||
dbg(f"init_process_group backend={backend} via env://")
|
|
||||||
dist.init_process_group(backend=backend, init_method="env://")
|
|
||||||
else:
|
|
||||||
dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}")
|
|
||||||
|
|
||||||
if torch.cuda.is_available() and local_rank >= 0:
|
if torch.cuda.is_available() and local_rank >= 0:
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} "
|
dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} "
|
||||||
|
|
@ -304,6 +367,13 @@ def main():
|
||||||
else:
|
else:
|
||||||
dbg("no cuda or invalid local_rank; not calling set_device")
|
dbg("no cuda or invalid local_rank; not calling set_device")
|
||||||
|
|
||||||
|
if world_size > 1 and dist.is_available() and not dist.is_initialized():
|
||||||
|
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
||||||
|
dbg(f"init_process_group backend={backend} via env://")
|
||||||
|
dist.init_process_group(backend=backend, init_method="env://")
|
||||||
|
else:
|
||||||
|
dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}")
|
||||||
|
|
||||||
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
try:
|
try:
|
||||||
|
|
@ -399,7 +469,11 @@ def main():
|
||||||
local_ok = has_one_sample(probe_stream)
|
local_ok = has_one_sample(probe_stream)
|
||||||
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
|
# t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
|
||||||
|
t = torch.tensor(
|
||||||
|
local_ok,
|
||||||
|
device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")
|
||||||
|
)
|
||||||
dist.all_reduce(t, op=dist.ReduceOp.MIN)
|
dist.all_reduce(t, op=dist.ReduceOp.MIN)
|
||||||
if t.item() == 0:
|
if t.item() == 0:
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
|
|
@ -533,14 +607,25 @@ def main():
|
||||||
**ta_kwargs,
|
**ta_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(
|
|
||||||
|
trainer = DebugTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_stream,
|
train_dataset=train_stream,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
processing_class=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
# processing_class=tokenizer,
|
||||||
data_collator=data_collator
|
data_collator=data_collator
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# trainer = Trainer(
|
||||||
|
# model=model,
|
||||||
|
# args=training_args,
|
||||||
|
# train_dataset=train_stream,
|
||||||
|
# eval_dataset=eval_dataset,
|
||||||
|
# processing_class=tokenizer,
|
||||||
|
# data_collator=data_collator
|
||||||
|
# )
|
||||||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||||||
|
|
||||||
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
|
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue