#!/usr/bin/env python3 import os import glob import socket import argparse import inspect from typing import Dict, List, Iterable, Iterator, Tuple, Optional import torch from torch.utils.data import IterableDataset, Dataset from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, set_seed ) from transformers.trainer_callback import TrainerCallback # ----------------- 进程工具 ----------------- def is_main_process(): return int(os.environ.get("RANK", "0")) == 0 def print_once(*args, **kwargs): if is_main_process(): print(*args, **kwargs, flush=True) # ----------------- 日志回调 ----------------- class CsvLossLogger(TrainerCallback): def __init__(self, csv_path: str): self.csv_path = csv_path if is_main_process(): os.makedirs(os.path.dirname(csv_path), exist_ok=True) with open(self.csv_path, "w", encoding="utf-8") as f: f.write("step,loss,lr,total_flos\n") def on_log(self, args, state, control, logs=None, **kwargs): if not is_main_process() or logs is None: return with open(self.csv_path, "a", encoding="utf-8") as f: f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n") # ----------------- 仅监督 assistant 的数据集 ----------------- def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: """ 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。 """ spans: List[Tuple[int, int]] = [] open_tag = "<|im_start|>assistant\n" close_tag = "<|im_end|>\n" pos = 0 while True: a = rendered.find(open_tag, pos) if a == -1: break start = a + len(open_tag) b = rendered.find(close_tag, start) if b == -1: break spans.append((start, b)) pos = b + len(close_tag) return spans class QwenChatSFTDataset(IterableDataset): """ 期望 jsonl 每行形如: {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]} 可选包含工具: {"messages":[...], "tools":[{...}]} 工作流: - 使用 tokenizer.apply_chat_template 渲染 - 仅对 assistant 片段计损失(其他 token 的 label = -100) - 超长序列保留尾部(通常包含回答) """ def __init__(self, ex_iter: Iterable[dict], tokenizer: AutoTokenizer, seq_len: int = 4096): self.ex_iter = ex_iter self.tok = tokenizer self.seq_len = seq_len def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: for ex in self.ex_iter: msgs = ex.get("messages", None) if not msgs or not isinstance(msgs, list): # 严格要求 messages 格式;发现旧的 "text" 数据直接跳过 continue # 可选:过滤掉带有非空 的样本(避免训练真实 COT) bad = False for m in msgs: if m.get("role") == "assistant" and isinstance(m.get("content"), str): c = m["content"] if "" in c and "" in c: inner = c.split("")[-1].split("")[0].strip() if inner: bad = True; break if bad: continue tools = ex.get("tools", None) # 1) 按模型自带模板渲染(不要手写) rendered: str = self.tok.apply_chat_template( msgs, tools=tools, add_generation_prompt=False, # 训练包含 assistant 答案 tokenize=False ) if not isinstance(rendered, str) or not rendered.strip(): continue # 2) 找出 assistant 片段的字符区间 spans = _assistant_char_spans(rendered) if not spans: continue # 3) 分词 + 字符/Token 对齐 enc = self.tok( rendered, add_special_tokens=False, return_offsets_mapping=True ) input_ids: List[int] = enc["input_ids"] offsets: List[Tuple[int, int]] = enc["offset_mapping"] # 4) 仅 assistant 计损失 labels = [-100] * len(input_ids) def in_any_span(lo: int, hi: int) -> bool: for s, e in spans: if not (hi <= s or lo >= e): return True return False for i, (lo, hi) in enumerate(offsets): if in_any_span(lo, hi): labels[i] = input_ids[i] # 5) 超长裁剪(保留尾部) if len(input_ids) > self.seq_len: input_ids = input_ids[-self.seq_len:] labels = labels[-self.seq_len:] yield { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.ones(len(input_ids), dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long) } # ----------------- 专用 Collator:pad inputs, pad labels=-100 ----------------- class SFTDataCollator: def __init__(self, tokenizer: AutoTokenizer): self.tok = tokenizer assert self.tok.pad_token_id is not None, "tokenizer.pad_token 不能为空;已在主函数里兜底为 eos_token" def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: # 将变长样本对齐到 batch 内最大长度;labels 用 -100 补齐 def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) input_ids = [ _to_list(f["input_ids"]) for f in features ] attn_masks = [ _to_list(f["attention_mask"]) for f in features ] labels_list = [ _to_list(f["labels"]) for f in features ] max_len = max(len(x) for x in input_ids) pad_id = self.tok.pad_token_id batch_inp, batch_attn, batch_lab = [], [], [] for inp, msk, lab in zip(input_ids, attn_masks, labels_list): pad_len = max_len - len(inp) batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long)) batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long)) batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long)) return { "input_ids": torch.stack(batch_inp, dim=0), "attention_mask": torch.stack(batch_attn, dim=0), "labels": torch.stack(batch_lab, dim=0), } # ----------------- 参数 ----------------- def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--model_name_or_path", type=str, required=True, help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B)") ap.add_argument("--data_glob", type=str, required=True, help="本地 jsonl 通配符(每台机器都需有同路径数据;每行应含 messages/可选 tools)") ap.add_argument("--output_dir", type=str, required=True, help="本地输出目录(各节点各自本地写)") ap.add_argument("--seq_len", type=int, default=4096) ap.add_argument("--learning_rate", type=float, default=2e-5) ap.add_argument("--weight_decay", type=float, default=0.1) ap.add_argument("--warmup_ratio", type=float, default=0.02) ap.add_argument("--num_train_epochs", type=float, default=1.0) ap.add_argument("--max_steps", type=int, default=-1) ap.add_argument("--log_interval", type=int, default=10) ap.add_argument("--save_steps", type=int, default=500) ap.add_argument("--eval_ratio", type=float, default=0.0) # 兜底抽样评估 ap.add_argument("--seed", type=int, default=1337) ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json") ap.add_argument("--gradient_checkpointing", action="store_true") ap.add_argument("--bf16", action="store_true", help="3090/A100/H100 等可开 bf16;同时在 DS 配置里也要开") ap.add_argument("--per_device_train_batch_size", type=int, default=1) ap.add_argument("--gradient_accumulation_steps", type=int, default=64) ap.add_argument("--report_to", type=str, default="tensorboard", choices=["none","tensorboard","wandb"]) ap.add_argument("--wandb_project", type=str, default="ds-qwen3") ap.add_argument("--eval_data_glob", type=str, default=None, help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用") ap.add_argument("--local_rank", type=int, default=-1, help="for deepspeed/torchrun launcher; ignored by user code") return ap.parse_args() # ----------------- 主函数 ----------------- def main(): args = parse_args() set_seed(args.seed) # Tokenizer/Model tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 供 padding 使用 model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16), low_cpu_mem_usage=True, trust_remote_code=True ) model.config.use_cache = False # 训练时禁用 cache if args.gradient_checkpointing: model.gradient_checkpointing_enable() # ===== 数据鲁棒性检查(多机各自执行)===== host = socket.gethostname() rank = int(os.environ.get("RANK", "0")) files = sorted(glob.glob(args.data_glob)) if len(files) == 0: raise FileNotFoundError( f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n" "每台机器都必须在相同本地路径下放置数据;" "可通过 DATA_GLOB= ./run_ds.sh 覆写。" ) if is_main_process(): print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True) # streaming 逐行读取(messages/tools 结构) ds_stream = load_dataset( "json", data_files={"train": files}, split="train", streaming=True ) def ex_iter(): for ex in ds_stream: yield ex train_stream_probe = QwenChatSFTDataset(ex_iter(), tokenizer, seq_len=args.seq_len) # 探针:确保能产出至少一个样本 _probe_it = iter(train_stream_probe) try: _ = next(_probe_it) except StopIteration: raise RuntimeError( f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n" "请确认每行 JSON 至少包含 'messages'(列表,含 user/assistant)字段;" "若含 请确保不包含真实思维文本,或移除。\n" "另外检查 seq_len 是否过小导致全部被裁。" ) # 探针已消耗流;为正式训练重建一次 ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def ex_iter2(): for ex in ds_stream2: yield ex train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) # ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ---- eval_dataset: Optional[Dataset] = None class ListDataset(Dataset): def __init__(self, items): self.items = items def __len__(self): return len(self.items) def __getitem__(self, idx): return self.items[idx] if args.eval_data_glob: eval_files = sorted(glob.glob(args.eval_data_glob)) if len(eval_files) == 0: raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}") if is_main_process(): print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True) ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True) def ex_iter_eval(): for ex in ds_eval_stream: yield ex eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len) eval_items: List[Dict[str, torch.Tensor]] = [] for sample in eval_iterable: eval_items.append(sample) if len(eval_items) == 0: raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。") eval_dataset = ListDataset(eval_items) elif args.eval_ratio and args.eval_ratio > 0: # 简易头部抽样(流式下仅作粗评) desired_eval_batches = 200 tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def ex_iter_eval2(): for ex in tmp_stream: yield ex eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len) eval_samples = [] it = iter(eval_stream) for _ in range(desired_eval_batches): try: eval_samples.append(next(it)) except StopIteration: break if len(eval_samples) > 0: eval_dataset = ListDataset(eval_samples) data_collator = SFTDataCollator(tokenizer) os.makedirs(args.output_dir, exist_ok=True) logging_dir = os.path.join(args.output_dir, "logs") os.makedirs(logging_dir, exist_ok=True) # ---- 兼容 4.51(eval_strategy)与旧版(evaluation_strategy) ---- ta_kwargs = {} sig = inspect.signature(TrainingArguments.__init__).parameters if eval_dataset is not None: if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "steps" elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "steps" else: if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "no" elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "no" training_args = TrainingArguments( output_dir=args.output_dir, logging_dir=logging_dir, do_train=True, do_eval=(eval_dataset is not None), # evaluation_strategy / eval_strategy 通过 **ta_kwargs 传入 eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, weight_decay=args.weight_decay, warmup_ratio=args.warmup_ratio, num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0, max_steps=args.max_steps if args.max_steps > 0 else -1, lr_scheduler_type="cosine", logging_steps=args.log_interval, save_steps=args.save_steps, save_total_limit=2, deepspeed=args.deepspeed, dataloader_drop_last=True, dataloader_num_workers=2, report_to=([] if args.report_to == "none" else [args.report_to]), bf16=args.bf16, fp16=(not args.bf16), gradient_checkpointing=args.gradient_checkpointing, remove_unused_columns=False, # 需要保留我们的字段 torch_compile=False, save_on_each_node=False, **ta_kwargs, # ← 兼容参数 ) trainer = Trainer( model=model, args=training_args, train_dataset=train_stream, eval_dataset=eval_dataset, tokenizer=tokenizer, data_collator=data_collator ) trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) # 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-* ckpt_exists = (os.path.isdir(args.output_dir) and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir))) resume_flag = True if ckpt_exists else None print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is True}") print_once("***** Starting training *****") train_result = trainer.train(resume_from_checkpoint=resume_flag) trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型 metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() if eval_dataset is not None: print_once("***** Running eval *****") eval_metrics = trainer.evaluate() trainer.log_metrics("eval", eval_metrics) trainer.save_metrics("eval", eval_metrics) print_once("Done.") if __name__ == "__main__": main()