From e22e569303e449b43734663f5757d8ef36262c38 Mon Sep 17 00:00:00 2001 From: hailin Date: Mon, 8 Sep 2025 09:24:40 +0800 Subject: [PATCH] . --- train_mm_zero3_lora.sh | 16 + train_sft_ds.py | 275 ++++++++++++++-- train_sft_lora.py | 730 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 988 insertions(+), 33 deletions(-) create mode 100644 train_mm_zero3_lora.sh create mode 100644 train_sft_lora.py diff --git a/train_mm_zero3_lora.sh b/train_mm_zero3_lora.sh new file mode 100644 index 0000000..6d5ab5f --- /dev/null +++ b/train_mm_zero3_lora.sh @@ -0,0 +1,16 @@ +deepspeed --hostfile hostfile \ + --num_nodes 6 --num_gpus 4 \ + train_sft_lora.py \ + --model_name_or_path /home/test/Qwen3-32B \ + --data_glob "/home/test/datasets/my_corpus/train*.jsonl" \ + --output_dir /home/test/checkpoints/q3-32b-lora \ + --seq_len 4096 \ + --bf16 \ + --gradient_accumulation_steps 64 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-4 \ + --warmup_ratio 0.03 \ + --lora_r 16 --lora_alpha 32 --lora_dropout 0.05 \ + --lora_target auto \ + --deepspeed /home/test/jd_train/ds_config_zero3.json \ + --report_to wandb --wandb_project ds-qwen3-lora diff --git a/train_sft_ds.py b/train_sft_ds.py index 91b3c9b..43a7f4d 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -194,10 +194,174 @@ class CsvLossLogger(TrainerCallback): f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n" ) -# ----------------- 仅监督 assistant 的数据集 ----------------- +# # ----------------- 仅监督 assistant 的数据集 ----------------- +# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: +# """ +# 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。 +# """ +# spans: List[Tuple[int, int]] = [] +# open_tag = "<|im_start|>assistant\n" +# close_tag = "<|im_end|>\n" +# pos = 0 +# while True: +# a = rendered.find(open_tag, pos) +# if a == -1: +# break +# start = a + len(open_tag) +# b = rendered.find(close_tag, start) +# if b == -1: +# break +# spans.append((start, b)) +# pos = b + len(close_tag) +# return spans + +# class QwenChatSFTDataset(IterableDataset): +# """ +# 期望 jsonl 每行形如: +# {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]} +# 可选包含工具: +# {"messages":[...], "tools":[{...}]} + +# 工作流: +# - 使用 tokenizer.apply_chat_template 渲染 +# - 仅对 assistant 片段计损失(其他 token 的 label = -100) +# - 超长序列保留尾部(通常包含回答) +# """ +# def __init__(self, +# ex_iter: Iterable[dict], +# tokenizer: AutoTokenizer, +# seq_len: int = 4096): +# self.ex_iter = ex_iter +# self.tok = tokenizer +# self.seq_len = seq_len + +# def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + +# # >>> DEBUG BEGIN +# dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1" +# if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0 +# dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3")) +# rank = int(os.environ.get("RANK", "0")) +# lrank = int(os.environ.get("LOCAL_RANK", "-1")) +# host = socket.gethostname() +# # >>> DEBUG END + +# for ex in self.ex_iter: +# msgs = ex.get("messages", None) +# if not msgs or not isinstance(msgs, list): +# continue + +# # 可选过滤 think +# bad = False +# for m in msgs: +# if m.get("role") == "assistant" and isinstance(m.get("content"), str): +# c = m["content"] +# if "" in c and "" in c: +# inner = c.split("")[-1].split("")[0].strip() +# if inner: +# bad = True; break +# # 注销这里就可以确保参与计算监督微调,打开就表示跳过 +# if bad: +# continue + +# tools = ex.get("tools", None) + +# # 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况 +# try: +# rendered: str = self.tok.apply_chat_template( +# msgs, tools=tools, add_generation_prompt=False, tokenize=False +# ) +# except TypeError: +# rendered: str = self.tok.apply_chat_template( +# msgs, add_generation_prompt=False, tokenize=False +# ) + + +# if not isinstance(rendered, str) or not rendered.strip(): +# continue + +# spans = _assistant_char_spans(rendered) +# if not spans: +# continue + +# enc = self.tok( +# rendered, +# add_special_tokens=False, +# return_offsets_mapping=True +# ) +# input_ids: List[int] = enc["input_ids"] +# offsets: List[Tuple[int, int]] = enc["offset_mapping"] + +# if not input_ids: +# continue + +# labels = [-100] * len(input_ids) + +# def in_any_span(lo: int, hi: int) -> bool: +# for s, e in spans: +# if not (hi <= s or lo >= e): +# return True +# return False + +# for i, (lo, hi) in enumerate(offsets): +# if in_any_span(lo, hi): +# labels[i] = input_ids[i] + +# # —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len —— +# # 1) 截断到 seq_len(保留尾部) +# if len(input_ids) > self.seq_len: +# input_ids = input_ids[-self.seq_len:] +# labels = labels[-self.seq_len:] + +# # 2) 左侧补齐到 seq_len(保证所有样本长度一致) +# pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id +# L = len(input_ids) +# if L < self.seq_len: +# pad = self.seq_len - L +# input_ids = ([pad_id] * pad) + input_ids +# labels = ([-100] * pad) + labels +# attn_mask = [0] * pad + [1] * L +# else: +# # 恰好等于 seq_len +# attn_mask = [1] * self.seq_len + +# # 若没有任何可训练 token(labels 全 -100),跳过 +# if all(v == -100 for v in labels): +# continue + +# assert len(input_ids) == self.seq_len +# assert len(labels) == self.seq_len +# assert len(attn_mask) == self.seq_len + +# # >>> DEBUG PRINT(此时变量已定义) +# if dbg_on and self._dbg_seen < dbg_limit: +# sup_tok = sum(1 for v in labels if v != -100) +# print( +# f"[sample][host={host} RANK={rank} LRank={lrank}] " +# f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} " +# f"seq_len={self.seq_len} pad_id={pad_id}", +# flush=True +# ) +# if sup_tok == 0: +# print( +# f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped", +# flush=True +# ) +# self._dbg_seen += 1 +# # <<< DEBUG PRINT + +# yield { +# "input_ids": torch.tensor(input_ids, dtype=torch.long), +# "attention_mask": torch.tensor(attn_mask, dtype=torch.long), +# "labels": torch.tensor(labels, dtype=torch.long), +# } + + +# ----------------- 工具:提取 assistant 字符区间 ----------------- def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: """ - 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。 + 在 apply_chat_template 渲染后的纯文本中,返回所有 assistant 段的字符区间 [start, end) + 这些区间覆盖了 assistant 的全部内容(包括 ... 标签与正文)。 """ spans: List[Tuple[int, int]] = [] open_tag = "<|im_start|>assistant\n" @@ -207,14 +371,16 @@ def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: a = rendered.find(open_tag, pos) if a == -1: break - start = a + len(open_tag) - b = rendered.find(close_tag, start) + s = a + len(open_tag) + b = rendered.find(close_tag, s) if b == -1: break - spans.append((start, b)) + spans.append((s, b)) pos = b + len(close_tag) return spans + +# ----------------- 数据集:SFT(监督 assistant 全段,含 标签与内容) ----------------- class QwenChatSFTDataset(IterableDataset): """ 期望 jsonl 每行形如: @@ -225,7 +391,7 @@ class QwenChatSFTDataset(IterableDataset): 工作流: - 使用 tokenizer.apply_chat_template 渲染 - 仅对 assistant 片段计损失(其他 token 的 label = -100) - - 超长序列保留尾部(通常包含回答) + - 截断时“优先确保最后一个 assistant 不被截断”;若其长度 > seq_len,则保留其“结尾”以避免切尾 """ def __init__(self, ex_iter: Iterable[dict], @@ -251,18 +417,7 @@ class QwenChatSFTDataset(IterableDataset): if not msgs or not isinstance(msgs, list): continue - # 可选过滤 think - bad = False - for m in msgs: - if m.get("role") == "assistant" and isinstance(m.get("content"), str): - c = m["content"] - if "" in c and "" in c: - inner = c.split("")[-1].split("")[0].strip() - if inner: - bad = True; break - if bad: - continue - + # —— 不再过滤 :显式允许其参与监督(包括标签与正文) tools = ex.get("tools", None) # 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况 @@ -275,7 +430,6 @@ class QwenChatSFTDataset(IterableDataset): msgs, add_generation_prompt=False, tokenize=False ) - if not isinstance(rendered, str) or not rendered.strip(): continue @@ -283,6 +437,7 @@ class QwenChatSFTDataset(IterableDataset): if not spans: continue + # 编码并拿到字符偏移,确保与 rendered 对齐 enc = self.tok( rendered, add_special_tokens=False, @@ -294,10 +449,12 @@ class QwenChatSFTDataset(IterableDataset): if not input_ids: continue + # 先对“所有 assistant 片段”打标签;包含 标签与内容、以及回答正文 labels = [-100] * len(input_ids) def in_any_span(lo: int, hi: int) -> bool: for s, e in spans: + # 与任一 [s, e) 有交集即监督 if not (hi <= s or lo >= e): return True return False @@ -306,13 +463,68 @@ class QwenChatSFTDataset(IterableDataset): if in_any_span(lo, hi): labels[i] = input_ids[i] - # —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len —— - # 1) 截断到 seq_len(保留尾部) - if len(input_ids) > self.seq_len: - input_ids = input_ids[-self.seq_len:] - labels = labels[-self.seq_len:] + # 若没有任何可训练 token(labels 全 -100),跳过 + if all(v == -100 for v in labels): + continue - # 2) 左侧补齐到 seq_len(保证所有样本长度一致) + # ======== Assistant 感知的截断策略(保证“最后一个 assistant 不被截掉”)======== + if len(input_ids) > self.seq_len: + # 取“最后一个 assistant”的字符区间 + s_last, e_last = spans[-1] + + # 将字符区间映射到 token 索引区间 [j, k_excl) + # j: 第一个 token,其右端 hi > s_last + j = 0 + while j < len(offsets) and offsets[j][1] <= s_last: + j += 1 + # k_excl: 第一个 token,其左端 lo >= e_last(即不再与 [s_last, e_last) 相交) + k_excl = j + while k_excl < len(offsets) and offsets[k_excl][0] < e_last: + k_excl += 1 + + A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数 + + if A >= self.seq_len: + # 单个 assistant 本身超过窗口 —— 保“结尾”,避免被切尾 + start = max(0, k_excl - self.seq_len) + end = start + self.seq_len + else: + # 有空间容纳整个 assistant:尽量把窗口对齐到包括完整 assistant + # 先试图把窗口从 j 开始,但要保证 k_excl 也在窗口内 + start = max(0, min(j, len(input_ids) - self.seq_len)) + end = start + self.seq_len + if end < k_excl: + # 还没覆盖到 assistant 末尾,则右移窗口到恰好覆盖末尾 + end = k_excl + start = end - self.seq_len + if start < 0: + start = 0 + end = self.seq_len + + # 可选:尝试“居中”一点(留部分历史上下文),但仍需包含完整 [j, k_excl) + leftover = self.seq_len - A + # 把剩余的一半尽量分配给左侧上下文(不越界) + left_wish = leftover // 2 + start = max(0, min(j - left_wish, start)) + end = start + self.seq_len + if end < k_excl: + # 若居中导致末尾又被排除,再纠正一次 + end = k_excl + start = end - self.seq_len + if start < 0: + start = 0 + end = self.seq_len + + # 真正切片 + input_ids = input_ids[start:end] + labels = labels[start:end] + # 注意:offsets 后续不再使用(只为确定切片窗口),无需同步裁剪 + + # 训练注意:这里的策略保证: + # - 若最后一个 assistant <= seq_len:完整保留; + # - 若 > seq_len:至少保证 assistant 的“结尾”在窗口内,不会“切尾”。 + + # ======== 统一长度:左侧补齐到 seq_len ======== pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id L = len(input_ids) if L < self.seq_len: @@ -321,24 +533,19 @@ class QwenChatSFTDataset(IterableDataset): labels = ([-100] * pad) + labels attn_mask = [0] * pad + [1] * L else: - # 恰好等于 seq_len attn_mask = [1] * self.seq_len - # 若没有任何可训练 token(labels 全 -100),跳过 - if all(v == -100 for v in labels): - continue - + # Sanity assert len(input_ids) == self.seq_len assert len(labels) == self.seq_len assert len(attn_mask) == self.seq_len - # >>> DEBUG PRINT(此时变量已定义) + # >>> DEBUG PRINT if dbg_on and self._dbg_seen < dbg_limit: sup_tok = sum(1 for v in labels if v != -100) print( f"[sample][host={host} RANK={rank} LRank={lrank}] " - f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} " - f"seq_len={self.seq_len} pad_id={pad_id}", + f"toks={len(input_ids)} sup_toks={sup_tok} seq_len={self.seq_len} pad_id={pad_id}", flush=True ) if sup_tok == 0: @@ -355,6 +562,8 @@ class QwenChatSFTDataset(IterableDataset): "labels": torch.tensor(labels, dtype=torch.long), } + + # ----------------- 专用 Collator:pad inputs, pad labels=-100 ----------------- class SFTDataCollator: def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None): diff --git a/train_sft_lora.py b/train_sft_lora.py new file mode 100644 index 0000000..c6fa521 --- /dev/null +++ b/train_sft_lora.py @@ -0,0 +1,730 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +os.environ.pop("PYTHONNOUSERSITE", None) +os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") +os.environ.setdefault("WANDB_START_METHOD", "thread") +os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb") + +import glob +import socket +import argparse +import inspect +import sys +from typing import Dict, List, Iterable, Iterator, Tuple, Optional + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset, Dataset +from datasets import load_dataset +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + TrainingArguments, + Trainer, + set_seed +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import get_last_checkpoint + +# ---------- PATH / CUDA utils ---------- +import site, shutil +home = os.path.expanduser("~") +want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"] +cur = os.environ.get("PATH", "").split(":") +new = [d for d in want if d and d not in cur] + cur +os.environ["PATH"] = ":".join(new) +print(f"[env] PATH={os.environ['PATH']}", flush=True) +print(f"[env] which ninja={shutil.which('ninja')} which nvcc={shutil.which('nvcc')}", flush=True) + +os.environ.setdefault("CUDA_HOME", "/usr/local/cuda-11.8") +ld = os.environ.get("LD_LIBRARY_PATH", "") +cuda_lib = "/usr/local/cuda-11.8/lib64" +if cuda_lib not in ld.split(":"): + os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib + +print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True) + +os.environ.pop("DS_BUILD_OPS", None) +os.environ.pop("DS_SKIP_CUDA_BUILD", None) +try: + user_site = site.getusersitepackages() + if user_site and user_site not in sys.path: + sys.path.insert(0, user_site) +except Exception: + pass +os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext") +os.environ.setdefault("MAX_JOBS", "12") +if shutil.which("ninja") is None: + os.environ["USE_NINJA"] = "0" + print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True) +try: + from deepspeed.ops.op_builder import CPUAdamBuilder + CPUAdamBuilder().load() + print("[env] CPUAdamBuilder JIT OK", flush=True) +except Exception as e: + if "Ninja is required to load C++ extensions" in str(e): + os.environ["USE_NINJA"] = "0" + print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True) + from deepspeed.ops.op_builder import CPUAdamBuilder + CPUAdamBuilder().load() + print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True) + else: + print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True) + # 不致命:LoRA 不依赖这个算子,继续运行 + pass + +# ---------- helpers ---------- +def is_main_process(): + return int(os.environ.get("RANK", "0")) == 0 + +def print_once(*args, **kwargs): + if is_main_process(): + print(*args, **kwargs, flush=True) + +class DebugTrainer(Trainer): + def training_step(self, model, inputs, num_items_in_batch=None): + if not hasattr(self, "_dbg_printed"): + rank = int(os.environ.get("RANK", "0")) + host = socket.gethostname() + ids = inputs["input_ids"]; msk = inputs["attention_mask"]; labs = inputs["labels"] + print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} " + f"supervised={(labs!=-100).sum().item()}", flush=True) + print(f"[step0][host={host} RANK={rank}] " + f"input_ids.shape={tuple(ids.shape)} " + f"attention_mask.shape={tuple(msk.shape)} " + f"labels.shape={tuple(labs.shape)} " + f"num_items_in_batch={num_items_in_batch}", flush=True) + self._dbg_printed = True + return super().training_step(model, inputs, num_items_in_batch) + +class CsvLossLogger(TrainerCallback): + def __init__(self, csv_path: str): + self.csv_path = csv_path + if is_main_process(): + os.makedirs(os.path.dirname(csv_path), exist_ok=True) + with open(self.csv_path, "w", encoding="utf-8") as f: + f.write("step,loss,lr,total_flos\n") + + def on_train_begin(self, args, state, control, **kwargs): + tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) + tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 + print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True) + + def on_log(self, args, state, control, logs=None, **kwargs): + if logs is None: return + cur = int(getattr(state, "global_step", 0) or 0) + tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) + tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 + pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a") + if tot and not hasattr(self, "_tot_announced"): + print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True) + self._tot_announced = True + print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] step {cur}/{tot} ({pct}) " + f"loss={logs.get('loss')} lr={logs.get('learning_rate')}", flush=True) + if not is_main_process(): return + with open(self.csv_path, "a", encoding="utf-8") as f: + f.write(f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n") + +# ---------- assistant span detection ---------- +def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: + spans: List[Tuple[int, int]] = [] + open_tag = "<|im_start|>assistant\n" + close_tag = "<|im_end|>\n" + pos = 0 + while True: + a = rendered.find(open_tag, pos) + if a == -1: break + s = a + len(open_tag) + b = rendered.find(close_tag, s) + if b == -1: break + spans.append((s, b)) + pos = b + len(close_tag) + return spans + +# ---------- Dataset (supervise assistant incl. tags) ---------- +class QwenChatSFTDataset(IterableDataset): + def __init__(self, ex_iter: Iterable[dict], tokenizer: AutoTokenizer, seq_len: int = 4096): + self.ex_iter = ex_iter + self.tok = tokenizer + self.seq_len = seq_len + + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1" + if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0 + dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3")) + rank = int(os.environ.get("RANK", "0")) + lrank = int(os.environ.get("LOCAL_RANK", "-1")) + host = socket.gethostname() + + for ex in self.ex_iter: + msgs = ex.get("messages", None) + if not msgs or not isinstance(msgs, list): continue + + tools = ex.get("tools", None) + try: + rendered: str = self.tok.apply_chat_template( + msgs, tools=tools, add_generation_prompt=False, tokenize=False + ) + except TypeError: + rendered: str = self.tok.apply_chat_template( + msgs, add_generation_prompt=False, tokenize=False + ) + if not isinstance(rendered, str) or not rendered.strip(): continue + + spans = _assistant_char_spans(rendered) + if not spans: continue + + enc = self.tok(rendered, add_special_tokens=False, return_offsets_mapping=True) + input_ids: List[int] = enc["input_ids"] + offsets: List[Tuple[int, int]] = enc["offset_mapping"] + if not input_ids: continue + + labels = [-100] * len(input_ids) + + def in_any_span(lo: int, hi: int) -> bool: + for s, e in spans: + if not (hi <= s or lo >= e): + return True + return False + + for i, (lo, hi) in enumerate(offsets): + if in_any_span(lo, hi): + labels[i] = input_ids[i] + + if all(v == -100 for v in labels): # 无监督 token + continue + + # ---- assistant-aware truncation: keep last assistant not cut off + if len(input_ids) > self.seq_len: + s_last, e_last = spans[-1] + j = 0 + while j < len(offsets) and offsets[j][1] <= s_last: j += 1 + k_excl = j + while k_excl < len(offsets) and offsets[k_excl][0] < e_last: k_excl += 1 + A = max(0, k_excl - j) + if A >= self.seq_len: + start = max(0, k_excl - self.seq_len); end = start + self.seq_len + else: + start = max(0, min(j, len(input_ids) - self.seq_len)) + end = start + self.seq_len + if end < k_excl: + end = k_excl; start = end - self.seq_len + if start < 0: start = 0; end = self.seq_len + leftover = self.seq_len - A + left_wish = leftover // 2 + start = max(0, min(j - left_wish, start)) + end = start + self.seq_len + if end < k_excl: + end = k_excl; start = end - self.seq_len + if start < 0: start = 0; end = self.seq_len + input_ids = input_ids[start:end] + labels = labels[start:end] + + pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id + L = len(input_ids) + if L < self.seq_len: + pad = self.seq_len - L + input_ids = ([pad_id]*pad) + input_ids + labels = ([-100]*pad) + labels + attn_mask = [0]*pad + [1]*L + else: + attn_mask = [1]*self.seq_len + + assert len(input_ids) == self.seq_len + assert len(labels) == self.seq_len + assert len(attn_mask) == self.seq_len + + if dbg_on and self._dbg_seen < dbg_limit: + sup_tok = sum(1 for v in labels if v != -100) + print(f"[sample][host={host} RANK={rank} LRank={lrank}] " + f"toks={len(input_ids)} sup_toks={sup_tok} seq_len={self.seq_len} pad_id={pad_id}", flush=True) + if sup_tok == 0: + print(f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> skipped", flush=True) + self._dbg_seen += 1 + + yield { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "attention_mask": torch.tensor(attn_mask, dtype=torch.long), + "labels": torch.tensor(labels, dtype=torch.long), + } + +# ---------- Collator ---------- +class SFTDataCollator: + def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None): + self.tok = tokenizer + self.pad_to_length = pad_to_length + assert self.tok.pad_token_id is not None + + def __call__(self, features): + if not features: + raise RuntimeError(f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator.") + def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) + input_ids = [_to_list(f["input_ids"]) for f in features] + attn_masks = [_to_list(f["attention_mask"]) for f in features] + labels_list = [_to_list(f["labels"]) for f in features] + max_len_in_batch = max(len(x) for x in input_ids) + target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch + pad_id = self.tok.pad_token_id + batch_inp, batch_attn, batch_lab = [], [], [] + for inp, msk, lab in zip(input_ids, attn_masks, labels_list): + pad_len = target_len - len(inp) + if pad_len < 0: + inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len] + pad_len = 0 + batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long)) + batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long)) + batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long)) + if os.environ.get("DBG_COLLATE","0") == "1": + print(f"[collate][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] " + f"features={len(features)} target_len={target_len}", flush=True) + return { + "input_ids": torch.stack(batch_inp, dim=0), + "attention_mask": torch.stack(batch_attn, dim=0), + "labels": torch.stack(batch_lab, dim=0), + } + +# ---------- Args ---------- +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("--model_name_or_path", type=str, required=True) + ap.add_argument("--data_glob", type=str, required=True) + ap.add_argument("--output_dir", type=str, required=True) + ap.add_argument("--seq_len", type=int, default=4096) + ap.add_argument("--learning_rate", type=float, default=1e-4) # LoRA 通常可更大学习率 + ap.add_argument("--weight_decay", type=float, default=0.0) # LoRA 常设 0 或很小 + ap.add_argument("--warmup_ratio", type=float, default=0.03) + ap.add_argument("--num_train_epochs", type=float, default=1.0) + ap.add_argument("--max_steps", type=int, default=-1) + ap.add_argument("--log_interval", type=int, default=10) + ap.add_argument("--save_steps", type=int, default=500) + ap.add_argument("--eval_ratio", type=float, default=0.0) + ap.add_argument("--seed", type=int, default=1337) + ap.add_argument("--gradient_checkpointing", action="store_true") + ap.add_argument("--bf16", action="store_true") + ap.add_argument("--per_device_train_batch_size", type=int, default=1) + ap.add_argument("--gradient_accumulation_steps", type=int, default=64) + ap.add_argument("--report_to", type=str, default="tensorboard", choices=["none","tensorboard","wandb"]) + ap.add_argument("--wandb_project", type=str, default="ds-qwen3-lora") + ap.add_argument("--eval_data_glob", type=str, default=None) + ap.add_argument("--local_rank", type=int, default=-1) + ap.add_argument("--per_device_eval_batch_size", type=int, default=1) + ap.add_argument("--deepspeed", type=str, default=None) + + # ---- LoRA specific ---- + ap.add_argument("--lora_r", type=int, default=16) + ap.add_argument("--lora_alpha", type=float, default=32) + ap.add_argument("--lora_dropout", type=float, default=0.05) + ap.add_argument("--lora_target", type=str, default="auto", + help='逗号分隔,如 "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj";或 "auto"') + + ap.add_argument("--qlora", action="store_true", help="使用 4bit (NF4) QLoRA(多机 DS 不建议)") + ap.add_argument("--merge_lora_and_save", action="store_true", + help="训练后在 rank0 合并 LoRA 到基座并另存(注意显存/内存占用)") + return ap.parse_args() + +# ---------- LoRA helpers ---------- +def _auto_lora_targets(model) -> List[str]: + """ + 针对 Qwen/Llama 族,自动挑选常见的线性层名字;仅匹配存在的模块。 + """ + cand = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj", + "w1","w2","w3", "W_pack", "o_attn", "o_proj"] # 覆盖不同实现命名 + present = set() + for name, module in model.named_modules(): + if any(name.endswith(f".{c}") or name == c for c in cand): + present.add(name.split(".")[-1]) + # 回落:若一个都没匹配到,使用“所有 nn.Linear” + if not present: + return ["all-linear"] + # 去重且保序 + order = [] + for c in cand: + if c in present: order.append(c) + return order + +# ---------- main ---------- +def main(): + args = parse_args() + + if os.environ.get("RANK","0") != "0" and args.report_to == "wandb": + print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True) + args.report_to = "none" + + set_seed(args.seed) + + # DeepSpeed enable? + use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed)) + dschf = None + if use_ds: + try: + from transformers.integrations.deepspeed import HfDeepSpeedConfig + src = "transformers.integrations.deepspeed" + except Exception: + from transformers import HfDeepSpeedConfig + src = "transformers" + dschf = HfDeepSpeedConfig(args.deepspeed) + print(f"[dbg] HfDeepSpeedConfig loaded from {src}", flush=True) + + if args.report_to == "wandb": + os.environ.setdefault("WANDB_PROJECT", args.wandb_project) + + import transformers as hf + try: + import deepspeed as ds + ds_ver = ds.__version__ + except Exception: + ds_ver = "n/a" + + def dbg(msg): + print(f"[dbg][host={socket.gethostname()} RANK={os.environ.get('RANK','0')} " + f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}", flush=True) + + dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}") + dbg(f"args={args}") + dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % ( + os.environ.get("WORLD_SIZE"), os.environ.get("RANK"), + os.environ.get("LOCAL_RANK", str(args.local_rank)), + os.environ.get("MASTER_ADDR"), os.environ.get("MASTER_PORT"), + os.environ.get("CUDA_VISIBLE_DEVICES"), + )) + dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}") + + # init dist + world_size = int(os.environ.get("WORLD_SIZE", "1")) + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank))) + if torch.cuda.is_available() and local_rank >= 0: + torch.cuda.set_device(local_rank) + dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} " + f"name={torch.cuda.get_device_name(torch.cuda.current_device())}") + if world_size > 1 and dist.is_available() and not dist.is_initialized(): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dbg(f"init_process_group backend={backend} via env://") + dist.init_process_group(backend=backend, init_method="env://") + + # tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + try: + if getattr(tokenizer, "padding_side", None) != "left": + tokenizer.padding_side = "left" + except Exception: + pass + + from transformers import PreTrainedTokenizerFast + if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False): + raise RuntimeError("需要 Fast tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14。") + tokenizer.model_max_length = args.seq_len + dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} model_max_length={tokenizer.model_max_length}") + + # dtype + def _bf16_supported(): + if not torch.cuda.is_available(): return False + if hasattr(torch.cuda, "is_bf16_supported"): + return torch.cuda.is_bf16_supported() + major, minor = torch.cuda.get_device_capability() + return (major, minor) >= (8, 0) + use_bf16 = bool(args.bf16 and _bf16_supported()) + compute_dtype = torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32) + + # -------- load base model (with/without 4bit) -------- + quantization_config = None + if args.qlora: + try: + from transformers import BitsAndBytesConfig + from peft import prepare_model_for_kbit_training + except Exception as e: + raise RuntimeError("使用 --qlora 需要安装 bitsandbytes>=0.41 与 peft。") from e + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=compute_dtype + ) + # 4bit 下不要传 attn_implementation="sdpa" 给部分旧版 torch + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + torch_dtype=compute_dtype, + trust_remote_code=True, + low_cpu_mem_usage=True, + quantization_config=quantization_config, + device_map=None # 用 DeepSpeed/Trainer 接管 + ) + if args.gradient_checkpointing: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) + else: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + torch_dtype=compute_dtype, + low_cpu_mem_usage=True, + trust_remote_code=True, + attn_implementation="sdpa", + ) + if args.gradient_checkpointing: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + model.config.pad_token_id = tokenizer.pad_token_id + model.config.use_cache = False + + # -------- wrap with LoRA -------- + from peft import LoraConfig, get_peft_model, TaskType, PeftModel + if args.lora_target.strip().lower() == "auto": + targets = _auto_lora_targets(model) + else: + targets = [x.strip() for x in args.lora_target.split(",") if x.strip()] + if not targets: + targets = _auto_lora_targets(model) + + lora_cfg = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=targets, + bias="none", + inference_mode=False + ) + model = get_peft_model(model, lora_cfg) + + # 冻结确认 + if is_main_process(): + try: + model.print_trainable_parameters() + except Exception: + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + print(f"[LoRA] trainable={trainable:,} / total={total:,} ({trainable/total:.2%})", flush=True) + + # -------- data streams -------- + files = sorted(glob.glob(args.data_glob)) + if len(files) == 0: + raise FileNotFoundError(f"No files matched DATA_GLOB={args.data_glob}") + if is_main_process(): + print(f"[data] matched {len(files)} files, example[0]={files[0]}", flush=True) + + ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + def ex_iter_probe(): + for ex in ds_stream_probe: yield ex + train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len) + try: + _ = next(iter(train_stream_probe)) + except StopIteration: + raise RuntimeError("[data] 样本结构不合法或全部被裁切。") + + ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) + + ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) + + def has_at_least(stream, n: int): + it = iter(stream) + for _ in range(n): + try: next(it) + except StopIteration: return 0 + return 1 + + need = max(1, args.gradient_accumulation_steps) + local_ok = has_at_least(probe_stream, need) + if dist.is_available() and dist.is_initialized(): + t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank>=0 else "cpu")) + dist.all_reduce(t, op=dist.ReduceOp.MIN) + if t.item() == 0: + if is_main_process(): + print(f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批。", flush=True) + dist.barrier(); sys.exit(2) + else: + if local_ok == 0: + print(f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批。", flush=True) + sys.exit(2) + + # eval + eval_dataset: Optional[Dataset] = None + class ListDataset(Dataset): + def __init__(self, items): self.items = items + def __len__(self): return len(self.items) + def __getitem__(self, idx): return self.items[idx] + + if args.eval_data_glob: + eval_files = sorted(glob.glob(args.eval_data_glob)) + if len(eval_files) == 0: + raise FileNotFoundError(f"No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}") + if is_main_process(): + print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True) + ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True) + def ex_iter_eval(): + for ex in ds_eval_stream: yield ex + eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len) + eval_items: List[Dict[str, torch.Tensor]] = [s for s in eval_iterable] + if len(eval_items) == 0: + raise RuntimeError("[eval] 读到了 0 条有效样本。") + eval_dataset = ListDataset(eval_items) + # pad to global batch size + ws = max(int(os.environ.get("WORLD_SIZE","1")), 1) + be = max(1, args.per_device_eval_batch_size) + global_bs = ws * be + r = len(eval_dataset) % global_bs + if r != 0: + pad_need = global_bs - r + eval_dataset.items += eval_dataset.items[:pad_need] + if is_main_process(): + print(f"[eval] padded eval set to {len(eval_dataset)}", flush=True) + + # collator + data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len) + + # training args + os.makedirs(args.output_dir, exist_ok=True) + logging_dir = os.path.join(args.output_dir, "logs"); os.makedirs(logging_dir, exist_ok=True) + + ta_kwargs = {} + sig = inspect.signature(TrainingArguments.__init__).parameters + if eval_dataset is not None: + if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "steps" + elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "steps" + else: + if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "no" + elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "no" + + ta_kwargs2 = dict( + output_dir=args.output_dir, + logging_dir=logging_dir, + do_train=True, + do_eval=(eval_dataset is not None), + eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None, + per_device_train_batch_size=args.per_device_train_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + warmup_ratio=args.warmup_ratio, + num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0, + max_steps=args.max_steps if args.max_steps > 0 else -1, + lr_scheduler_type="cosine", + logging_steps=args.log_interval, + save_steps=args.save_steps, + save_total_limit=2, + deepspeed=(args.deepspeed if use_ds else None), + dataloader_drop_last=False, + dataloader_num_workers=0, + per_device_eval_batch_size=args.per_device_eval_batch_size, + report_to=([] if args.report_to == "none" else [args.report_to]), + gradient_checkpointing=args.gradient_checkpointing, + remove_unused_columns=False, + save_on_each_node=True, + logging_first_step=True, + **ta_kwargs, + ) + # 精度:QLoRA/LoRA 均按 compute_dtype 设置 + if "dataloader_pin_memory" in sig: ta_kwargs2["dataloader_pin_memory"] = False + if "torch_compile" in sig: ta_kwargs2["torch_compile"] = False + ta_kwargs2.update({ + "bf16": (compute_dtype==torch.bfloat16), + "fp16": (compute_dtype==torch.float16), + }) + training_args = TrainingArguments(**ta_kwargs2) + + # pass tokenizer / processing_class + trainer_kwargs = {} + if "processing_class" in inspect.signature(Trainer.__init__).parameters: + trainer_kwargs["processing_class"] = tokenizer + else: + trainer_kwargs["tokenizer"] = tokenizer + + trainer = DebugTrainer( + model=model, + args=training_args, + train_dataset=train_stream, + eval_dataset=eval_dataset, + data_collator=data_collator, + **trainer_kwargs + ) + trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) + + # resume (per-node local checkpoint agreement) + def last_step(path: str) -> int: + ck = get_last_checkpoint(path) + if ck is None: return -1 + base = os.path.basename(ck) + try: return int(base.split("-")[-1]) + except Exception: return -1 + + local_last = last_step(args.output_dir) + device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank>=0) else "cpu") + resume_flag = None + if dist.is_available() and dist.is_initialized(): + has_local = torch.tensor(1 if local_last >= 0 else 0, device=device) + dist.all_reduce(has_local, op=dist.ReduceOp.MIN) + if has_local.item() == 1: + ts = torch.tensor(local_last, device=device) + world = dist.get_world_size() + buf = [torch.zeros_like(ts) for _ in range(world)] + dist.all_gather(buf, ts) + steps = [b.item() for b in buf] + k = min(steps) + if k >= 0: + resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}") + if is_main_process(): + print(f"[resume] steps={steps}, resume={resume_flag}", flush=True) + else: + if local_last >= 0: + resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}") + + print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}") + if dist.is_available() and dist.is_initialized(): + present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device) + dist.all_reduce(present, op=dist.ReduceOp.MIN) + if present.item() == 0: + if is_main_process(): + print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True) + resume_flag = None + dist.barrier() + else: + if resume_flag is not None and not os.path.isdir(resume_flag): + print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True) + resume_flag = None + + print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}") + print_once("***** Starting LoRA training *****") + print(f"[dbg] allocated={torch.cuda.memory_allocated()/1024**2:.1f} MB, " + f"reserved={torch.cuda.memory_reserved()/1024**2:.1f} MB", flush=True) + + train_result = trainer.train(resume_from_checkpoint=resume_flag) + + # save adapter (not the full base) + trainer.save_model() # 对 PeftModel:只保存 adapter 权重到 output_dir + metrics = train_result.metrics + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # eval + if eval_dataset is not None: + print_once("***** Running eval *****") + eval_metrics = trainer.evaluate() + trainer.log_metrics("eval", eval_metrics) + trainer.save_metrics("eval", eval_metrics) + + # optional merge + if args.merge_lora_and_save and is_main_process(): + print("[merge] Merging LoRA into base model ...", flush=True) + try: + if isinstance(trainer.model, PeftModel): + merged = trainer.model.merge_and_unload() + else: + merged = trainer.model + merge_dir = os.path.join(args.output_dir, "merged-full-model") + os.makedirs(merge_dir, exist_ok=True) + merged.save_pretrained(merge_dir, safe_serialization=True) + tokenizer.save_pretrained(merge_dir) + print(f"[merge] Saved merged model to: {merge_dir}", flush=True) + except Exception as e: + print(f"[merge] FAILED: {e}", flush=True) + + print_once("Done.") + +if __name__ == "__main__": + main()