#!/usr/bin/env python3 import os os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") import glob import socket import argparse import inspect import sys from typing import Dict, List, Iterable, Iterator, Tuple, Optional import torch import torch.distributed as dist from torch.utils.data import IterableDataset, Dataset from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, set_seed ) from transformers.trainer_callback import TrainerCallback # ----------------- 进程工具 ----------------- def is_main_process(): return int(os.environ.get("RANK", "0")) == 0 def print_once(*args, **kwargs): if is_main_process(): print(*args, **kwargs, flush=True) class DebugTrainer(Trainer): def training_step(self, model, inputs, num_items_in_batch=None): if not hasattr(self, "_dbg_printed"): rank = int(os.environ.get("RANK", "0")) host = socket.gethostname() ids = inputs["input_ids"] msk = inputs["attention_mask"] labs = inputs["labels"] print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} " f"supervised={(labs!=-100).sum().item()}", flush=True) print( f"[step0][host={host} RANK={rank}] " f"input_ids.shape={tuple(ids.shape)} " f"attention_mask.shape={tuple(msk.shape)} " f"labels.shape={tuple(labs.shape)} " f"num_items_in_batch={num_items_in_batch}", flush=True ) self._dbg_printed = True return super().training_step(model, inputs, num_items_in_batch) # ----------------- 日志回调 ----------------- class CsvLossLogger(TrainerCallback): def __init__(self, csv_path: str): self.csv_path = csv_path if is_main_process(): os.makedirs(os.path.dirname(csv_path), exist_ok=True) with open(self.csv_path, "w", encoding="utf-8") as f: f.write("step,loss,lr,total_flos\n") def on_log(self, args, state, control, logs=None, **kwargs): if not is_main_process() or logs is None: return with open(self.csv_path, "a", encoding="utf-8") as f: f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n") # ----------------- 仅监督 assistant 的数据集 ----------------- def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: """ 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。 """ spans: List[Tuple[int, int]] = [] open_tag = "<|im_start|>assistant\n" close_tag = "<|im_end|>\n" pos = 0 while True: a = rendered.find(open_tag, pos) if a == -1: break start = a + len(open_tag) b = rendered.find(close_tag, start) if b == -1: break spans.append((start, b)) pos = b + len(close_tag) return spans class QwenChatSFTDataset(IterableDataset): """ 期望 jsonl 每行形如: {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]} 可选包含工具: {"messages":[...], "tools":[{...}]} 工作流: - 使用 tokenizer.apply_chat_template 渲染 - 仅对 assistant 片段计损失(其他 token 的 label = -100) - 超长序列保留尾部(通常包含回答) """ def __init__(self, ex_iter: Iterable[dict], tokenizer: AutoTokenizer, seq_len: int = 4096): self.ex_iter = ex_iter self.tok = tokenizer self.seq_len = seq_len def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: # >>> DEBUG BEGIN dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1" if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0 dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3")) rank = int(os.environ.get("RANK", "0")) lrank = int(os.environ.get("LOCAL_RANK", "-1")) host = socket.gethostname() # >>> DEBUG END for ex in self.ex_iter: msgs = ex.get("messages", None) if not msgs or not isinstance(msgs, list): continue # 可选过滤 think bad = False for m in msgs: if m.get("role") == "assistant" and isinstance(m.get("content"), str): c = m["content"] if "" in c and "" in c: inner = c.split("")[-1].split("")[0].strip() if inner: bad = True; break if bad: continue tools = ex.get("tools", None) rendered: str = self.tok.apply_chat_template( msgs, tools=tools, add_generation_prompt=False, tokenize=False ) if not isinstance(rendered, str) or not rendered.strip(): continue spans = _assistant_char_spans(rendered) if not spans: continue enc = self.tok( rendered, add_special_tokens=False, return_offsets_mapping=True ) input_ids: List[int] = enc["input_ids"] offsets: List[Tuple[int, int]] = enc["offset_mapping"] if not input_ids: continue labels = [-100] * len(input_ids) def in_any_span(lo: int, hi: int) -> bool: for s, e in spans: if not (hi <= s or lo >= e): return True return False for i, (lo, hi) in enumerate(offsets): if in_any_span(lo, hi): labels[i] = input_ids[i] # —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len —— # 1) 截断到 seq_len(保留尾部) if len(input_ids) > self.seq_len: input_ids = input_ids[-self.seq_len:] labels = labels[-self.seq_len:] # 2) 左侧补齐到 seq_len(保证所有样本长度一致) pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id L = len(input_ids) if L < self.seq_len: pad = self.seq_len - L input_ids = ([pad_id] * pad) + input_ids labels = ([-100] * pad) + labels attn_mask = [0] * pad + [1] * L else: # 恰好等于 seq_len attn_mask = [1] * self.seq_len # 若没有任何可训练 token(labels 全 -100),跳过 if all(v == -100 for v in labels): continue assert len(input_ids) == self.seq_len assert len(labels) == self.seq_len assert len(attn_mask) == self.seq_len # >>> DEBUG PRINT(此时变量已定义) if dbg_on and self._dbg_seen < dbg_limit: sup_tok = sum(1 for v in labels if v != -100) print( f"[sample][host={host} RANK={rank} LRank={lrank}] " f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} " f"seq_len={self.seq_len} pad_id={pad_id}", flush=True ) if sup_tok == 0: print( f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped", flush=True ) self._dbg_seen += 1 # <<< DEBUG PRINT yield { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attn_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } # ----------------- 专用 Collator:pad inputs, pad labels=-100 ----------------- class SFTDataCollator: def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None): self.tok = tokenizer self.pad_to_length = pad_to_length assert self.tok.pad_token_id is not None def __call__(self, features): # if not features: # raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. " # f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.") def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) input_ids = [_to_list(f["input_ids"]) for f in features] attn_masks = [_to_list(f["attention_mask"]) for f in features] labels_list = [_to_list(f["labels"]) for f in features] max_len_in_batch = max(len(x) for x in input_ids) target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch pad_id = self.tok.pad_token_id batch_inp, batch_attn, batch_lab = [], [], [] for inp, msk, lab in zip(input_ids, attn_masks, labels_list): pad_len = target_len - len(inp) if pad_len < 0: inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len] pad_len = 0 batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long)) batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long)) batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long)) # >>> DEBUG BEGIN dbg_on = os.environ.get("DBG_COLLATE", "0") == "1" if dbg_on: rank = int(os.environ.get("RANK", "0")) host = socket.gethostname() bs = len(features) first_len = len(input_ids[0]) if bs > 0 else None print( f"[collate][host={host} RANK={rank}] features={bs} " f"target_len={target_len} first_len={first_len}", flush=True ) # 额外严苛校验:防止空 batch 继续往下走 if not features: raise RuntimeError( f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. " f"Check dataset sharding/streaming." ) # >>> DEBUG END return { "input_ids": torch.stack(batch_inp, dim=0), "attention_mask": torch.stack(batch_attn, dim=0), "labels": torch.stack(batch_lab, dim=0), } # ----------------- 参数 ----------------- def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--model_name_or_path", type=str, required=True, help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B)") ap.add_argument("--data_glob", type=str, required=True, help="本地 jsonl 通配符(每台机器都需有同路径数据;每行应含 messages/可选 tools)") ap.add_argument("--output_dir", type=str, required=True, help="本地输出目录(各节点各自本地写)") ap.add_argument("--seq_len", type=int, default=4096) ap.add_argument("--learning_rate", type=float, default=2e-5) ap.add_argument("--weight_decay", type=float, default=0.1) ap.add_argument("--warmup_ratio", type=float, default=0.02) ap.add_argument("--num_train_epochs", type=float, default=1.0) ap.add_argument("--max_steps", type=int, default=-1) ap.add_argument("--log_interval", type=int, default=10) ap.add_argument("--save_steps", type=int, default=500) ap.add_argument("--eval_ratio", type=float, default=0.0) ap.add_argument("--seed", type=int, default=1337) ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json") ap.add_argument("--gradient_checkpointing", action="store_true") ap.add_argument("--bf16", action="store_true", help="3090/A100/H100 等可开 bf16;同时在 DS 配置里也要开") ap.add_argument("--per_device_train_batch_size", type=int, default=1) ap.add_argument("--gradient_accumulation_steps", type=int, default=64) ap.add_argument("--report_to", type=str, default="tensorboard", choices=["none","tensorboard","wandb"]) ap.add_argument("--wandb_project", type=str, default="ds-qwen3") ap.add_argument("--eval_data_glob", type=str, default=None, help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用") ap.add_argument("--local_rank", type=int, default=-1, help="for deepspeed/torchrun launcher; ignored by user code") ap.add_argument("--per_device_eval_batch_size", type=int, default=1) return ap.parse_args() # ----------------- 主函数 ----------------- def main(): args = parse_args() set_seed(args.seed) # -------- 调试打印工具(每个 rank 都打)-------- host = socket.gethostname() def dbg(msg): print( f"[dbg][host={host} RANK={os.environ.get('RANK','0')} " f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}", flush=True ) # 版本 & 启动参数 & 关键环境变量 import transformers as hf try: import deepspeed as ds ds_ver = ds.__version__ except Exception: ds_ver = "n/a" dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}") dbg(f"args={args}") dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % ( os.environ.get("WORLD_SIZE"), os.environ.get("RANK"), os.environ.get("LOCAL_RANK", str(args.local_rank)), os.environ.get("MASTER_ADDR"), os.environ.get("MASTER_PORT"), os.environ.get("CUDA_VISIBLE_DEVICES"), )) dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}") # ---- 初始化分布式(供一致性探针使用)---- world_size = int(os.environ.get("WORLD_SIZE", "1")) rank = int(os.environ.get("RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank))) dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}") if torch.cuda.is_available() and local_rank >= 0: torch.cuda.set_device(local_rank) dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} " f"name={torch.cuda.get_device_name(torch.cuda.current_device())}") else: dbg("no cuda or invalid local_rank; not calling set_device") if world_size > 1 and dist.is_available() and not dist.is_initialized(): backend = "nccl" if torch.cuda.is_available() else "gloo" dbg(f"init_process_group backend={backend} via env://") dist.init_process_group(backend=backend, init_method="env://") else: dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}") if dist.is_available() and dist.is_initialized(): try: dbg(f"dist.get_backend()={dist.get_backend()} " f"dist.get_world_size()={dist.get_world_size()} dist.get_rank()={dist.get_rank()}") except Exception as e: dbg(f"dist query error: {e}") # 1) 先补 tokenizer 的 pad tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.model_max_length = args.seq_len dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} " f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}") # 2) 再加载模型 model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16), low_cpu_mem_usage=True, trust_remote_code=True ) dbg(f"model loaded: dtype={next(model.parameters()).dtype} " f"use_cache={getattr(model.config,'use_cache',None)} " f"pad_token_id={getattr(model.config,'pad_token_id',None)}") # 3) pad/alibi 等配置 model.config.pad_token_id = tokenizer.pad_token_id model.config.use_cache = False if args.gradient_checkpointing: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) try: torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_math_sdp(True) except Exception: pass # ===== 数据鲁棒性检查(多机各自执行)===== host = socket.gethostname() files = sorted(glob.glob(args.data_glob)) if len(files) == 0: raise FileNotFoundError( f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n" "每台机器都必须在相同本地路径下放置数据;" "可通过 DATA_GLOB= ./run_ds.sh 覆写。" ) if is_main_process(): print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True) # ====== 小探针:样本结构 ====== ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def ex_iter_probe(): for ex in ds_stream_probe: yield ex train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len) try: _ = next(iter(train_stream_probe)) except StopIteration: raise RuntimeError( f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n" "请确认每行 JSON 至少包含 'messages'(列表,含 user/assistant)字段;" "若含 请确保不包含真实思维文本,或移除。\n" "另外检查 seq_len 是否过小导致全部被裁。" ) # ====== 正式训练流 + 模数分片(不要求样本数整除 world_size) ====== ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def ex_iter2(): for i, ex in enumerate(ds_stream2): if i % max(world_size, 1) == rank: yield ex train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) # ====== 一致性探针:任意 rank 无样本 -> 全体退出 ====== def has_one_sample(stream): it = iter(stream) try: next(it); return 1 except StopIteration: return 0 ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def ex_iter2_probe(): for i, ex in enumerate(ds_stream_probe2): if i % max(world_size, 1) == rank: yield ex probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len) local_ok = has_one_sample(probe_stream) if dist.is_available() and dist.is_initialized(): # t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu")) t = torch.tensor( local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu") ) dist.all_reduce(t, op=dist.ReduceOp.MIN) if t.item() == 0: if is_main_process(): print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True) dist.barrier() sys.exit(2) else: if local_ok == 0: print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2) # ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ---- eval_dataset: Optional[Dataset] = None class ListDataset(Dataset): def __init__(self, items): self.items = items def __len__(self): return len(self.items) def __getitem__(self, idx): return self.items[idx] if args.eval_data_glob: eval_files = sorted(glob.glob(args.eval_data_glob)) if len(eval_files) == 0: raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}") if is_main_process(): print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True) ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True) def ex_iter_eval(): for ex in ds_eval_stream: yield ex eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len) eval_items: List[Dict[str, torch.Tensor]] = [] for sample in eval_iterable: eval_items.append(sample) if len(eval_items) == 0: raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。") eval_dataset = ListDataset(eval_items) elif args.eval_ratio and args.eval_ratio > 0: desired_eval_batches = 200 tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def ex_iter_eval2(): for ex in tmp_stream: yield ex eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len) eval_samples = [] it = iter(eval_stream) for _ in range(desired_eval_batches): try: eval_samples.append(next(it)) except StopIteration: break if len(eval_samples) > 0: eval_dataset = ListDataset(eval_samples) # ---- 统一补齐 eval 集(确保不会出现空 batch)---- if eval_dataset is not None: ws = max(world_size, 1) be = max(1, args.per_device_eval_batch_size) global_bs = ws * be r = len(eval_dataset) % global_bs if r != 0: need = global_bs - r # 你的 eval_dataset 是上面自定义的 ListDataset,带 .items eval_dataset.items += eval_dataset.items[:need] if is_main_process(): print(f"[eval] padded eval set to {len(eval_dataset)} " f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})", flush=True) # 补齐后再做 sanity check assert len(eval_dataset) % global_bs == 0, \ f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}" # 更稳:联调阶段不强行 pad 到 4096 # data_collator = SFTDataCollator(tokenizer, pad_to_length=None) data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len) os.makedirs(args.output_dir, exist_ok=True) logging_dir = os.path.join(args.output_dir, "logs") os.makedirs(logging_dir, exist_ok=True) # ---- 兼容 4.51(eval_strategy)与旧版(evaluation_strategy) ---- ta_kwargs = {} sig = inspect.signature(TrainingArguments.__init__).parameters if eval_dataset is not None: if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "steps" elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "steps" else: if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "no" elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "no" training_args = TrainingArguments( output_dir=args.output_dir, logging_dir=logging_dir, do_train=True, do_eval=(eval_dataset is not None), eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, weight_decay=args.weight_decay, warmup_ratio=args.warmup_ratio, num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0, max_steps=args.max_steps if args.max_steps > 0 else -1, lr_scheduler_type="cosine", logging_steps=args.log_interval, save_steps=args.save_steps, save_total_limit=2, deepspeed=args.deepspeed, dataloader_drop_last=False, # 关键:别丢尾,避免空 batch dataloader_num_workers=0, dataloader_prefetch_factor=None, dataloader_pin_memory=False, per_device_eval_batch_size=args.per_device_eval_batch_size, report_to=([] if args.report_to == "none" else [args.report_to]), bf16=args.bf16, fp16=(not args.bf16), gradient_checkpointing=args.gradient_checkpointing, remove_unused_columns=False, torch_compile=False, save_on_each_node=False, logging_first_step=True, **ta_kwargs, ) trainer = DebugTrainer( model=model, args=training_args, train_dataset=train_stream, eval_dataset=eval_dataset, tokenizer=tokenizer, # processing_class=tokenizer, data_collator=data_collator ) # trainer = Trainer( # model=model, # args=training_args, # train_dataset=train_stream, # eval_dataset=eval_dataset, # processing_class=tokenizer, # data_collator=data_collator # ) trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) # 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-* ckpt_exists = (os.path.isdir(args.output_dir) and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir))) resume_flag = True if ckpt_exists else None print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is True}") print_once("***** Starting training *****") train_result = trainer.train(resume_from_checkpoint=resume_flag) trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型 metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() if eval_dataset is not None: print_once("***** Running eval *****") eval_metrics = trainer.evaluate() trainer.log_metrics("eval", eval_metrics) trainer.save_metrics("eval", eval_metrics) print_once("Done.") if __name__ == "__main__": main()