diff --git a/.deepspeed_env b/.deepspeed_env index 84b51fe..a84c99f 100644 --- a/.deepspeed_env +++ b/.deepspeed_env @@ -2,7 +2,7 @@ WANDB_BASE_URL=https://wandb.szaiai.com WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1 WANDB_PROJECT=ds-qwen3 WANDB_ENTITY=hailin -WANDB_GROUP=q3-32b-ds4-2025-09-05 +WANDB_GROUP=q3-32b-ds4-2025-09-04 WANDB_NAME=q3-32b-lr2e-5-train2 WANDB_RESUME=allow WANDB_INIT_TIMEOUT=300 diff --git a/train_mm_zero3_lora.sh b/train_mm_zero3_lora.sh index 6f0f562..b0643ee 100644 --- a/train_mm_zero3_lora.sh +++ b/train_mm_zero3_lora.sh @@ -1,4 +1,4 @@ -deepspeed --hostfile hostfile \ +FORCE_COLOR=1 deepspeed --hostfile hostfile \ --num_nodes 6 --num_gpus 4 \ train_sft_lora.py \ --model_name_or_path /home/test/Qwen3-32B \ @@ -11,6 +11,5 @@ deepspeed --hostfile hostfile \ --learning_rate 1e-4 \ --warmup_ratio 0.03 \ --lora_r 16 --lora_alpha 32 --lora_dropout 0.05 \ - --lora_target auto \ --deepspeed /home/test/jd_train/ds_config_zero3_lora.json \ --report_to wandb --wandb_project ds-qwen3-lora diff --git a/train_sft_lora.py b/train_sft_lora.py index 107f9bc..3232513 100644 --- a/train_sft_lora.py +++ b/train_sft_lora.py @@ -1,8 +1,11 @@ +#!/usr/bin/env python3 import os os.environ.pop("PYTHONNOUSERSITE", None) os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("WANDB_START_METHOD", "thread") os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb") +os.environ.setdefault("WANDB_BASE_URL", "https://wandb.szaiai.com") +os.environ.setdefault("WANDB_INIT_TIMEOUT", "300") import glob import socket @@ -14,6 +17,7 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional import torch import torch.distributed as dist from torch.utils.data import IterableDataset, Dataset + from datasets import load_dataset from transformers import ( AutoTokenizer, @@ -24,14 +28,17 @@ from transformers import ( ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import get_last_checkpoint +from torch.optim import AdamW as TorchAdamW -# ---------- PATH / CUDA utils ---------- +# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ==== import site, shutil + home = os.path.expanduser("~") want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"] cur = os.environ.get("PATH", "").split(":") new = [d for d in want if d and d not in cur] + cur os.environ["PATH"] = ":".join(new) + print(f"[env] PATH={os.environ['PATH']}", flush=True) print(f"[env] which ninja={shutil.which('ninja')} which nvcc={shutil.which('nvcc')}", flush=True) @@ -45,17 +52,21 @@ print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CU os.environ.pop("DS_BUILD_OPS", None) os.environ.pop("DS_SKIP_CUDA_BUILD", None) + try: user_site = site.getusersitepackages() if user_site and user_site not in sys.path: sys.path.insert(0, user_site) except Exception: pass + os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext") os.environ.setdefault("MAX_JOBS", "12") + if shutil.which("ninja") is None: os.environ["USE_NINJA"] = "0" print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True) + try: from deepspeed.ops.op_builder import CPUAdamBuilder CPUAdamBuilder().load() @@ -68,11 +79,11 @@ except Exception as e: CPUAdamBuilder().load() print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True) else: + import socket print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True) - # 不致命:LoRA 不依赖这个算子,继续运行 - pass + raise -# ---------- helpers ---------- +# ----------------- 进程工具 ----------------- def is_main_process(): return int(os.environ.get("RANK", "0")) == 0 @@ -85,17 +96,27 @@ class DebugTrainer(Trainer): if not hasattr(self, "_dbg_printed"): rank = int(os.environ.get("RANK", "0")) host = socket.gethostname() - ids = inputs["input_ids"]; msk = inputs["attention_mask"]; labs = inputs["labels"] + ids = inputs["input_ids"] + msk = inputs["attention_mask"] + labs = inputs["labels"] print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} " - f"supervised={(labs!=-100).sum().item()}", flush=True) - print(f"[step0][host={host} RANK={rank}] " - f"input_ids.shape={tuple(ids.shape)} " - f"attention_mask.shape={tuple(msk.shape)} " - f"labels.shape={tuple(labs.shape)} " - f"num_items_in_batch={num_items_in_batch}", flush=True) + f"supervised={(labs!=-100).sum().item()}", + flush=True) + print( + f"[step0][host={host} RANK={rank}] " + f"input_ids.shape={tuple(ids.shape)} " + f"attention_mask.shape={tuple(msk.shape)} " + f"labels.shape={tuple(labs.shape)} " + f"num_items_in_batch={num_items_in_batch}", + flush=True + ) self._dbg_printed = True - return super().training_step(model, inputs, num_items_in_batch) + try: + return super().training_step(model, inputs, num_items_in_batch=num_items_in_batch) + except TypeError: + return super().training_step(model, inputs) +# ----------------- 日志回调 ----------------- class CsvLossLogger(TrainerCallback): def __init__(self, csv_path: str): self.csv_path = csv_path @@ -107,10 +128,13 @@ class CsvLossLogger(TrainerCallback): def on_train_begin(self, args, state, control, **kwargs): tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 - print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True) + rank = os.environ.get("RANK", "?") + host = socket.gethostname() + print(f"[{host} rank={rank}] total_steps={tot}", flush=True) def on_log(self, args, state, control, logs=None, **kwargs): - if logs is None: return + if logs is None: + return cur = int(getattr(state, "global_step", 0) or 0) tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 @@ -118,151 +142,180 @@ class CsvLossLogger(TrainerCallback): if tot and not hasattr(self, "_tot_announced"): print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True) self._tot_announced = True - print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] step {cur}/{tot} ({pct}) " - f"loss={logs.get('loss')} lr={logs.get('learning_rate')}", flush=True) - if not is_main_process(): return + rank = os.environ.get("RANK", "?") + host = socket.gethostname() + print( + f"[{host} rank={rank}] step {cur}/{tot} ({pct}) " + f"loss={logs.get('loss')} lr={logs.get('learning_rate')}", + flush=True + ) + if not is_main_process(): + return with open(self.csv_path, "a", encoding="utf-8") as f: f.write(f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n") -# ---------- assistant span detection ---------- -def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: - spans: List[Tuple[int, int]] = [] - open_tag = "<|im_start|>assistant\n" - close_tag = "<|im_end|>\n" - pos = 0 - while True: - a = rendered.find(open_tag, pos) - if a == -1: break - s = a + len(open_tag) - b = rendered.find(close_tag, s) - if b == -1: break - spans.append((s, b)) - pos = b + len(close_tag) - return spans - -# ---------- Dataset (supervise assistant incl. tags) ---------- +# ----------------- 仅监督 assistant 内容(token-id 级) ----------------- class QwenChatSFTDataset(IterableDataset): - def __init__(self, ex_iter: Iterable[dict], tokenizer: AutoTokenizer, seq_len: int = 4096): + def __init__(self, ex_iter: Iterable[dict], tokenizer: AutoTokenizer, + seq_len: int = 4096, mask_think_and_tags: bool = True): self.ex_iter = ex_iter self.tok = tokenizer self.seq_len = seq_len + self.mask_think_and_tags = mask_think_and_tags + self.id_START = self.tok.convert_tokens_to_ids("<|im_start|>") + self.id_END = self.tok.convert_tokens_to_ids("<|im_end|>") + self.ids_ASSISTANT_CANDIDATES = [ + self.tok.encode("assistant\n", add_special_tokens=False), + self.tok.encode("assistant", add_special_tokens=False), + ] + self.ids_ASSISTANT_CANDIDATES = [c for c in self.ids_ASSISTANT_CANDIDATES if len(c) > 0] + if not self.ids_ASSISTANT_CANDIDATES: + raise RuntimeError("[fatal] no valid 'assistant' role token sequence found; check chat template/tokenizer.") + self.ids_THINK_OPEN = self.tok.encode("", add_special_tokens=False) + self.ids_THINK_CLOSE = self.tok.encode("", add_special_tokens=False) + for name, val in {"id_START": self.id_START, "id_END": self.id_END}.items(): + if val is None or val == self.tok.unk_token_id: + raise RuntimeError(f"[fatal] tokenizer missing special token id for {name}") + + @staticmethod + def _find_subseq(hay: list, needle: list, start: int) -> int: + n = len(needle) + if n == 0: return start + for i in range(start, len(hay) - n + 1): + if hay[i:i+n] == needle: + return i + return -1 + + def _find_role_after_start(self, ids, j_start: int) -> Optional[Tuple[int, int]]: + for cand in self.ids_ASSISTANT_CANDIDATES: + pos = self._find_subseq(ids, cand, j_start) + if pos == j_start: + return (pos, len(cand)) + return None def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1" - if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0 dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3")) + seen = 0 + host = socket.gethostname() rank = int(os.environ.get("RANK", "0")) lrank = int(os.environ.get("LOCAL_RANK", "-1")) - host = socket.gethostname() for ex in self.ex_iter: - msgs = ex.get("messages", None) - if not msgs or not isinstance(msgs, list): continue - + msgs = ex.get("messages") + if not msgs or not isinstance(msgs, list): + continue tools = ex.get("tools", None) try: - rendered: str = self.tok.apply_chat_template( - msgs, tools=tools, add_generation_prompt=False, tokenize=False + ids = self.tok.apply_chat_template( + msgs, tools=tools, add_generation_prompt=False, + tokenize=True, return_tensors=None ) + if isinstance(ids, dict): + ids = ids["input_ids"] except TypeError: rendered: str = self.tok.apply_chat_template( msgs, add_generation_prompt=False, tokenize=False ) - if not isinstance(rendered, str) or not rendered.strip(): continue + ids = self.tok(rendered, add_special_tokens=False)["input_ids"] - spans = _assistant_char_spans(rendered) - if not spans: continue + if not ids: + continue + mask = [0] * len(ids) + i = 0 + while i < len(ids): + try: + a = ids.index(self.id_START, i) + except ValueError: + break + j = a + 1 + role_match = self._find_role_after_start(ids, j) + if role_match is None: + i = a + 1 + continue + _, role_len = role_match + content_lo = j + role_len + try: + b = ids.index(self.id_END, content_lo) + except ValueError: + i = a + 1 + continue + content_hi = b + for t in range(content_lo, content_hi): + mask[t] = 1 + if self.mask_think_and_tags: + p = content_lo + while True: + o = self._find_subseq(ids, self.ids_THINK_OPEN, p) + if o == -1 or o >= content_hi: + break + c = self._find_subseq(ids, self.ids_THINK_CLOSE, o + len(self.ids_THINK_OPEN)) + if c == -1 or c > content_hi: + break + x_lo = o + x_hi = c + len(self.ids_THINK_CLOSE) + for t in range(x_lo, min(x_hi, content_hi)): + mask[t] = 0 + p = x_hi + i = b + 1 - enc = self.tok(rendered, add_special_tokens=False, return_offsets_mapping=True) - input_ids: List[int] = enc["input_ids"] - offsets: List[Tuple[int, int]] = enc["offset_mapping"] - if not input_ids: continue - - labels = [-100] * len(input_ids) - - def in_any_span(lo: int, hi: int) -> bool: - for s, e in spans: - if not (hi <= s or lo >= e): - return True - return False - - for i, (lo, hi) in enumerate(offsets): - if in_any_span(lo, hi): - labels[i] = input_ids[i] - - if all(v == -100 for v in labels): # 无监督 token + if not any(mask): continue - # ---- assistant-aware truncation: keep last assistant not cut off - if len(input_ids) > self.seq_len: - s_last, e_last = spans[-1] - j = 0 - while j < len(offsets) and offsets[j][1] <= s_last: j += 1 - k_excl = j - while k_excl < len(offsets) and offsets[k_excl][0] < e_last: k_excl += 1 - A = max(0, k_excl - j) - if A >= self.seq_len: - start = max(0, k_excl - self.seq_len); end = start + self.seq_len - else: - start = max(0, min(j, len(input_ids) - self.seq_len)) - end = start + self.seq_len - if end < k_excl: - end = k_excl; start = end - self.seq_len - if start < 0: start = 0; end = self.seq_len - leftover = self.seq_len - A - left_wish = leftover // 2 - start = max(0, min(j - left_wish, start)) - end = start + self.seq_len - if end < k_excl: - end = k_excl; start = end - self.seq_len - if start < 0: start = 0; end = self.seq_len - input_ids = input_ids[start:end] - labels = labels[start:end] + if len(ids) > self.seq_len: + last_on = max(idx for idx, v in enumerate(mask) if v == 1) + end = min(len(ids), last_on + 1) + start = max(0, end - self.seq_len) + ids = ids[start:end] + mask = mask[start:end] pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id - L = len(input_ids) + L = len(ids) if L < self.seq_len: pad = self.seq_len - L - input_ids = ([pad_id]*pad) + input_ids - labels = ([-100]*pad) + labels - attn_mask = [0]*pad + [1]*L + input_ids = [pad_id] * pad + ids + attention_mask = [0] * pad + [1] * L + labels = [-100] * pad + [tok if m == 1 else -100 for tok, m in zip(ids, mask)] else: - attn_mask = [1]*self.seq_len + input_ids = ids + attention_mask = [1] * self.seq_len + labels = [tok if m == 1 else -100 for tok, m in zip(ids, mask)] - assert len(input_ids) == self.seq_len - assert len(labels) == self.seq_len - assert len(attn_mask) == self.seq_len - - if dbg_on and self._dbg_seen < dbg_limit: + if dbg_on and seen < dbg_limit: sup_tok = sum(1 for v in labels if v != -100) - print(f"[sample][host={host} RANK={rank} LRank={lrank}] " - f"toks={len(input_ids)} sup_toks={sup_tok} seq_len={self.seq_len} pad_id={pad_id}", flush=True) - if sup_tok == 0: - print(f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> skipped", flush=True) - self._dbg_seen += 1 + print( + f"[sample][host={host} RANK={rank} LRank={lrank}] " + f"toks={len(input_ids)} sup_toks={sup_tok} " + f"seq_len={self.seq_len} pad_id={pad_id}", + flush=True + ) + seen += 1 yield { "input_ids": torch.tensor(input_ids, dtype=torch.long), - "attention_mask": torch.tensor(attn_mask, dtype=torch.long), + "attention_mask": torch.tensor(attention_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } -# ---------- Collator ---------- +# ----------------- Collator ----------------- class SFTDataCollator: def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None): self.tok = tokenizer self.pad_to_length = pad_to_length - assert self.tok.pad_token_id is not None + assert self.tok.pad_token_id is not None, "tokenizer.pad_token_id must be set" def __call__(self, features): if not features: - raise RuntimeError(f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator.") + raise RuntimeError("Empty batch passed to collator") + def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) input_ids = [_to_list(f["input_ids"]) for f in features] attn_masks = [_to_list(f["attention_mask"]) for f in features] labels_list = [_to_list(f["labels"]) for f in features] + max_len_in_batch = max(len(x) for x in input_ids) target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch + pad_id = self.tok.pad_token_id batch_inp, batch_attn, batch_lab = [], [], [] for inp, msk, lab in zip(input_ids, attn_masks, labels_list): @@ -273,25 +326,23 @@ class SFTDataCollator: batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long)) batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long)) batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long)) - if os.environ.get("DBG_COLLATE","0") == "1": - print(f"[collate][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] " - f"features={len(features)} target_len={target_len}", flush=True) + return { "input_ids": torch.stack(batch_inp, dim=0), "attention_mask": torch.stack(batch_attn, dim=0), "labels": torch.stack(batch_lab, dim=0), } -# ---------- Args ---------- +# ----------------- 参数 ----------------- def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--model_name_or_path", type=str, required=True) ap.add_argument("--data_glob", type=str, required=True) ap.add_argument("--output_dir", type=str, required=True) ap.add_argument("--seq_len", type=int, default=4096) - ap.add_argument("--learning_rate", type=float, default=1e-4) # LoRA 通常可更大学习率 - ap.add_argument("--weight_decay", type=float, default=0.0) # LoRA 常设 0 或很小 - ap.add_argument("--warmup_ratio", type=float, default=0.03) + ap.add_argument("--learning_rate", type=float, default=2e-4) # LoRA通常更大lr + ap.add_argument("--weight_decay", type=float, default=0.0) + ap.add_argument("--warmup_ratio", type=float, default=0.02) ap.add_argument("--num_train_epochs", type=float, default=1.0) ap.add_argument("--max_steps", type=int, default=-1) ap.add_argument("--log_interval", type=int, default=10) @@ -302,56 +353,99 @@ def parse_args(): ap.add_argument("--bf16", action="store_true") ap.add_argument("--per_device_train_batch_size", type=int, default=1) ap.add_argument("--gradient_accumulation_steps", type=int, default=64) - ap.add_argument("--report_to", type=str, default="tensorboard", choices=["none","tensorboard","wandb"]) - ap.add_argument("--wandb_project", type=str, default="ds-qwen3-lora") + ap.add_argument("--report_to", type=str, default="tensorboard", + choices=["none","tensorboard","wandb"]) + ap.add_argument("--wandb_project", type=str, default="ds-qwen3") ap.add_argument("--eval_data_glob", type=str, default=None) ap.add_argument("--local_rank", type=int, default=-1) ap.add_argument("--per_device_eval_batch_size", type=int, default=1) ap.add_argument("--deepspeed", type=str, default=None) + ap.add_argument("--eval_steps", type=int, default=10) - # ---- LoRA specific ---- + # ===== LoRA 相关 ===== ap.add_argument("--lora_r", type=int, default=16) - ap.add_argument("--lora_alpha", type=float, default=32) + ap.add_argument("--lora_alpha", type=float, default=32.0) ap.add_argument("--lora_dropout", type=float, default=0.05) - ap.add_argument("--lora_target", type=str, default="auto", - help='逗号分隔,如 "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj";或 "auto"') + ap.add_argument("--lora_bias", type=str, default="none", choices=["none","all","lora_only"]) + ap.add_argument("--lora_exclude", type=str, default="", help="逗号分隔的层名后缀(如 lm_head,embed_tokens)用于排除") + ap.add_argument("--merge_lora_and_save", action="store_true", help="训练后把LoRA合并到基座并另存(占显存/内存大)") - ap.add_argument("--qlora", action="store_true", help="使用 4bit (NF4) QLoRA(多机 DS 不建议)") - ap.add_argument("--merge_lora_and_save", action="store_true", - help="训练后在 rank0 合并 LoRA 到基座并另存(注意显存/内存占用)") return ap.parse_args() -# ---------- LoRA helpers ---------- -def _auto_lora_targets(model) -> List[str]: - """ - 针对 Qwen/Llama 族,自动挑选常见的线性层名字;仅匹配存在的模块。 - """ - cand = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj", - "w1","w2","w3", "W_pack", "o_attn", "o_proj"] # 覆盖不同实现命名 - present = set() - for name, module in model.named_modules(): - if any(name.endswith(f".{c}") or name == c for c in cand): - present.add(name.split(".")[-1]) - # 回落:若一个都没匹配到,使用“所有 nn.Linear” - if not present: - return ["all-linear"] - # 去重且保序 - order = [] - for c in cand: - if c in present: order.append(c) - return order +# ----------------- 小工具:日志与颜色 ----------------- +def _make_dbg(): + try: + import colorama + colorama.just_fix_windows_console() + except Exception: + pass + def _use_color() -> bool: + if os.environ.get("NO_COLOR"): return False + if os.environ.get("FORCE_COLOR"): return True + return sys.stdout.isatty() + class _C: + reset="\033[0m"; gray="\033[90m"; green="\033[32m"; yellow="\033[33m"; red="\033[31m"; cyan="\033[36m" + _LEVEL_ALIAS={"": "info", None: "info", "ok":"ok","success":"ok","warn":"warn","warning":"warn","err":"err","error":"err","fatal":"err","fail":"err","info":"info"} + _LEVEL_COLOR={"ok":_C.green,"warn":_C.yellow,"err":_C.red,"info":_C.cyan} + def _norm_level(level): + if level is None: return "info" + if isinstance(level,(int,float)): + if level>=40: return "err" + if level>=30: return "warn" + return "info" + if isinstance(level,str): + key=level.strip().lower() + return _LEVEL_ALIAS.get(key,"info") + return "info" + def _paint(s,c): return f"{c}{s}{_C.reset}" if _use_color() else s + def dbg(msg, level=None): + lvl=_norm_level(level); color=_LEVEL_COLOR.get(lvl,_C.cyan) + host=socket.gethostname(); rank=os.environ.get("RANK","0"); lrank=os.environ.get("LOCAL_RANK","-1") + prefix=f"[dbg][host={host} RANK={rank} LOCAL_RANK={lrank}] " + print(_paint(prefix,_C.gray)+_paint(str(msg),color), flush=True) + return dbg +dbg=_make_dbg() -# ---------- main ---------- +# ----------------- LoRA 目标层自动发现:所有线性层 ----------------- +def discover_all_linear_leaf_names(model, exclude: List[str]) -> List[str]: + """ + 返回 LoRA target_modules 需要的“叶子模块名后缀”集合(去重)。 + 默认遍历 nn.Linear / bitsandbytes 的 Linear4bit/8bit 等线性类。 + """ + linear_like = [torch.nn.Linear] + try: + import bitsandbytes as bnb + import bitsandbytes.nn as bnbnn + # 兼容 bnb 线性封装 + for cls_name in ("Linear4bit", "Linear8bitLt"): + if hasattr(bnbnn, cls_name): + linear_like.append(getattr(bnbnn, cls_name)) + except Exception: + pass + + suffixes=set() + for full_name, module in model.named_modules(): + if any(isinstance(module, cls) for cls in linear_like): + last = full_name.split(".")[-1] + if last not in exclude: + suffixes.add(last) + targets = sorted(suffixes) + if not targets: + raise RuntimeError("未发现任何线性层可用于 LoRA。请检查模型结构或放宽排除列表。") + return targets + +# ----------------- 主函数 ----------------- def main(): args = parse_args() - if os.environ.get("RANK","0") != "0" and args.report_to == "wandb": + # 只有 rank0 用 wandb + if os.environ.get("RANK", "0") != "0" and args.report_to == "wandb": print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True) args.report_to = "none" set_seed(args.seed) - # DeepSpeed enable? + # DeepSpeed use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed)) dschf = None if use_ds: @@ -359,13 +453,42 @@ def main(): from transformers.integrations.deepspeed import HfDeepSpeedConfig src = "transformers.integrations.deepspeed" except Exception: - from transformers import HfDeepSpeedConfig - src = "transformers" + try: + from transformers import HfDeepSpeedConfig + src = "transformers" + except Exception as e: + raise RuntimeError("当前 transformers 版本未提供 HfDeepSpeedConfig,请升级/降级 transformers") from e dschf = HfDeepSpeedConfig(args.deepspeed) - print(f"[dbg] HfDeepSpeedConfig loaded from {src}", flush=True) + dbg(f"HfDeepSpeedConfig loaded from {src}") + # W&B(rank0) if args.report_to == "wandb": os.environ.setdefault("WANDB_PROJECT", args.wandb_project) + is_rank0 = os.environ.get("RANK", "0") == "0" and os.environ.get("LOCAL_RANK", "-1") in ("0","-1") + if is_rank0: + import wandb + try: + os.environ.pop("WANDB_RUN_ID", None) + extra={} + if os.getenv("WANDB_NAME"): extra["name"]=os.getenv("WANDB_NAME") + if os.getenv("WANDB_GROUP"): extra["group"]=os.getenv("WANDB_GROUP") + if os.getenv("WANDB_RESUME"): extra["resume"]=os.getenv("WANDB_RESUME") + run = wandb.init( + project=args.wandb_project, + entity=os.getenv("WANDB_ENTITY") or os.getenv("WB_ENTITY") or "hailin", + settings=wandb.Settings( + base_url=os.getenv("WANDB_BASE_URL","https://wandb.szaiai.com"), + init_timeout=int(os.getenv("WANDB_INIT_TIMEOUT","300")), + ), + **extra, + ) + print(f"[wandb] run url: {getattr(run, 'url', '(n/a)')}", flush=True) + except Exception as e: + print(f"[wandb] init failed -> disable logging, reason={e}", flush=True) + os.environ["WANDB_DISABLED"]="true" + args.report_to="none" + else: + os.environ["WANDB_DISABLED"]="true" import transformers as hf try: @@ -374,32 +497,44 @@ def main(): except Exception: ds_ver = "n/a" - def dbg(msg): - print(f"[dbg][host={socket.gethostname()} RANK={os.environ.get('RANK','0')} " - f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}", flush=True) - dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}") dbg(f"args={args}") dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % ( - os.environ.get("WORLD_SIZE"), os.environ.get("RANK"), + os.environ.get("WORLD_SIZE"), + os.environ.get("RANK"), os.environ.get("LOCAL_RANK", str(args.local_rank)), - os.environ.get("MASTER_ADDR"), os.environ.get("MASTER_PORT"), + os.environ.get("MASTER_ADDR"), + os.environ.get("MASTER_PORT"), os.environ.get("CUDA_VISIBLE_DEVICES"), )) dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}") - # init dist + # 分布式初始化 world_size = int(os.environ.get("WORLD_SIZE", "1")) rank = int(os.environ.get("RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank))) + dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}") + if torch.cuda.is_available() and local_rank >= 0: torch.cuda.set_device(local_rank) dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} " f"name={torch.cuda.get_device_name(torch.cuda.current_device())}") + else: + dbg("no cuda or invalid local_rank; not calling set_device") + if world_size > 1 and dist.is_available() and not dist.is_initialized(): backend = "nccl" if torch.cuda.is_available() else "gloo" dbg(f"init_process_group backend={backend} via env://") dist.init_process_group(backend=backend, init_method="env://") + else: + dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}") + + if dist.is_available() and dist.is_initialized(): + try: + dbg(f"dist.get_backend()={dist.get_backend()} " + f"dist.get_world_size()={dist.get_world_size()} dist.get_rank()={dist.get_rank()}") + except Exception as e: + dbg(f"dist query error: {e}") # tokenizer tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) @@ -410,109 +545,118 @@ def main(): tokenizer.padding_side = "left" except Exception: pass - - from transformers import PreTrainedTokenizerFast - if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False): - raise RuntimeError("需要 Fast tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14。") tokenizer.model_max_length = args.seq_len - dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} model_max_length={tokenizer.model_max_length}") + dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} " + f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}") # dtype def _bf16_supported(): - if not torch.cuda.is_available(): return False + if not torch.cuda.is_available(): + return False if hasattr(torch.cuda, "is_bf16_supported"): return torch.cuda.is_bf16_supported() major, minor = torch.cuda.get_device_capability() return (major, minor) >= (8, 0) use_bf16 = bool(args.bf16 and _bf16_supported()) - compute_dtype = torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32) + dtype = (torch.bfloat16 if use_bf16 else + (torch.float16 if torch.cuda.is_available() else torch.float32)) - # -------- load base model (with/without 4bit) -------- - quantization_config = None - if args.qlora: - try: - from transformers import BitsAndBytesConfig - from peft import prepare_model_for_kbit_training - except Exception as e: - raise RuntimeError("使用 --qlora 需要安装 bitsandbytes>=0.41 与 peft。") from e - quantization_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_use_double_quant=True, - bnb_4bit_compute_dtype=compute_dtype - ) - # 4bit 下不要传 attn_implementation="sdpa" 给部分旧版 torch - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - torch_dtype=compute_dtype, - trust_remote_code=True, - low_cpu_mem_usage=True, - quantization_config=quantization_config, - device_map=None # 用 DeepSpeed/Trainer 接管 - ) - if args.gradient_checkpointing: - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) - else: - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - torch_dtype=compute_dtype, - low_cpu_mem_usage=True, - trust_remote_code=True, - attn_implementation="sdpa", - ) - if args.gradient_checkpointing: - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + try: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + except Exception: + pass + # 基座模型 + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + torch_dtype=dtype, + low_cpu_mem_usage=True, + trust_remote_code=True, + attn_implementation="sdpa", + ) + + # pad/alibi 等 model.config.pad_token_id = tokenizer.pad_token_id - model.config.use_cache = False + if getattr(model, "generation_config", None) is not None: + model.generation_config.pad_token_id = tokenizer.pad_token_id + model.config.use_cache = False # 训练必须关掉 cache - # -------- wrap with LoRA -------- - from peft import LoraConfig, get_peft_model, TaskType, PeftModel - if args.lora_target.strip().lower() == "auto": - targets = _auto_lora_targets(model) - else: - targets = [x.strip() for x in args.lora_target.split(",") if x.strip()] - if not targets: - targets = _auto_lora_targets(model) + # ============ 关键改动:注入 LoRA ============ + # 1) 决定 LoRA 目标模块:默认“全模型所有线性层” + exclude = [x.strip() for x in args.lora_exclude.split(",") if x.strip()] + target_modules = discover_all_linear_leaf_names(model, exclude) + if is_main_process(): + print(f"[lora] target_modules (auto, all-linear minus exclude) = {target_modules}", flush=True) + # 2) 构造 LoRA 配置并注入 + from peft import LoraConfig, get_peft_model lora_cfg = LoraConfig( - task_type=TaskType.CAUSAL_LM, r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, - target_modules=targets, - bias="none", - inference_mode=False + bias=args.lora_bias, + task_type="CAUSAL_LM", + target_modules=target_modules, ) model = get_peft_model(model, lora_cfg) - # 冻结确认 - if is_main_process(): - try: - model.print_trainable_parameters() - except Exception: - trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) - total = sum(p.numel() for p in model.parameters()) - print(f"[LoRA] trainable={trainable:,} / total={total:,} ({trainable/total:.2%})", flush=True) + # 3) 再次配置梯度检查点(注入后调用更稳) + if args.gradient_checkpointing: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - # -------- data streams -------- + # 4) 打印可训练参数占比 + try: + from peft import get_peft_model_state_dict + trainable, total = 0, 0 + for n, p in model.named_parameters(): + total += p.numel() + if p.requires_grad: + trainable += p.numel() + pct = (trainable / total * 100.0) if total else 0.0 + if is_main_process(): + print(f"[lora] trainable params: {trainable} / {total} ({pct:.2f}%)", flush=True) + except Exception: + pass + # ============ LoRA 注入结束 ============ + + dbg(f"post-config: use_cache={model.config.use_cache} " + f"model.pad_token_id={model.config.pad_token_id} " + f"gen.pad_token_id={getattr(getattr(model,'generation_config',None),'pad_token_id',None)} " + f"tok.pad={tokenizer.pad_token}/{tokenizer.pad_token_id}") + + assert tokenizer.pad_token_id is not None, "tokenizer.pad_token_id must not be None" + assert model.config.pad_token_id == tokenizer.pad_token_id, \ + f"model.pad_token_id {model.config.pad_token_id} != tokenizer.pad_token_id {tokenizer.pad_token_id}" + + # ===== 数据检查 ===== + host = socket.gethostname() files = sorted(glob.glob(args.data_glob)) if len(files) == 0: - raise FileNotFoundError(f"No files matched DATA_GLOB={args.data_glob}") + raise FileNotFoundError( + f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n" + "每台机器都必须在相同本地路径下放置数据;" + ) if is_main_process(): - print(f"[data] matched {len(files)} files, example[0]={files[0]}", flush=True) + print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True) ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def ex_iter_probe(): - for ex in ds_stream_probe: yield ex + for ex in ds_stream_probe: + yield ex train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len) try: - _ = next(iter(train_stream_probe)) + sample = next(iter(train_stream_probe)) except StopIteration: - raise RuntimeError("[data] 样本结构不合法或全部被裁切。") + raise RuntimeError( + f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。" + ) + ids, attn, labs = sample["input_ids"], sample["attention_mask"], sample["labels"] + assert (labs != -100).any(), "[fatal] no supervised tokens in first valid sample" + assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100" - ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed) train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) @@ -521,25 +665,35 @@ def main(): def has_at_least(stream, n: int): it = iter(stream) for _ in range(n): - try: next(it) - except StopIteration: return 0 + try: + next(it) + except StopIteration: + return 0 return 1 need = max(1, args.gradient_accumulation_steps) local_ok = has_at_least(probe_stream, need) + if dist.is_available() and dist.is_initialized(): - t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank>=0 else "cpu")) + t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")) dist.all_reduce(t, op=dist.ReduceOp.MIN) if t.item() == 0: if is_main_process(): - print(f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批。", flush=True) - dist.barrier(); sys.exit(2) + print( + f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批 (GA={need})。 ", + flush=True + ) + dist.barrier() + sys.exit(2) else: if local_ok == 0: - print(f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批。", flush=True) + print( + f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批 (GA={need})。", + flush=True + ) sys.exit(2) - # eval + # ---- Eval ---- eval_dataset: Optional[Dataset] = None class ListDataset(Dataset): def __init__(self, items): self.items = items @@ -549,19 +703,39 @@ def main(): if args.eval_data_glob: eval_files = sorted(glob.glob(args.eval_data_glob)) if len(eval_files) == 0: - raise FileNotFoundError(f"No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}") + raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}") if is_main_process(): print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True) ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True) def ex_iter_eval(): - for ex in ds_eval_stream: yield ex + for ex in ds_eval_stream: + yield ex eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len) - eval_items: List[Dict[str, torch.Tensor]] = [s for s in eval_iterable] + eval_items: List[Dict[str, torch.Tensor]] = [] + for sample in eval_iterable: + eval_items.append(sample) if len(eval_items) == 0: - raise RuntimeError("[eval] 读到了 0 条有效样本。") + raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。") eval_dataset = ListDataset(eval_items) - # pad to global batch size - ws = max(int(os.environ.get("WORLD_SIZE","1")), 1) + elif args.eval_ratio and args.eval_ratio > 0: + desired_eval_batches = 200 + tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + def ex_iter_eval2(): + for ex in tmp_stream: + yield ex + eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len) + eval_samples = [] + it = iter(eval_stream) + for _ in range(desired_eval_batches): + try: + eval_samples.append(next(it)) + except StopIteration: + break + if len(eval_samples) > 0: + eval_dataset = ListDataset(eval_samples) + + if eval_dataset is not None: + ws = max(world_size, 1) be = max(1, args.per_device_eval_batch_size) global_bs = ws * be r = len(eval_dataset) % global_bs @@ -569,30 +743,39 @@ def main(): pad_need = global_bs - r eval_dataset.items += eval_dataset.items[:pad_need] if is_main_process(): - print(f"[eval] padded eval set to {len(eval_dataset)}", flush=True) + print(f"[eval] padded eval set to {len(eval_dataset)} " + f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})", + flush=True) + assert len(eval_dataset) % global_bs == 0 - # collator - data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len) + data_collator = SFTDataCollator(tokenizer, pad_to_length=None) - # training args os.makedirs(args.output_dir, exist_ok=True) - logging_dir = os.path.join(args.output_dir, "logs"); os.makedirs(logging_dir, exist_ok=True) + logging_dir = os.path.join(args.output_dir, "logs") + os.makedirs(logging_dir, exist_ok=True) + # ---- TrainingArguments ---- ta_kwargs = {} sig = inspect.signature(TrainingArguments.__init__).parameters if eval_dataset is not None: - if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "steps" - elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "steps" + if "eval_strategy" in sig: + ta_kwargs["eval_strategy"] = "steps" + elif "evaluation_strategy" in sig: + ta_kwargs["evaluation_strategy"] = "steps" else: - if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "no" - elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "no" + if "eval_strategy" in sig: + ta_kwargs["eval_strategy"] = "no" + elif "evaluation_strategy" in sig: + ta_kwargs["evaluation_strategy"] = "no" + ta_sig = inspect.signature(TrainingArguments.__init__).parameters ta_kwargs2 = dict( output_dir=args.output_dir, logging_dir=logging_dir, + run_name=f"lora-{os.path.basename(args.output_dir)}-{socket.gethostname()}", do_train=True, do_eval=(eval_dataset is not None), - eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None, + eval_steps=(args.eval_steps if eval_dataset is not None else None), per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, @@ -607,6 +790,7 @@ def main(): deepspeed=(args.deepspeed if use_ds else None), dataloader_drop_last=False, dataloader_num_workers=0, + label_smoothing_factor=0.0, per_device_eval_batch_size=args.per_device_eval_batch_size, report_to=([] if args.report_to == "none" else [args.report_to]), gradient_checkpointing=args.gradient_checkpointing, @@ -615,16 +799,14 @@ def main(): logging_first_step=True, **ta_kwargs, ) - # 精度:QLoRA/LoRA 均按 compute_dtype 设置 - if "dataloader_pin_memory" in sig: ta_kwargs2["dataloader_pin_memory"] = False - if "torch_compile" in sig: ta_kwargs2["torch_compile"] = False - ta_kwargs2.update({ - "bf16": (compute_dtype==torch.bfloat16), - "fp16": (compute_dtype==torch.float16), - }) + if "dataloader_pin_memory" in ta_sig: + ta_kwargs2["dataloader_pin_memory"] = False + if "torch_compile" in ta_sig: + ta_kwargs2["torch_compile"] = False + ta_kwargs2.update({"bf16": (dtype==torch.bfloat16), "fp16": (dtype==torch.float16)}) + training_args = TrainingArguments(**ta_kwargs2) - # pass tokenizer / processing_class trainer_kwargs = {} if "processing_class" in inspect.signature(Trainer.__init__).parameters: trainer_kwargs["processing_class"] = tokenizer @@ -637,20 +819,24 @@ def main(): train_dataset=train_stream, eval_dataset=eval_dataset, data_collator=data_collator, - **trainer_kwargs + **trainer_kwargs, ) trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) - # resume (per-node local checkpoint agreement) + # ==== 断点恢复判定 ==== def last_step(path: str) -> int: ck = get_last_checkpoint(path) - if ck is None: return -1 + if ck is None: + return -1 base = os.path.basename(ck) - try: return int(base.split("-")[-1]) - except Exception: return -1 + try: + return int(base.split("-")[-1]) + except Exception: + return -1 local_last = last_step(args.output_dir) - device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank>=0) else "cpu") + device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu") + resume_flag = None if dist.is_available() and dist.is_initialized(): has_local = torch.tensor(1 if local_last >= 0 else 0, device=device) @@ -671,6 +857,7 @@ def main(): resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}") print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}") + if dist.is_available() and dist.is_initialized(): present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device) dist.all_reduce(present, op=dist.ReduceOp.MIN) @@ -686,40 +873,34 @@ def main(): print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}") print_once("***** Starting LoRA training *****") - print(f"[dbg] allocated={torch.cuda.memory_allocated()/1024**2:.1f} MB, " - f"reserved={torch.cuda.memory_reserved()/1024**2:.1f} MB", flush=True) + + dbg(f"allocated={torch.cuda.memory_allocated()/1024**2:.1f} MB, " + f"reserved={torch.cuda.memory_reserved()/1024**2:.1f} MB") train_result = trainer.train(resume_from_checkpoint=resume_flag) + # 保存:此处保存的是“LoRA 适配器”(非合并的整权重) + trainer.save_model() # 保存到 output_dir, 包含 adapter_model.bin & adapter_config.json - # save adapter (not the full base) - trainer.save_model() # 对 PeftModel:只保存 adapter 权重到 output_dir metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() - # eval if eval_dataset is not None: print_once("***** Running eval *****") eval_metrics = trainer.evaluate() trainer.log_metrics("eval", eval_metrics) trainer.save_metrics("eval", eval_metrics) - # optional merge + # (可选)合并 LoRA 并另存 if args.merge_lora_and_save and is_main_process(): - print("[merge] Merging LoRA into base model ...", flush=True) - try: - if isinstance(trainer.model, PeftModel): - merged = trainer.model.merge_and_unload() - else: - merged = trainer.model - merge_dir = os.path.join(args.output_dir, "merged-full-model") - os.makedirs(merge_dir, exist_ok=True) - merged.save_pretrained(merge_dir, safe_serialization=True) - tokenizer.save_pretrained(merge_dir) - print(f"[merge] Saved merged model to: {merge_dir}", flush=True) - except Exception as e: - print(f"[merge] FAILED: {e}", flush=True) + print("[lora] merging LoRA into base weights ...", flush=True) + merged = model.merge_and_unload() # 需要足够显存/内存 + merge_dir = os.path.join(args.output_dir, "merged") + os.makedirs(merge_dir, exist_ok=True) + merged.save_pretrained(merge_dir, safe_serialization=True) + tokenizer.save_pretrained(merge_dir) + print(f"[lora] merged model saved to: {merge_dir}", flush=True) print_once("Done.") diff --git a/wipe_wandb_by_run_group.sh b/wipe_wandb_by_run_group.sh index 557bafe..8329b80 100755 --- a/wipe_wandb_by_run_group.sh +++ b/wipe_wandb_by_run_group.sh @@ -3,7 +3,7 @@ export WANDB_BASE_URL=https://wandb.szaiai.com export WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1 export WANDB_PROJECT=ds-qwen3 export WANDB_GROUP=q3-32b-ds4-2025-09-04 # 如果训练时没用 WANDB_RUN_GROUP,这里只是“期望值” -export MATCH_NAME_REGEX='q3-32b-ds4($|/|-)' # 回退方案:按名字匹配 +export MATCH_NAME_REGEX='^q3-32b-ds4' # 回退方案:按名字匹配 python3 - <<'PY' import os, re, wandb