#!/usr/bin/env python3 import os os.environ.pop("PYTHONNOUSERSITE", None) os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("WANDB_START_METHOD", "thread") os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb") os.environ.setdefault("WANDB_BASE_URL", "https://wandb.szaiai.com") os.environ.setdefault("WANDB_INIT_TIMEOUT", "300") import glob import socket import argparse import inspect import sys from typing import Dict, List, Iterable, Iterator, Tuple, Optional import torch import torch.distributed as dist from torch.utils.data import IterableDataset, Dataset from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, set_seed ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import get_last_checkpoint from torch.optim import AdamW as TorchAdamW from transformers import EarlyStoppingCallback # ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ==== import site, shutil home = os.path.expanduser("~") want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"] cur = os.environ.get("PATH", "").split(":") new = [d for d in want if d and d not in cur] + cur os.environ["PATH"] = ":".join(new) print(f"[env] PATH={os.environ['PATH']}", flush=True) print(f"[env] which ninja={shutil.which('ninja')} which nvcc={shutil.which('nvcc')}", flush=True) os.environ.setdefault("CUDA_HOME", "/usr/local/cuda-11.8") ld = os.environ.get("LD_LIBRARY_PATH", "") cuda_lib = "/usr/local/cuda-11.8/lib64" if cuda_lib not in ld.split(":"): os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True) os.environ.pop("DS_BUILD_OPS", None) os.environ.pop("DS_SKIP_CUDA_BUILD", None) try: user_site = site.getusersitepackages() if user_site and user_site not in sys.path: sys.path.insert(0, user_site) except Exception: pass os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext") os.environ.setdefault("MAX_JOBS", "12") if shutil.which("ninja") is None: os.environ["USE_NINJA"] = "0" print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True) try: from deepspeed.ops.op_builder import CPUAdamBuilder CPUAdamBuilder().load() print("[env] CPUAdamBuilder JIT OK", flush=True) except Exception as e: if "Ninja is required to load C++ extensions" in str(e): os.environ["USE_NINJA"] = "0" print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True) from deepspeed.ops.op_builder import CPUAdamBuilder CPUAdamBuilder().load() print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True) else: import socket print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True) raise # ----------------- 进程工具 ----------------- def is_main_process(): return int(os.environ.get("RANK", "0")) == 0 def print_once(*args, **kwargs): if is_main_process(): print(*args, **kwargs, flush=True) class DebugTrainer(Trainer): def training_step(self, model, inputs, num_items_in_batch=None): if not hasattr(self, "_dbg_printed"): rank = int(os.environ.get("RANK", "0")) host = socket.gethostname() ids = inputs["input_ids"] msk = inputs["attention_mask"] labs = inputs["labels"] print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} " f"supervised={(labs!=-100).sum().item()}", flush=True) print( f"[step0][host={host} RANK={rank}] " f"input_ids.shape={tuple(ids.shape)} " f"attention_mask.shape={tuple(msk.shape)} " f"labels.shape={tuple(labs.shape)} " f"num_items_in_batch={num_items_in_batch}", flush=True ) self._dbg_printed = True try: return super().training_step(model, inputs, num_items_in_batch=num_items_in_batch) except TypeError: return super().training_step(model, inputs) # ----------------- 日志回调 ----------------- class CsvLossLogger(TrainerCallback): def __init__(self, csv_path: str): self.csv_path = csv_path if is_main_process(): os.makedirs(os.path.dirname(csv_path), exist_ok=True) with open(self.csv_path, "w", encoding="utf-8") as f: f.write("step,loss,lr,total_flos\n") def on_train_begin(self, args, state, control, **kwargs): tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 rank = os.environ.get("RANK", "?") host = socket.gethostname() print(f"[{host} rank={rank}] total_steps={tot}", flush=True) def on_log(self, args, state, control, logs=None, **kwargs): if logs is None: return cur = int(getattr(state, "global_step", 0) or 0) tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a") if tot and not hasattr(self, "_tot_announced"): print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True) self._tot_announced = True rank = os.environ.get("RANK", "?") host = socket.gethostname() print( f"[{host} rank={rank}] step {cur}/{tot} ({pct}) " f"loss={logs.get('loss')} lr={logs.get('learning_rate')}", flush=True ) if not is_main_process(): return with open(self.csv_path, "a", encoding="utf-8") as f: f.write(f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n") # ----------------- 仅监督 assistant 内容(token-id 级) ----------------- class QwenChatSFTDataset(IterableDataset): def __init__(self, ex_iter: Iterable[dict], tokenizer: AutoTokenizer, seq_len: int = 4096, mask_think_and_tags: bool = True): self.ex_iter = ex_iter self.tok = tokenizer self.seq_len = seq_len self.mask_think_and_tags = mask_think_and_tags self.id_START = self.tok.convert_tokens_to_ids("<|im_start|>") self.id_END = self.tok.convert_tokens_to_ids("<|im_end|>") self.ids_ASSISTANT_CANDIDATES = [ self.tok.encode("assistant\n", add_special_tokens=False), self.tok.encode("assistant", add_special_tokens=False), ] self.ids_ASSISTANT_CANDIDATES = [c for c in self.ids_ASSISTANT_CANDIDATES if len(c) > 0] if not self.ids_ASSISTANT_CANDIDATES: raise RuntimeError("[fatal] no valid 'assistant' role token sequence found; check chat template/tokenizer.") self.ids_THINK_OPEN = self.tok.encode("", add_special_tokens=False) self.ids_THINK_CLOSE = self.tok.encode("", add_special_tokens=False) for name, val in {"id_START": self.id_START, "id_END": self.id_END}.items(): if val is None or val == self.tok.unk_token_id: raise RuntimeError(f"[fatal] tokenizer missing special token id for {name}") @staticmethod def _find_subseq(hay: list, needle: list, start: int) -> int: n = len(needle) if n == 0: return start for i in range(start, len(hay) - n + 1): if hay[i:i+n] == needle: return i return -1 def _find_role_after_start(self, ids, j_start: int) -> Optional[Tuple[int, int]]: for cand in self.ids_ASSISTANT_CANDIDATES: pos = self._find_subseq(ids, cand, j_start) if pos == j_start: return (pos, len(cand)) return None def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1" dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3")) seen = 0 host = socket.gethostname() rank = int(os.environ.get("RANK", "0")) lrank = int(os.environ.get("LOCAL_RANK", "-1")) it = self.ex_iter() if callable(self.ex_iter) else iter(self.ex_iter) for ex in it: # for ex in self.ex_iter: msgs = ex.get("messages") if not msgs or not isinstance(msgs, list): continue tools = ex.get("tools", None) try: ids = self.tok.apply_chat_template( msgs, tools=tools, add_generation_prompt=False, tokenize=True, return_tensors=None ) if isinstance(ids, dict): ids = ids["input_ids"] except TypeError: rendered: str = self.tok.apply_chat_template( msgs, add_generation_prompt=False, tokenize=False ) ids = self.tok(rendered, add_special_tokens=False)["input_ids"] if not ids: continue mask = [0] * len(ids) i = 0 while i < len(ids): try: a = ids.index(self.id_START, i) except ValueError: break j = a + 1 role_match = self._find_role_after_start(ids, j) if role_match is None: i = a + 1 continue _, role_len = role_match content_lo = j + role_len try: b = ids.index(self.id_END, content_lo) except ValueError: i = a + 1 continue content_hi = b for t in range(content_lo, content_hi): mask[t] = 1 if self.mask_think_and_tags: p = content_lo while True: o = self._find_subseq(ids, self.ids_THINK_OPEN, p) if o == -1 or o >= content_hi: break c = self._find_subseq(ids, self.ids_THINK_CLOSE, o + len(self.ids_THINK_OPEN)) if c == -1 or c > content_hi: break x_lo = o x_hi = c + len(self.ids_THINK_CLOSE) for t in range(x_lo, min(x_hi, content_hi)): mask[t] = 0 p = x_hi i = b + 1 if not any(mask): continue if len(ids) > self.seq_len: last_on = max(idx for idx, v in enumerate(mask) if v == 1) end = min(len(ids), last_on + 1) start = max(0, end - self.seq_len) ids = ids[start:end] mask = mask[start:end] pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id L = len(ids) if L < self.seq_len: pad = self.seq_len - L input_ids = [pad_id] * pad + ids attention_mask = [0] * pad + [1] * L labels = [-100] * pad + [tok if m == 1 else -100 for tok, m in zip(ids, mask)] else: input_ids = ids attention_mask = [1] * self.seq_len labels = [tok if m == 1 else -100 for tok, m in zip(ids, mask)] if dbg_on and seen < dbg_limit: sup_tok = sum(1 for v in labels if v != -100) print( f"[sample][host={host} RANK={rank} LRank={lrank}] " f"toks={len(input_ids)} sup_toks={sup_tok} " f"seq_len={self.seq_len} pad_id={pad_id}", flush=True ) seen += 1 yield { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attention_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } # ----------------- Collator ----------------- class SFTDataCollator: def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None): self.tok = tokenizer self.pad_to_length = pad_to_length assert self.tok.pad_token_id is not None, "tokenizer.pad_token_id must be set" def __call__(self, features): if not features: raise RuntimeError("Empty batch passed to collator") def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) input_ids = [_to_list(f["input_ids"]) for f in features] attn_masks = [_to_list(f["attention_mask"]) for f in features] labels_list = [_to_list(f["labels"]) for f in features] max_len_in_batch = max(len(x) for x in input_ids) target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch pad_id = self.tok.pad_token_id batch_inp, batch_attn, batch_lab = [], [], [] for inp, msk, lab in zip(input_ids, attn_masks, labels_list): pad_len = target_len - len(inp) if pad_len < 0: inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len] pad_len = 0 batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long)) batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long)) batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long)) # ensure this batch has supervised tokens has_sup = any((lab != -100).any().item() for lab in batch_lab) if not has_sup: raise RuntimeError("batch has zero supervised tokens; check masking or dataset.") return { "input_ids": torch.stack(batch_inp, dim=0), "attention_mask": torch.stack(batch_attn, dim=0), "labels": torch.stack(batch_lab, dim=0), } # ----------------- 参数 ----------------- def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--model_name_or_path", type=str, required=True) ap.add_argument("--data_glob", type=str, required=True) ap.add_argument("--output_dir", type=str, required=True) ap.add_argument("--seq_len", type=int, default=4096) ap.add_argument("--learning_rate", type=float, default=2e-4) # LoRA通常更大lr ap.add_argument("--weight_decay", type=float, default=0.0) ap.add_argument("--warmup_ratio", type=float, default=0.02) ap.add_argument("--num_train_epochs", type=float, default=1.0) ap.add_argument("--max_steps", type=int, default=-1) ap.add_argument("--log_interval", type=int, default=10) ap.add_argument("--save_steps", type=int, default=500) ap.add_argument("--eval_ratio", type=float, default=0.0) ap.add_argument("--seed", type=int, default=1337) ap.add_argument("--gradient_checkpointing", action="store_true") ap.add_argument("--bf16", action="store_true") ap.add_argument("--per_device_train_batch_size", type=int, default=1) ap.add_argument("--gradient_accumulation_steps", type=int, default=64) ap.add_argument("--report_to", type=str, default="tensorboard", choices=["none","tensorboard","wandb"]) ap.add_argument("--wandb_project", type=str, default="ds-qwen3") ap.add_argument("--eval_data_glob", type=str, default=None) ap.add_argument("--local_rank", type=int, default=-1) ap.add_argument("--per_device_eval_batch_size", type=int, default=1) ap.add_argument("--deepspeed", type=str, default=None) ap.add_argument("--eval_steps", type=int, default=10) ap.add_argument("--save_total_limit", type=int, default=2) # ===== LoRA 相关 ===== ap.add_argument("--lora_r", type=int, default=16) ap.add_argument("--lora_alpha", type=float, default=32.0) ap.add_argument("--lora_dropout", type=float, default=0.05) ap.add_argument("--lora_bias", type=str, default="none", choices=["none","all","lora_only"]) ap.add_argument("--lora_exclude", type=str, default="", help="逗号分隔的层名后缀(如 lm_head,embed_tokens)用于排除") ap.add_argument("--merge_lora_and_save", action="store_true", help="训练后把LoRA合并到基座并另存(占显存/内存大)") return ap.parse_args() # ----------------- 小工具:日志与颜色 ----------------- def _make_dbg(): try: import colorama colorama.just_fix_windows_console() except Exception: pass def _use_color() -> bool: if os.environ.get("NO_COLOR"): return False if os.environ.get("FORCE_COLOR"): return True return sys.stdout.isatty() class _C: reset="\033[0m"; gray="\033[90m"; green="\033[32m"; yellow="\033[33m"; red="\033[31m"; cyan="\033[36m" _LEVEL_ALIAS={"": "info", None: "info", "ok":"ok","success":"ok","warn":"warn","warning":"warn","err":"err","error":"err","fatal":"err","fail":"err","info":"info"} _LEVEL_COLOR={"ok":_C.green,"warn":_C.yellow,"err":_C.red,"info":_C.cyan} def _norm_level(level): if level is None: return "info" if isinstance(level,(int,float)): if level>=40: return "err" if level>=30: return "warn" return "info" if isinstance(level,str): key=level.strip().lower() return _LEVEL_ALIAS.get(key,"info") return "info" def _paint(s,c): return f"{c}{s}{_C.reset}" if _use_color() else s def dbg(msg, level=None): lvl=_norm_level(level); color=_LEVEL_COLOR.get(lvl,_C.cyan) host=socket.gethostname(); rank=os.environ.get("RANK","0"); lrank=os.environ.get("LOCAL_RANK","-1") prefix=f"[dbg][host={host} RANK={rank} LOCAL_RANK={lrank}] " print(_paint(prefix,_C.gray)+_paint(str(msg),color), flush=True) return dbg dbg=_make_dbg() # ----------------- LoRA 目标层自动发现:所有线性层 ----------------- def discover_all_linear_leaf_names(model, exclude: List[str]) -> List[str]: """ 返回 LoRA target_modules 需要的“叶子模块名后缀”集合(去重)。 默认遍历 nn.Linear / bitsandbytes 的 Linear4bit/8bit 等线性类。 """ linear_like = [torch.nn.Linear] try: import bitsandbytes as bnb import bitsandbytes.nn as bnbnn # 兼容 bnb 线性封装 for cls_name in ("Linear4bit", "Linear8bitLt"): if hasattr(bnbnn, cls_name): linear_like.append(getattr(bnbnn, cls_name)) except Exception: pass suffixes=set() for full_name, module in model.named_modules(): if any(isinstance(module, cls) for cls in linear_like): last = full_name.split(".")[-1] if last not in exclude: suffixes.add(last) targets = sorted(suffixes) if not targets: raise RuntimeError("未发现任何线性层可用于 LoRA。请检查模型结构或放宽排除列表。") return targets # ----------------- 主函数 ----------------- def main(): args = parse_args() # 只有 rank0 用 wandb if os.environ.get("RANK", "0") != "0" and args.report_to == "wandb": print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True) args.report_to = "none" set_seed(args.seed) # DeepSpeed use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed)) dschf = None if use_ds: try: from transformers.integrations.deepspeed import HfDeepSpeedConfig src = "transformers.integrations.deepspeed" except Exception: try: from transformers import HfDeepSpeedConfig src = "transformers" except Exception as e: raise RuntimeError("当前 transformers 版本未提供 HfDeepSpeedConfig,请升级/降级 transformers") from e dschf = HfDeepSpeedConfig(args.deepspeed) dbg(f"HfDeepSpeedConfig loaded from {src}") # W&B(rank0) if args.report_to == "wandb": os.environ.setdefault("WANDB_PROJECT", args.wandb_project) is_rank0 = os.environ.get("RANK", "0") == "0" and os.environ.get("LOCAL_RANK", "-1") in ("0","-1") if is_rank0: import wandb try: os.environ.pop("WANDB_RUN_ID", None) extra={} if os.getenv("WANDB_NAME"): extra["name"]=os.getenv("WANDB_NAME") if os.getenv("WANDB_GROUP"): extra["group"]=os.getenv("WANDB_GROUP") if os.getenv("WANDB_RESUME"): extra["resume"]=os.getenv("WANDB_RESUME") run = wandb.init( project=args.wandb_project, entity=os.getenv("WANDB_ENTITY") or os.getenv("WB_ENTITY") or "hailin", settings=wandb.Settings( base_url=os.getenv("WANDB_BASE_URL","https://wandb.szaiai.com"), init_timeout=int(os.getenv("WANDB_INIT_TIMEOUT","300")), ), **extra, ) print(f"[wandb] run url: {getattr(run, 'url', '(n/a)')}", flush=True) except Exception as e: print(f"[wandb] init failed -> disable logging, reason={e}", flush=True) os.environ["WANDB_DISABLED"]="true" args.report_to="none" else: os.environ["WANDB_DISABLED"]="true" import transformers as hf try: import deepspeed as ds ds_ver = ds.__version__ except Exception: ds_ver = "n/a" dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}") dbg(f"args={args}") dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % ( os.environ.get("WORLD_SIZE"), os.environ.get("RANK"), os.environ.get("LOCAL_RANK", str(args.local_rank)), os.environ.get("MASTER_ADDR"), os.environ.get("MASTER_PORT"), os.environ.get("CUDA_VISIBLE_DEVICES"), )) dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}") # 分布式初始化 world_size = int(os.environ.get("WORLD_SIZE", "1")) rank = int(os.environ.get("RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank))) dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}") if torch.cuda.is_available() and local_rank >= 0: torch.cuda.set_device(local_rank) dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} " f"name={torch.cuda.get_device_name(torch.cuda.current_device())}") else: dbg("no cuda or invalid local_rank; not calling set_device") if world_size > 1 and dist.is_available() and not dist.is_initialized(): backend = "nccl" if torch.cuda.is_available() else "gloo" dbg(f"init_process_group backend={backend} via env://") dist.init_process_group(backend=backend, init_method="env://") else: dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}") if dist.is_available() and dist.is_initialized(): try: dbg(f"dist.get_backend()={dist.get_backend()} " f"dist.get_world_size()={dist.get_world_size()} dist.get_rank()={dist.get_rank()}") except Exception as e: dbg(f"dist query error: {e}") # tokenizer tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token try: if getattr(tokenizer, "padding_side", None) != "left": tokenizer.padding_side = "left" except Exception: pass tokenizer.model_max_length = args.seq_len dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} " f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}") # dtype def _bf16_supported(): if not torch.cuda.is_available(): return False if hasattr(torch.cuda, "is_bf16_supported"): return torch.cuda.is_bf16_supported() major, minor = torch.cuda.get_device_capability() return (major, minor) >= (8, 0) use_bf16 = bool(args.bf16 and _bf16_supported()) dtype = (torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32)) try: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision("high") except Exception: pass # 基座模型 model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=dtype, low_cpu_mem_usage=True, trust_remote_code=True, attn_implementation="sdpa", ) # pad/alibi 等 model.config.pad_token_id = tokenizer.pad_token_id if getattr(model, "generation_config", None) is not None: model.generation_config.pad_token_id = tokenizer.pad_token_id model.config.use_cache = False # 训练必须关掉 cache # ============ 关键改动:注入 LoRA ============ # 1) 决定 LoRA 目标模块:默认“全模型所有线性层” exclude = [x.strip() for x in args.lora_exclude.split(",") if x.strip()] target_modules = discover_all_linear_leaf_names(model, exclude) if is_main_process(): print(f"[lora] target_modules (auto, all-linear minus exclude) = {target_modules}", flush=True) # 2) 构造 LoRA 配置并注入 from peft import LoraConfig, get_peft_model lora_cfg = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias=args.lora_bias, task_type="CAUSAL_LM", target_modules=target_modules, ) model = get_peft_model(model, lora_cfg) try: model.print_trainable_parameters() except Exception: pass # 3) 再次配置梯度检查点(注入后调用更稳) if args.gradient_checkpointing: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) # 关键:让输入参与梯度,从而兼容 checkpoint try: model.enable_input_require_grads() except AttributeError: # 旧版 transformers 兜底:给 embedding 输出打 requires_grad emb = model.get_input_embeddings() if hasattr(emb, "register_forward_hook"): emb.register_forward_hook(lambda m, inp, out: out.requires_grad_(True)) # 4) 打印可训练参数占比 try: from peft import get_peft_model_state_dict trainable, total = 0, 0 for n, p in model.named_parameters(): total += p.numel() if p.requires_grad: trainable += p.numel() pct = (trainable / total * 100.0) if total else 0.0 if is_main_process(): print(f"[lora] trainable params: {trainable} / {total} ({pct:.2f}%)", flush=True) except Exception: pass # ============ LoRA 注入结束 ============ dbg(f"post-config: use_cache={model.config.use_cache} " f"model.pad_token_id={model.config.pad_token_id} " f"gen.pad_token_id={getattr(getattr(model,'generation_config',None),'pad_token_id',None)} " f"tok.pad={tokenizer.pad_token}/{tokenizer.pad_token_id}") assert tokenizer.pad_token_id is not None, "tokenizer.pad_token_id must not be None" assert model.config.pad_token_id == tokenizer.pad_token_id, \ f"model.pad_token_id {model.config.pad_token_id} != tokenizer.pad_token_id {tokenizer.pad_token_id}" # ===== 数据检查 ===== host = socket.gethostname() files = sorted(glob.glob(args.data_glob)) if len(files) == 0: raise FileNotFoundError( f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n" "每台机器都必须在相同本地路径下放置数据;" ) if is_main_process(): print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True) ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def ex_iter_probe(): for ex in ds_stream_probe: yield ex train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len) try: sample = next(iter(train_stream_probe)) except StopIteration: raise RuntimeError( f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。" ) ids, attn, labs = sample["input_ids"], sample["attention_mask"], sample["labels"] assert (labs != -100).any(), "[fatal] no supervised tokens in first valid sample" assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100" ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed) # train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) train_stream = QwenChatSFTDataset(ds_stream2, tokenizer, seq_len=args.seq_len) ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) def has_at_least(stream, n: int): it = iter(stream) for _ in range(n): try: next(it) except StopIteration: return 0 return 1 need = max(1, args.gradient_accumulation_steps) local_ok = has_at_least(probe_stream, need) if dist.is_available() and dist.is_initialized(): t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")) dist.all_reduce(t, op=dist.ReduceOp.MIN) if t.item() == 0: if is_main_process(): print( f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批 (GA={need})。 ", flush=True ) dist.barrier() sys.exit(2) else: if local_ok == 0: print( f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批 (GA={need})。", flush=True ) sys.exit(2) # ---- Eval ---- eval_dataset: Optional[Dataset] = None class ListDataset(Dataset): def __init__(self, items): self.items = items def __len__(self): return len(self.items) def __getitem__(self, idx): return self.items[idx] if args.eval_data_glob: eval_files = sorted(glob.glob(args.eval_data_glob)) if len(eval_files) == 0: raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}") if is_main_process(): print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True) ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True) def ex_iter_eval(): for ex in ds_eval_stream: yield ex eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len) eval_items: List[Dict[str, torch.Tensor]] = [] for sample in eval_iterable: eval_items.append(sample) if len(eval_items) == 0: raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。") eval_dataset = ListDataset(eval_items) elif args.eval_ratio and args.eval_ratio > 0: desired_eval_batches = 200 tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def ex_iter_eval2(): for ex in tmp_stream: yield ex eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len) eval_samples = [] it = iter(eval_stream) for _ in range(desired_eval_batches): try: eval_samples.append(next(it)) except StopIteration: break if len(eval_samples) > 0: eval_dataset = ListDataset(eval_samples) if eval_dataset is not None: ws = max(world_size, 1) be = max(1, args.per_device_eval_batch_size) global_bs = ws * be r = len(eval_dataset) % global_bs if r != 0: pad_need = global_bs - r eval_dataset.items += eval_dataset.items[:pad_need] if is_main_process(): print(f"[eval] padded eval set to {len(eval_dataset)} " f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})", flush=True) assert len(eval_dataset) % global_bs == 0 data_collator = SFTDataCollator(tokenizer, pad_to_length=None) os.makedirs(args.output_dir, exist_ok=True) logging_dir = os.path.join(args.output_dir, "logs") os.makedirs(logging_dir, exist_ok=True) # ---- TrainingArguments ---- ta_kwargs = {} sig = inspect.signature(TrainingArguments.__init__).parameters if eval_dataset is not None: if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "steps" elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "steps" else: if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "no" elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "no" ta_sig = inspect.signature(TrainingArguments.__init__).parameters ta_kwargs2 = dict( output_dir=args.output_dir, logging_dir=logging_dir, run_name=f"lora-{os.path.basename(args.output_dir)}-{socket.gethostname()}", do_train=True, do_eval=(eval_dataset is not None), eval_steps=(args.eval_steps if eval_dataset is not None else None), per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, weight_decay=args.weight_decay, warmup_ratio=args.warmup_ratio, num_train_epochs = args.num_train_epochs, max_steps=args.max_steps if args.max_steps > 0 else -1, lr_scheduler_type="cosine", logging_steps=args.log_interval, save_steps=args.save_steps, save_total_limit=args.save_total_limit, deepspeed=(args.deepspeed if use_ds else None), dataloader_drop_last=False, dataloader_num_workers=0, label_smoothing_factor=0.0, per_device_eval_batch_size=args.per_device_eval_batch_size, report_to=([] if args.report_to == "none" else [args.report_to]), gradient_checkpointing=args.gradient_checkpointing, remove_unused_columns=False, save_on_each_node=True, logging_first_step=True, **ta_kwargs, ) if "dataloader_pin_memory" in ta_sig: ta_kwargs2["dataloader_pin_memory"] = False if "torch_compile" in ta_sig: ta_kwargs2["torch_compile"] = False ta_kwargs2.update({"bf16": (dtype==torch.bfloat16), "fp16": (dtype==torch.float16)}) ta_kwargs2.update(dict( load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, )) training_args = TrainingArguments(**ta_kwargs2) trainer_kwargs = {} if "processing_class" in inspect.signature(Trainer.__init__).parameters: trainer_kwargs["processing_class"] = tokenizer else: trainer_kwargs["tokenizer"] = tokenizer trainer = DebugTrainer( model=model, args=training_args, train_dataset=train_stream, eval_dataset=eval_dataset, data_collator=data_collator, **trainer_kwargs, ) trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) # ==== 断点恢复判定 ==== def last_step(path: str) -> int: ck = get_last_checkpoint(path) if ck is None: return -1 base = os.path.basename(ck) try: return int(base.split("-")[-1]) except Exception: return -1 local_last = last_step(args.output_dir) device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu") resume_flag = None if dist.is_available() and dist.is_initialized(): has_local = torch.tensor(1 if local_last >= 0 else 0, device=device) dist.all_reduce(has_local, op=dist.ReduceOp.MIN) if has_local.item() == 1: ts = torch.tensor(local_last, device=device) world = dist.get_world_size() buf = [torch.zeros_like(ts) for _ in range(world)] dist.all_gather(buf, ts) steps = [b.item() for b in buf] k = min(steps) if k >= 0: resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}") if is_main_process(): print(f"[resume] steps={steps}, resume={resume_flag}", flush=True) else: if local_last >= 0: resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}") print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}") if dist.is_available() and dist.is_initialized(): present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device) dist.all_reduce(present, op=dist.ReduceOp.MIN) if present.item() == 0: if is_main_process(): print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True) resume_flag = None dist.barrier() else: if resume_flag is not None and not os.path.isdir(resume_flag): print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True) resume_flag = None trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=1e-3)) print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}") print_once("***** Starting LoRA training *****") dbg(f"allocated={torch.cuda.memory_allocated()/1024**2:.1f} MB, " f"reserved={torch.cuda.memory_reserved()/1024**2:.1f} MB") train_result = trainer.train(resume_from_checkpoint=resume_flag) # 保存:此处保存的是“LoRA 适配器”(非合并的整权重) trainer.save_model() # 保存到 output_dir, 包含 adapter_model.bin & adapter_config.json metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() if eval_dataset is not None: print_once("***** Running eval *****") eval_metrics = trainer.evaluate() trainer.log_metrics("eval", eval_metrics) trainer.save_metrics("eval", eval_metrics) # (可选)合并 LoRA 并另存 if args.merge_lora_and_save and is_main_process(): print("[lora] merging LoRA into base weights ...", flush=True) merged = model.merge_and_unload() # 需要足够显存/内存 merge_dir = os.path.join(args.output_dir, "merged") os.makedirs(merge_dir, exist_ok=True) merged.save_pretrained(merge_dir, safe_serialization=True) tokenizer.save_pretrained(merge_dir) print(f"[lora] merged model saved to: {merge_dir}", flush=True) print_once("Done.") if __name__ == "__main__": main()