#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os os.environ.pop("PYTHONNOUSERSITE", None) os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("WANDB_START_METHOD", "thread") os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb") import glob import socket import argparse import inspect import sys from typing import Dict, List, Iterable, Iterator, Tuple, Optional import torch import torch.distributed as dist from torch.utils.data import IterableDataset, Dataset from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, set_seed ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import get_last_checkpoint # ---------- PATH / CUDA utils ---------- import site, shutil home = os.path.expanduser("~") want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"] cur = os.environ.get("PATH", "").split(":") new = [d for d in want if d and d not in cur] + cur os.environ["PATH"] = ":".join(new) print(f"[env] PATH={os.environ['PATH']}", flush=True) print(f"[env] which ninja={shutil.which('ninja')} which nvcc={shutil.which('nvcc')}", flush=True) os.environ.setdefault("CUDA_HOME", "/usr/local/cuda-11.8") ld = os.environ.get("LD_LIBRARY_PATH", "") cuda_lib = "/usr/local/cuda-11.8/lib64" if cuda_lib not in ld.split(":"): os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True) os.environ.pop("DS_BUILD_OPS", None) os.environ.pop("DS_SKIP_CUDA_BUILD", None) try: user_site = site.getusersitepackages() if user_site and user_site not in sys.path: sys.path.insert(0, user_site) except Exception: pass os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext") os.environ.setdefault("MAX_JOBS", "12") if shutil.which("ninja") is None: os.environ["USE_NINJA"] = "0" print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True) try: from deepspeed.ops.op_builder import CPUAdamBuilder CPUAdamBuilder().load() print("[env] CPUAdamBuilder JIT OK", flush=True) except Exception as e: if "Ninja is required to load C++ extensions" in str(e): os.environ["USE_NINJA"] = "0" print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True) from deepspeed.ops.op_builder import CPUAdamBuilder CPUAdamBuilder().load() print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True) else: print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True) # 不致命:LoRA 不依赖这个算子,继续运行 pass # ---------- helpers ---------- def is_main_process(): return int(os.environ.get("RANK", "0")) == 0 def print_once(*args, **kwargs): if is_main_process(): print(*args, **kwargs, flush=True) class DebugTrainer(Trainer): def training_step(self, model, inputs, num_items_in_batch=None): if not hasattr(self, "_dbg_printed"): rank = int(os.environ.get("RANK", "0")) host = socket.gethostname() ids = inputs["input_ids"]; msk = inputs["attention_mask"]; labs = inputs["labels"] print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} " f"supervised={(labs!=-100).sum().item()}", flush=True) print(f"[step0][host={host} RANK={rank}] " f"input_ids.shape={tuple(ids.shape)} " f"attention_mask.shape={tuple(msk.shape)} " f"labels.shape={tuple(labs.shape)} " f"num_items_in_batch={num_items_in_batch}", flush=True) self._dbg_printed = True return super().training_step(model, inputs, num_items_in_batch) class CsvLossLogger(TrainerCallback): def __init__(self, csv_path: str): self.csv_path = csv_path if is_main_process(): os.makedirs(os.path.dirname(csv_path), exist_ok=True) with open(self.csv_path, "w", encoding="utf-8") as f: f.write("step,loss,lr,total_flos\n") def on_train_begin(self, args, state, control, **kwargs): tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True) def on_log(self, args, state, control, logs=None, **kwargs): if logs is None: return cur = int(getattr(state, "global_step", 0) or 0) tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a") if tot and not hasattr(self, "_tot_announced"): print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True) self._tot_announced = True print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] step {cur}/{tot} ({pct}) " f"loss={logs.get('loss')} lr={logs.get('learning_rate')}", flush=True) if not is_main_process(): return with open(self.csv_path, "a", encoding="utf-8") as f: f.write(f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n") # ---------- assistant span detection ---------- def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: spans: List[Tuple[int, int]] = [] open_tag = "<|im_start|>assistant\n" close_tag = "<|im_end|>\n" pos = 0 while True: a = rendered.find(open_tag, pos) if a == -1: break s = a + len(open_tag) b = rendered.find(close_tag, s) if b == -1: break spans.append((s, b)) pos = b + len(close_tag) return spans # ---------- Dataset (supervise assistant incl. tags) ---------- class QwenChatSFTDataset(IterableDataset): def __init__(self, ex_iter: Iterable[dict], tokenizer: AutoTokenizer, seq_len: int = 4096): self.ex_iter = ex_iter self.tok = tokenizer self.seq_len = seq_len def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1" if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0 dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3")) rank = int(os.environ.get("RANK", "0")) lrank = int(os.environ.get("LOCAL_RANK", "-1")) host = socket.gethostname() for ex in self.ex_iter: msgs = ex.get("messages", None) if not msgs or not isinstance(msgs, list): continue tools = ex.get("tools", None) try: rendered: str = self.tok.apply_chat_template( msgs, tools=tools, add_generation_prompt=False, tokenize=False ) except TypeError: rendered: str = self.tok.apply_chat_template( msgs, add_generation_prompt=False, tokenize=False ) if not isinstance(rendered, str) or not rendered.strip(): continue spans = _assistant_char_spans(rendered) if not spans: continue enc = self.tok(rendered, add_special_tokens=False, return_offsets_mapping=True) input_ids: List[int] = enc["input_ids"] offsets: List[Tuple[int, int]] = enc["offset_mapping"] if not input_ids: continue labels = [-100] * len(input_ids) def in_any_span(lo: int, hi: int) -> bool: for s, e in spans: if not (hi <= s or lo >= e): return True return False for i, (lo, hi) in enumerate(offsets): if in_any_span(lo, hi): labels[i] = input_ids[i] if all(v == -100 for v in labels): # 无监督 token continue # ---- assistant-aware truncation: keep last assistant not cut off if len(input_ids) > self.seq_len: s_last, e_last = spans[-1] j = 0 while j < len(offsets) and offsets[j][1] <= s_last: j += 1 k_excl = j while k_excl < len(offsets) and offsets[k_excl][0] < e_last: k_excl += 1 A = max(0, k_excl - j) if A >= self.seq_len: start = max(0, k_excl - self.seq_len); end = start + self.seq_len else: start = max(0, min(j, len(input_ids) - self.seq_len)) end = start + self.seq_len if end < k_excl: end = k_excl; start = end - self.seq_len if start < 0: start = 0; end = self.seq_len leftover = self.seq_len - A left_wish = leftover // 2 start = max(0, min(j - left_wish, start)) end = start + self.seq_len if end < k_excl: end = k_excl; start = end - self.seq_len if start < 0: start = 0; end = self.seq_len input_ids = input_ids[start:end] labels = labels[start:end] pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id L = len(input_ids) if L < self.seq_len: pad = self.seq_len - L input_ids = ([pad_id]*pad) + input_ids labels = ([-100]*pad) + labels attn_mask = [0]*pad + [1]*L else: attn_mask = [1]*self.seq_len assert len(input_ids) == self.seq_len assert len(labels) == self.seq_len assert len(attn_mask) == self.seq_len if dbg_on and self._dbg_seen < dbg_limit: sup_tok = sum(1 for v in labels if v != -100) print(f"[sample][host={host} RANK={rank} LRank={lrank}] " f"toks={len(input_ids)} sup_toks={sup_tok} seq_len={self.seq_len} pad_id={pad_id}", flush=True) if sup_tok == 0: print(f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> skipped", flush=True) self._dbg_seen += 1 yield { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attn_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } # ---------- Collator ---------- class SFTDataCollator: def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None): self.tok = tokenizer self.pad_to_length = pad_to_length assert self.tok.pad_token_id is not None def __call__(self, features): if not features: raise RuntimeError(f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator.") def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) input_ids = [_to_list(f["input_ids"]) for f in features] attn_masks = [_to_list(f["attention_mask"]) for f in features] labels_list = [_to_list(f["labels"]) for f in features] max_len_in_batch = max(len(x) for x in input_ids) target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch pad_id = self.tok.pad_token_id batch_inp, batch_attn, batch_lab = [], [], [] for inp, msk, lab in zip(input_ids, attn_masks, labels_list): pad_len = target_len - len(inp) if pad_len < 0: inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len] pad_len = 0 batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long)) batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long)) batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long)) if os.environ.get("DBG_COLLATE","0") == "1": print(f"[collate][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] " f"features={len(features)} target_len={target_len}", flush=True) return { "input_ids": torch.stack(batch_inp, dim=0), "attention_mask": torch.stack(batch_attn, dim=0), "labels": torch.stack(batch_lab, dim=0), } # ---------- Args ---------- def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--model_name_or_path", type=str, required=True) ap.add_argument("--data_glob", type=str, required=True) ap.add_argument("--output_dir", type=str, required=True) ap.add_argument("--seq_len", type=int, default=4096) ap.add_argument("--learning_rate", type=float, default=1e-4) # LoRA 通常可更大学习率 ap.add_argument("--weight_decay", type=float, default=0.0) # LoRA 常设 0 或很小 ap.add_argument("--warmup_ratio", type=float, default=0.03) ap.add_argument("--num_train_epochs", type=float, default=1.0) ap.add_argument("--max_steps", type=int, default=-1) ap.add_argument("--log_interval", type=int, default=10) ap.add_argument("--save_steps", type=int, default=500) ap.add_argument("--eval_ratio", type=float, default=0.0) ap.add_argument("--seed", type=int, default=1337) ap.add_argument("--gradient_checkpointing", action="store_true") ap.add_argument("--bf16", action="store_true") ap.add_argument("--per_device_train_batch_size", type=int, default=1) ap.add_argument("--gradient_accumulation_steps", type=int, default=64) ap.add_argument("--report_to", type=str, default="tensorboard", choices=["none","tensorboard","wandb"]) ap.add_argument("--wandb_project", type=str, default="ds-qwen3-lora") ap.add_argument("--eval_data_glob", type=str, default=None) ap.add_argument("--local_rank", type=int, default=-1) ap.add_argument("--per_device_eval_batch_size", type=int, default=1) ap.add_argument("--deepspeed", type=str, default=None) # ---- LoRA specific ---- ap.add_argument("--lora_r", type=int, default=16) ap.add_argument("--lora_alpha", type=float, default=32) ap.add_argument("--lora_dropout", type=float, default=0.05) ap.add_argument("--lora_target", type=str, default="auto", help='逗号分隔,如 "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj";或 "auto"') ap.add_argument("--qlora", action="store_true", help="使用 4bit (NF4) QLoRA(多机 DS 不建议)") ap.add_argument("--merge_lora_and_save", action="store_true", help="训练后在 rank0 合并 LoRA 到基座并另存(注意显存/内存占用)") return ap.parse_args() # ---------- LoRA helpers ---------- def _auto_lora_targets(model) -> List[str]: """ 针对 Qwen/Llama 族,自动挑选常见的线性层名字;仅匹配存在的模块。 """ cand = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj", "w1","w2","w3", "W_pack", "o_attn", "o_proj"] # 覆盖不同实现命名 present = set() for name, module in model.named_modules(): if any(name.endswith(f".{c}") or name == c for c in cand): present.add(name.split(".")[-1]) # 回落:若一个都没匹配到,使用“所有 nn.Linear” if not present: return ["all-linear"] # 去重且保序 order = [] for c in cand: if c in present: order.append(c) return order # ---------- main ---------- def main(): args = parse_args() if os.environ.get("RANK","0") != "0" and args.report_to == "wandb": print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True) args.report_to = "none" set_seed(args.seed) # DeepSpeed enable? use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed)) dschf = None if use_ds: try: from transformers.integrations.deepspeed import HfDeepSpeedConfig src = "transformers.integrations.deepspeed" except Exception: from transformers import HfDeepSpeedConfig src = "transformers" dschf = HfDeepSpeedConfig(args.deepspeed) print(f"[dbg] HfDeepSpeedConfig loaded from {src}", flush=True) if args.report_to == "wandb": os.environ.setdefault("WANDB_PROJECT", args.wandb_project) import transformers as hf try: import deepspeed as ds ds_ver = ds.__version__ except Exception: ds_ver = "n/a" def dbg(msg): print(f"[dbg][host={socket.gethostname()} RANK={os.environ.get('RANK','0')} " f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}", flush=True) dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}") dbg(f"args={args}") dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % ( os.environ.get("WORLD_SIZE"), os.environ.get("RANK"), os.environ.get("LOCAL_RANK", str(args.local_rank)), os.environ.get("MASTER_ADDR"), os.environ.get("MASTER_PORT"), os.environ.get("CUDA_VISIBLE_DEVICES"), )) dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}") # init dist world_size = int(os.environ.get("WORLD_SIZE", "1")) rank = int(os.environ.get("RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank))) if torch.cuda.is_available() and local_rank >= 0: torch.cuda.set_device(local_rank) dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} " f"name={torch.cuda.get_device_name(torch.cuda.current_device())}") if world_size > 1 and dist.is_available() and not dist.is_initialized(): backend = "nccl" if torch.cuda.is_available() else "gloo" dbg(f"init_process_group backend={backend} via env://") dist.init_process_group(backend=backend, init_method="env://") # tokenizer tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token try: if getattr(tokenizer, "padding_side", None) != "left": tokenizer.padding_side = "left" except Exception: pass from transformers import PreTrainedTokenizerFast if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False): raise RuntimeError("需要 Fast tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14。") tokenizer.model_max_length = args.seq_len dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} model_max_length={tokenizer.model_max_length}") # dtype def _bf16_supported(): if not torch.cuda.is_available(): return False if hasattr(torch.cuda, "is_bf16_supported"): return torch.cuda.is_bf16_supported() major, minor = torch.cuda.get_device_capability() return (major, minor) >= (8, 0) use_bf16 = bool(args.bf16 and _bf16_supported()) compute_dtype = torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32) # -------- load base model (with/without 4bit) -------- quantization_config = None if args.qlora: try: from transformers import BitsAndBytesConfig from peft import prepare_model_for_kbit_training except Exception as e: raise RuntimeError("使用 --qlora 需要安装 bitsandbytes>=0.41 与 peft。") from e quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=compute_dtype ) # 4bit 下不要传 attn_implementation="sdpa" 给部分旧版 torch model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=compute_dtype, trust_remote_code=True, low_cpu_mem_usage=True, quantization_config=quantization_config, device_map=None # 用 DeepSpeed/Trainer 接管 ) if args.gradient_checkpointing: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) else: model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=compute_dtype, low_cpu_mem_usage=True, trust_remote_code=True, attn_implementation="sdpa", ) if args.gradient_checkpointing: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) model.config.pad_token_id = tokenizer.pad_token_id model.config.use_cache = False # -------- wrap with LoRA -------- from peft import LoraConfig, get_peft_model, TaskType, PeftModel if args.lora_target.strip().lower() == "auto": targets = _auto_lora_targets(model) else: targets = [x.strip() for x in args.lora_target.split(",") if x.strip()] if not targets: targets = _auto_lora_targets(model) lora_cfg = LoraConfig( task_type=TaskType.CAUSAL_LM, r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, target_modules=targets, bias="none", inference_mode=False ) model = get_peft_model(model, lora_cfg) # 冻结确认 if is_main_process(): try: model.print_trainable_parameters() except Exception: trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) print(f"[LoRA] trainable={trainable:,} / total={total:,} ({trainable/total:.2%})", flush=True) # -------- data streams -------- files = sorted(glob.glob(args.data_glob)) if len(files) == 0: raise FileNotFoundError(f"No files matched DATA_GLOB={args.data_glob}") if is_main_process(): print(f"[data] matched {len(files)} files, example[0]={files[0]}", flush=True) ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def ex_iter_probe(): for ex in ds_stream_probe: yield ex train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len) try: _ = next(iter(train_stream_probe)) except StopIteration: raise RuntimeError("[data] 样本结构不合法或全部被裁切。") ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) def has_at_least(stream, n: int): it = iter(stream) for _ in range(n): try: next(it) except StopIteration: return 0 return 1 need = max(1, args.gradient_accumulation_steps) local_ok = has_at_least(probe_stream, need) if dist.is_available() and dist.is_initialized(): t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank>=0 else "cpu")) dist.all_reduce(t, op=dist.ReduceOp.MIN) if t.item() == 0: if is_main_process(): print(f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批。", flush=True) dist.barrier(); sys.exit(2) else: if local_ok == 0: print(f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批。", flush=True) sys.exit(2) # eval eval_dataset: Optional[Dataset] = None class ListDataset(Dataset): def __init__(self, items): self.items = items def __len__(self): return len(self.items) def __getitem__(self, idx): return self.items[idx] if args.eval_data_glob: eval_files = sorted(glob.glob(args.eval_data_glob)) if len(eval_files) == 0: raise FileNotFoundError(f"No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}") if is_main_process(): print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True) ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True) def ex_iter_eval(): for ex in ds_eval_stream: yield ex eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len) eval_items: List[Dict[str, torch.Tensor]] = [s for s in eval_iterable] if len(eval_items) == 0: raise RuntimeError("[eval] 读到了 0 条有效样本。") eval_dataset = ListDataset(eval_items) # pad to global batch size ws = max(int(os.environ.get("WORLD_SIZE","1")), 1) be = max(1, args.per_device_eval_batch_size) global_bs = ws * be r = len(eval_dataset) % global_bs if r != 0: pad_need = global_bs - r eval_dataset.items += eval_dataset.items[:pad_need] if is_main_process(): print(f"[eval] padded eval set to {len(eval_dataset)}", flush=True) # collator data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len) # training args os.makedirs(args.output_dir, exist_ok=True) logging_dir = os.path.join(args.output_dir, "logs"); os.makedirs(logging_dir, exist_ok=True) ta_kwargs = {} sig = inspect.signature(TrainingArguments.__init__).parameters if eval_dataset is not None: if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "steps" elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "steps" else: if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "no" elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "no" ta_kwargs2 = dict( output_dir=args.output_dir, logging_dir=logging_dir, do_train=True, do_eval=(eval_dataset is not None), eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, weight_decay=args.weight_decay, warmup_ratio=args.warmup_ratio, num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0, max_steps=args.max_steps if args.max_steps > 0 else -1, lr_scheduler_type="cosine", logging_steps=args.log_interval, save_steps=args.save_steps, save_total_limit=2, deepspeed=(args.deepspeed if use_ds else None), dataloader_drop_last=False, dataloader_num_workers=0, per_device_eval_batch_size=args.per_device_eval_batch_size, report_to=([] if args.report_to == "none" else [args.report_to]), gradient_checkpointing=args.gradient_checkpointing, remove_unused_columns=False, save_on_each_node=True, logging_first_step=True, **ta_kwargs, ) # 精度:QLoRA/LoRA 均按 compute_dtype 设置 if "dataloader_pin_memory" in sig: ta_kwargs2["dataloader_pin_memory"] = False if "torch_compile" in sig: ta_kwargs2["torch_compile"] = False ta_kwargs2.update({ "bf16": (compute_dtype==torch.bfloat16), "fp16": (compute_dtype==torch.float16), }) training_args = TrainingArguments(**ta_kwargs2) # pass tokenizer / processing_class trainer_kwargs = {} if "processing_class" in inspect.signature(Trainer.__init__).parameters: trainer_kwargs["processing_class"] = tokenizer else: trainer_kwargs["tokenizer"] = tokenizer trainer = DebugTrainer( model=model, args=training_args, train_dataset=train_stream, eval_dataset=eval_dataset, data_collator=data_collator, **trainer_kwargs ) trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) # resume (per-node local checkpoint agreement) def last_step(path: str) -> int: ck = get_last_checkpoint(path) if ck is None: return -1 base = os.path.basename(ck) try: return int(base.split("-")[-1]) except Exception: return -1 local_last = last_step(args.output_dir) device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank>=0) else "cpu") resume_flag = None if dist.is_available() and dist.is_initialized(): has_local = torch.tensor(1 if local_last >= 0 else 0, device=device) dist.all_reduce(has_local, op=dist.ReduceOp.MIN) if has_local.item() == 1: ts = torch.tensor(local_last, device=device) world = dist.get_world_size() buf = [torch.zeros_like(ts) for _ in range(world)] dist.all_gather(buf, ts) steps = [b.item() for b in buf] k = min(steps) if k >= 0: resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}") if is_main_process(): print(f"[resume] steps={steps}, resume={resume_flag}", flush=True) else: if local_last >= 0: resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}") print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}") if dist.is_available() and dist.is_initialized(): present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device) dist.all_reduce(present, op=dist.ReduceOp.MIN) if present.item() == 0: if is_main_process(): print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True) resume_flag = None dist.barrier() else: if resume_flag is not None and not os.path.isdir(resume_flag): print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True) resume_flag = None print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}") print_once("***** Starting LoRA training *****") print(f"[dbg] allocated={torch.cuda.memory_allocated()/1024**2:.1f} MB, " f"reserved={torch.cuda.memory_reserved()/1024**2:.1f} MB", flush=True) train_result = trainer.train(resume_from_checkpoint=resume_flag) # save adapter (not the full base) trainer.save_model() # 对 PeftModel:只保存 adapter 权重到 output_dir metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() # eval if eval_dataset is not None: print_once("***** Running eval *****") eval_metrics = trainer.evaluate() trainer.log_metrics("eval", eval_metrics) trainer.save_metrics("eval", eval_metrics) # optional merge if args.merge_lora_and_save and is_main_process(): print("[merge] Merging LoRA into base model ...", flush=True) try: if isinstance(trainer.model, PeftModel): merged = trainer.model.merge_and_unload() else: merged = trainer.model merge_dir = os.path.join(args.output_dir, "merged-full-model") os.makedirs(merge_dir, exist_ok=True) merged.save_pretrained(merge_dir, safe_serialization=True) tokenizer.save_pretrained(merge_dir) print(f"[merge] Saved merged model to: {merge_dir}", flush=True) except Exception as e: print(f"[merge] FAILED: {e}", flush=True) print_once("Done.") if __name__ == "__main__": main()