This commit is contained in:
hailin 2025-09-08 09:24:40 +08:00
parent 30de9a5e5c
commit e22e569303
3 changed files with 988 additions and 33 deletions

16
train_mm_zero3_lora.sh Normal file
View File

@ -0,0 +1,16 @@
deepspeed --hostfile hostfile \
--num_nodes 6 --num_gpus 4 \
train_sft_lora.py \
--model_name_or_path /home/test/Qwen3-32B \
--data_glob "/home/test/datasets/my_corpus/train*.jsonl" \
--output_dir /home/test/checkpoints/q3-32b-lora \
--seq_len 4096 \
--bf16 \
--gradient_accumulation_steps 64 \
--per_device_train_batch_size 1 \
--learning_rate 1e-4 \
--warmup_ratio 0.03 \
--lora_r 16 --lora_alpha 32 --lora_dropout 0.05 \
--lora_target auto \
--deepspeed /home/test/jd_train/ds_config_zero3.json \
--report_to wandb --wandb_project ds-qwen3-lora

View File

@ -194,10 +194,174 @@ class CsvLossLogger(TrainerCallback):
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
)
# ----------------- 仅监督 assistant 的数据集 -----------------
# # ----------------- 仅监督 assistant 的数据集 -----------------
# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
# """
# 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
# """
# spans: List[Tuple[int, int]] = []
# open_tag = "<|im_start|>assistant\n"
# close_tag = "<|im_end|>\n"
# pos = 0
# while True:
# a = rendered.find(open_tag, pos)
# if a == -1:
# break
# start = a + len(open_tag)
# b = rendered.find(close_tag, start)
# if b == -1:
# break
# spans.append((start, b))
# pos = b + len(close_tag)
# return spans
# class QwenChatSFTDataset(IterableDataset):
# """
# 期望 jsonl 每行形如:
# {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
# 可选包含工具:
# {"messages":[...], "tools":[{...}]}
# 工作流:
# - 使用 tokenizer.apply_chat_template 渲染
# - 仅对 assistant 片段计损失(其他 token 的 label = -100
# - 超长序列保留尾部(通常包含回答)
# """
# def __init__(self,
# ex_iter: Iterable[dict],
# tokenizer: AutoTokenizer,
# seq_len: int = 4096):
# self.ex_iter = ex_iter
# self.tok = tokenizer
# self.seq_len = seq_len
# def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
# # >>> DEBUG BEGIN
# dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
# if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0
# dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
# rank = int(os.environ.get("RANK", "0"))
# lrank = int(os.environ.get("LOCAL_RANK", "-1"))
# host = socket.gethostname()
# # >>> DEBUG END
# for ex in self.ex_iter:
# msgs = ex.get("messages", None)
# if not msgs or not isinstance(msgs, list):
# continue
# # 可选过滤 think
# bad = False
# for m in msgs:
# if m.get("role") == "assistant" and isinstance(m.get("content"), str):
# c = m["content"]
# if "<think>" in c and "</think>" in c:
# inner = c.split("<think>")[-1].split("</think>")[0].strip()
# if inner:
# bad = True; break
# # 注销这里就可以确保<think></think>参与计算监督微调,打开就表示跳过
# if bad:
# continue
# tools = ex.get("tools", None)
# # 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况
# try:
# rendered: str = self.tok.apply_chat_template(
# msgs, tools=tools, add_generation_prompt=False, tokenize=False
# )
# except TypeError:
# rendered: str = self.tok.apply_chat_template(
# msgs, add_generation_prompt=False, tokenize=False
# )
# if not isinstance(rendered, str) or not rendered.strip():
# continue
# spans = _assistant_char_spans(rendered)
# if not spans:
# continue
# enc = self.tok(
# rendered,
# add_special_tokens=False,
# return_offsets_mapping=True
# )
# input_ids: List[int] = enc["input_ids"]
# offsets: List[Tuple[int, int]] = enc["offset_mapping"]
# if not input_ids:
# continue
# labels = [-100] * len(input_ids)
# def in_any_span(lo: int, hi: int) -> bool:
# for s, e in spans:
# if not (hi <= s or lo >= e):
# return True
# return False
# for i, (lo, hi) in enumerate(offsets):
# if in_any_span(lo, hi):
# labels[i] = input_ids[i]
# # —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len ——
# # 1) 截断到 seq_len保留尾部
# if len(input_ids) > self.seq_len:
# input_ids = input_ids[-self.seq_len:]
# labels = labels[-self.seq_len:]
# # 2) 左侧补齐到 seq_len保证所有样本长度一致
# pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
# L = len(input_ids)
# if L < self.seq_len:
# pad = self.seq_len - L
# input_ids = ([pad_id] * pad) + input_ids
# labels = ([-100] * pad) + labels
# attn_mask = [0] * pad + [1] * L
# else:
# # 恰好等于 seq_len
# attn_mask = [1] * self.seq_len
# # 若没有任何可训练 tokenlabels 全 -100跳过
# if all(v == -100 for v in labels):
# continue
# assert len(input_ids) == self.seq_len
# assert len(labels) == self.seq_len
# assert len(attn_mask) == self.seq_len
# # >>> DEBUG PRINT此时变量已定义
# if dbg_on and self._dbg_seen < dbg_limit:
# sup_tok = sum(1 for v in labels if v != -100)
# print(
# f"[sample][host={host} RANK={rank} LRank={lrank}] "
# f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} "
# f"seq_len={self.seq_len} pad_id={pad_id}",
# flush=True
# )
# if sup_tok == 0:
# print(
# f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped",
# flush=True
# )
# self._dbg_seen += 1
# # <<< DEBUG PRINT
# yield {
# "input_ids": torch.tensor(input_ids, dtype=torch.long),
# "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
# "labels": torch.tensor(labels, dtype=torch.long),
# }
# ----------------- 工具:提取 assistant 字符区间 -----------------
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
"""
apply_chat_template 渲染后的文本中返回所有 assistant 内容的字符区间 [start, end)
apply_chat_template 渲染后的纯文本中返回所有 assistant 段的字符区间 [start, end)
这些区间覆盖了 assistant 的全部内容包括 <think>...</think> 标签与正文
"""
spans: List[Tuple[int, int]] = []
open_tag = "<|im_start|>assistant\n"
@ -207,14 +371,16 @@ def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
a = rendered.find(open_tag, pos)
if a == -1:
break
start = a + len(open_tag)
b = rendered.find(close_tag, start)
s = a + len(open_tag)
b = rendered.find(close_tag, s)
if b == -1:
break
spans.append((start, b))
spans.append((s, b))
pos = b + len(close_tag)
return spans
# ----------------- 数据集SFT监督 assistant 全段,含 <think> 标签与内容) -----------------
class QwenChatSFTDataset(IterableDataset):
"""
期望 jsonl 每行形如
@ -225,7 +391,7 @@ class QwenChatSFTDataset(IterableDataset):
工作流
- 使用 tokenizer.apply_chat_template 渲染
- 仅对 assistant 片段计损失其他 token label = -100
- 超长序列保留尾部通常包含回答
- 截断时优先确保最后一个 assistant 不被截断若其长度 > seq_len则保留其结尾以避免切尾
"""
def __init__(self,
ex_iter: Iterable[dict],
@ -251,18 +417,7 @@ class QwenChatSFTDataset(IterableDataset):
if not msgs or not isinstance(msgs, list):
continue
# 可选过滤 think
bad = False
for m in msgs:
if m.get("role") == "assistant" and isinstance(m.get("content"), str):
c = m["content"]
if "<think>" in c and "</think>" in c:
inner = c.split("<think>")[-1].split("</think>")[0].strip()
if inner:
bad = True; break
if bad:
continue
# —— 不再过滤 <think>:显式允许其参与监督(包括标签与正文)
tools = ex.get("tools", None)
# 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况
@ -275,7 +430,6 @@ class QwenChatSFTDataset(IterableDataset):
msgs, add_generation_prompt=False, tokenize=False
)
if not isinstance(rendered, str) or not rendered.strip():
continue
@ -283,6 +437,7 @@ class QwenChatSFTDataset(IterableDataset):
if not spans:
continue
# 编码并拿到字符偏移,确保与 rendered 对齐
enc = self.tok(
rendered,
add_special_tokens=False,
@ -294,10 +449,12 @@ class QwenChatSFTDataset(IterableDataset):
if not input_ids:
continue
# 先对“所有 assistant 片段”打标签;包含 <think> 标签与内容、以及回答正文
labels = [-100] * len(input_ids)
def in_any_span(lo: int, hi: int) -> bool:
for s, e in spans:
# 与任一 [s, e) 有交集即监督
if not (hi <= s or lo >= e):
return True
return False
@ -306,13 +463,68 @@ class QwenChatSFTDataset(IterableDataset):
if in_any_span(lo, hi):
labels[i] = input_ids[i]
# —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len ——
# 1) 截断到 seq_len保留尾部
if len(input_ids) > self.seq_len:
input_ids = input_ids[-self.seq_len:]
labels = labels[-self.seq_len:]
# 若没有任何可训练 tokenlabels 全 -100跳过
if all(v == -100 for v in labels):
continue
# 2) 左侧补齐到 seq_len保证所有样本长度一致
# ======== Assistant 感知的截断策略(保证“最后一个 assistant 不被截掉”)========
if len(input_ids) > self.seq_len:
# 取“最后一个 assistant”的字符区间
s_last, e_last = spans[-1]
# 将字符区间映射到 token 索引区间 [j, k_excl)
# j: 第一个 token其右端 hi > s_last
j = 0
while j < len(offsets) and offsets[j][1] <= s_last:
j += 1
# k_excl: 第一个 token其左端 lo >= e_last即不再与 [s_last, e_last) 相交)
k_excl = j
while k_excl < len(offsets) and offsets[k_excl][0] < e_last:
k_excl += 1
A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数
if A >= self.seq_len:
# 单个 assistant 本身超过窗口 —— 保“结尾”,避免被切尾
start = max(0, k_excl - self.seq_len)
end = start + self.seq_len
else:
# 有空间容纳整个 assistant尽量把窗口对齐到包括完整 assistant
# 先试图把窗口从 j 开始,但要保证 k_excl 也在窗口内
start = max(0, min(j, len(input_ids) - self.seq_len))
end = start + self.seq_len
if end < k_excl:
# 还没覆盖到 assistant 末尾,则右移窗口到恰好覆盖末尾
end = k_excl
start = end - self.seq_len
if start < 0:
start = 0
end = self.seq_len
# 可选:尝试“居中”一点(留部分历史上下文),但仍需包含完整 [j, k_excl)
leftover = self.seq_len - A
# 把剩余的一半尽量分配给左侧上下文(不越界)
left_wish = leftover // 2
start = max(0, min(j - left_wish, start))
end = start + self.seq_len
if end < k_excl:
# 若居中导致末尾又被排除,再纠正一次
end = k_excl
start = end - self.seq_len
if start < 0:
start = 0
end = self.seq_len
# 真正切片
input_ids = input_ids[start:end]
labels = labels[start:end]
# 注意offsets 后续不再使用(只为确定切片窗口),无需同步裁剪
# 训练注意:这里的策略保证:
# - 若最后一个 assistant <= seq_len完整保留
# - 若 > seq_len至少保证 assistant 的“结尾”在窗口内,不会“切尾”。
# ======== 统一长度:左侧补齐到 seq_len ========
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
L = len(input_ids)
if L < self.seq_len:
@ -321,24 +533,19 @@ class QwenChatSFTDataset(IterableDataset):
labels = ([-100] * pad) + labels
attn_mask = [0] * pad + [1] * L
else:
# 恰好等于 seq_len
attn_mask = [1] * self.seq_len
# 若没有任何可训练 tokenlabels 全 -100跳过
if all(v == -100 for v in labels):
continue
# Sanity
assert len(input_ids) == self.seq_len
assert len(labels) == self.seq_len
assert len(attn_mask) == self.seq_len
# >>> DEBUG PRINT(此时变量已定义)
# >>> DEBUG PRINT
if dbg_on and self._dbg_seen < dbg_limit:
sup_tok = sum(1 for v in labels if v != -100)
print(
f"[sample][host={host} RANK={rank} LRank={lrank}] "
f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} "
f"seq_len={self.seq_len} pad_id={pad_id}",
f"toks={len(input_ids)} sup_toks={sup_tok} seq_len={self.seq_len} pad_id={pad_id}",
flush=True
)
if sup_tok == 0:
@ -355,6 +562,8 @@ class QwenChatSFTDataset(IterableDataset):
"labels": torch.tensor(labels, dtype=torch.long),
}
# ----------------- 专用 Collatorpad inputs, pad labels=-100 -----------------
class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):

730
train_sft_lora.py Normal file
View File

@ -0,0 +1,730 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
os.environ.pop("PYTHONNOUSERSITE", None)
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("WANDB_START_METHOD", "thread")
os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb")
import glob
import socket
import argparse
import inspect
import sys
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
set_seed
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint
# ---------- PATH / CUDA utils ----------
import site, shutil
home = os.path.expanduser("~")
want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"]
cur = os.environ.get("PATH", "").split(":")
new = [d for d in want if d and d not in cur] + cur
os.environ["PATH"] = ":".join(new)
print(f"[env] PATH={os.environ['PATH']}", flush=True)
print(f"[env] which ninja={shutil.which('ninja')} which nvcc={shutil.which('nvcc')}", flush=True)
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda-11.8")
ld = os.environ.get("LD_LIBRARY_PATH", "")
cuda_lib = "/usr/local/cuda-11.8/lib64"
if cuda_lib not in ld.split(":"):
os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib
print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True)
os.environ.pop("DS_BUILD_OPS", None)
os.environ.pop("DS_SKIP_CUDA_BUILD", None)
try:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.insert(0, user_site)
except Exception:
pass
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
os.environ.setdefault("MAX_JOBS", "12")
if shutil.which("ninja") is None:
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True)
try:
from deepspeed.ops.op_builder import CPUAdamBuilder
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK", flush=True)
except Exception as e:
if "Ninja is required to load C++ extensions" in str(e):
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
from deepspeed.ops.op_builder import CPUAdamBuilder
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
else:
print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True)
# 不致命LoRA 不依赖这个算子,继续运行
pass
# ---------- helpers ----------
def is_main_process():
return int(os.environ.get("RANK", "0")) == 0
def print_once(*args, **kwargs):
if is_main_process():
print(*args, **kwargs, flush=True)
class DebugTrainer(Trainer):
def training_step(self, model, inputs, num_items_in_batch=None):
if not hasattr(self, "_dbg_printed"):
rank = int(os.environ.get("RANK", "0"))
host = socket.gethostname()
ids = inputs["input_ids"]; msk = inputs["attention_mask"]; labs = inputs["labels"]
print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} "
f"supervised={(labs!=-100).sum().item()}", flush=True)
print(f"[step0][host={host} RANK={rank}] "
f"input_ids.shape={tuple(ids.shape)} "
f"attention_mask.shape={tuple(msk.shape)} "
f"labels.shape={tuple(labs.shape)} "
f"num_items_in_batch={num_items_in_batch}", flush=True)
self._dbg_printed = True
return super().training_step(model, inputs, num_items_in_batch)
class CsvLossLogger(TrainerCallback):
def __init__(self, csv_path: str):
self.csv_path = csv_path
if is_main_process():
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
with open(self.csv_path, "w", encoding="utf-8") as f:
f.write("step,loss,lr,total_flos\n")
def on_train_begin(self, args, state, control, **kwargs):
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None: return
cur = int(getattr(state, "global_step", 0) or 0)
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a")
if tot and not hasattr(self, "_tot_announced"):
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
self._tot_announced = True
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] step {cur}/{tot} ({pct}) "
f"loss={logs.get('loss')} lr={logs.get('learning_rate')}", flush=True)
if not is_main_process(): return
with open(self.csv_path, "a", encoding="utf-8") as f:
f.write(f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
# ---------- assistant span detection ----------
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
spans: List[Tuple[int, int]] = []
open_tag = "<|im_start|>assistant\n"
close_tag = "<|im_end|>\n"
pos = 0
while True:
a = rendered.find(open_tag, pos)
if a == -1: break
s = a + len(open_tag)
b = rendered.find(close_tag, s)
if b == -1: break
spans.append((s, b))
pos = b + len(close_tag)
return spans
# ---------- Dataset (supervise assistant incl. <think> tags) ----------
class QwenChatSFTDataset(IterableDataset):
def __init__(self, ex_iter: Iterable[dict], tokenizer: AutoTokenizer, seq_len: int = 4096):
self.ex_iter = ex_iter
self.tok = tokenizer
self.seq_len = seq_len
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0
dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
rank = int(os.environ.get("RANK", "0"))
lrank = int(os.environ.get("LOCAL_RANK", "-1"))
host = socket.gethostname()
for ex in self.ex_iter:
msgs = ex.get("messages", None)
if not msgs or not isinstance(msgs, list): continue
tools = ex.get("tools", None)
try:
rendered: str = self.tok.apply_chat_template(
msgs, tools=tools, add_generation_prompt=False, tokenize=False
)
except TypeError:
rendered: str = self.tok.apply_chat_template(
msgs, add_generation_prompt=False, tokenize=False
)
if not isinstance(rendered, str) or not rendered.strip(): continue
spans = _assistant_char_spans(rendered)
if not spans: continue
enc = self.tok(rendered, add_special_tokens=False, return_offsets_mapping=True)
input_ids: List[int] = enc["input_ids"]
offsets: List[Tuple[int, int]] = enc["offset_mapping"]
if not input_ids: continue
labels = [-100] * len(input_ids)
def in_any_span(lo: int, hi: int) -> bool:
for s, e in spans:
if not (hi <= s or lo >= e):
return True
return False
for i, (lo, hi) in enumerate(offsets):
if in_any_span(lo, hi):
labels[i] = input_ids[i]
if all(v == -100 for v in labels): # 无监督 token
continue
# ---- assistant-aware truncation: keep last assistant not cut off
if len(input_ids) > self.seq_len:
s_last, e_last = spans[-1]
j = 0
while j < len(offsets) and offsets[j][1] <= s_last: j += 1
k_excl = j
while k_excl < len(offsets) and offsets[k_excl][0] < e_last: k_excl += 1
A = max(0, k_excl - j)
if A >= self.seq_len:
start = max(0, k_excl - self.seq_len); end = start + self.seq_len
else:
start = max(0, min(j, len(input_ids) - self.seq_len))
end = start + self.seq_len
if end < k_excl:
end = k_excl; start = end - self.seq_len
if start < 0: start = 0; end = self.seq_len
leftover = self.seq_len - A
left_wish = leftover // 2
start = max(0, min(j - left_wish, start))
end = start + self.seq_len
if end < k_excl:
end = k_excl; start = end - self.seq_len
if start < 0: start = 0; end = self.seq_len
input_ids = input_ids[start:end]
labels = labels[start:end]
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
L = len(input_ids)
if L < self.seq_len:
pad = self.seq_len - L
input_ids = ([pad_id]*pad) + input_ids
labels = ([-100]*pad) + labels
attn_mask = [0]*pad + [1]*L
else:
attn_mask = [1]*self.seq_len
assert len(input_ids) == self.seq_len
assert len(labels) == self.seq_len
assert len(attn_mask) == self.seq_len
if dbg_on and self._dbg_seen < dbg_limit:
sup_tok = sum(1 for v in labels if v != -100)
print(f"[sample][host={host} RANK={rank} LRank={lrank}] "
f"toks={len(input_ids)} sup_toks={sup_tok} seq_len={self.seq_len} pad_id={pad_id}", flush=True)
if sup_tok == 0:
print(f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> skipped", flush=True)
self._dbg_seen += 1
yield {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
# ---------- Collator ----------
class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
self.tok = tokenizer
self.pad_to_length = pad_to_length
assert self.tok.pad_token_id is not None
def __call__(self, features):
if not features:
raise RuntimeError(f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator.")
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
input_ids = [_to_list(f["input_ids"]) for f in features]
attn_masks = [_to_list(f["attention_mask"]) for f in features]
labels_list = [_to_list(f["labels"]) for f in features]
max_len_in_batch = max(len(x) for x in input_ids)
target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch
pad_id = self.tok.pad_token_id
batch_inp, batch_attn, batch_lab = [], [], []
for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
pad_len = target_len - len(inp)
if pad_len < 0:
inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len]
pad_len = 0
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
if os.environ.get("DBG_COLLATE","0") == "1":
print(f"[collate][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] "
f"features={len(features)} target_len={target_len}", flush=True)
return {
"input_ids": torch.stack(batch_inp, dim=0),
"attention_mask": torch.stack(batch_attn, dim=0),
"labels": torch.stack(batch_lab, dim=0),
}
# ---------- Args ----------
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--model_name_or_path", type=str, required=True)
ap.add_argument("--data_glob", type=str, required=True)
ap.add_argument("--output_dir", type=str, required=True)
ap.add_argument("--seq_len", type=int, default=4096)
ap.add_argument("--learning_rate", type=float, default=1e-4) # LoRA 通常可更大学习率
ap.add_argument("--weight_decay", type=float, default=0.0) # LoRA 常设 0 或很小
ap.add_argument("--warmup_ratio", type=float, default=0.03)
ap.add_argument("--num_train_epochs", type=float, default=1.0)
ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0)
ap.add_argument("--seed", type=int, default=1337)
ap.add_argument("--gradient_checkpointing", action="store_true")
ap.add_argument("--bf16", action="store_true")
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
ap.add_argument("--report_to", type=str, default="tensorboard", choices=["none","tensorboard","wandb"])
ap.add_argument("--wandb_project", type=str, default="ds-qwen3-lora")
ap.add_argument("--eval_data_glob", type=str, default=None)
ap.add_argument("--local_rank", type=int, default=-1)
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
ap.add_argument("--deepspeed", type=str, default=None)
# ---- LoRA specific ----
ap.add_argument("--lora_r", type=int, default=16)
ap.add_argument("--lora_alpha", type=float, default=32)
ap.add_argument("--lora_dropout", type=float, default=0.05)
ap.add_argument("--lora_target", type=str, default="auto",
help='逗号分隔,如 "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj";或 "auto"')
ap.add_argument("--qlora", action="store_true", help="使用 4bit (NF4) QLoRA多机 DS 不建议)")
ap.add_argument("--merge_lora_and_save", action="store_true",
help="训练后在 rank0 合并 LoRA 到基座并另存(注意显存/内存占用)")
return ap.parse_args()
# ---------- LoRA helpers ----------
def _auto_lora_targets(model) -> List[str]:
"""
针对 Qwen/Llama 自动挑选常见的线性层名字仅匹配存在的模块
"""
cand = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj",
"w1","w2","w3", "W_pack", "o_attn", "o_proj"] # 覆盖不同实现命名
present = set()
for name, module in model.named_modules():
if any(name.endswith(f".{c}") or name == c for c in cand):
present.add(name.split(".")[-1])
# 回落:若一个都没匹配到,使用“所有 nn.Linear”
if not present:
return ["all-linear"]
# 去重且保序
order = []
for c in cand:
if c in present: order.append(c)
return order
# ---------- main ----------
def main():
args = parse_args()
if os.environ.get("RANK","0") != "0" and args.report_to == "wandb":
print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True)
args.report_to = "none"
set_seed(args.seed)
# DeepSpeed enable?
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
dschf = None
if use_ds:
try:
from transformers.integrations.deepspeed import HfDeepSpeedConfig
src = "transformers.integrations.deepspeed"
except Exception:
from transformers import HfDeepSpeedConfig
src = "transformers"
dschf = HfDeepSpeedConfig(args.deepspeed)
print(f"[dbg] HfDeepSpeedConfig loaded from {src}", flush=True)
if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
import transformers as hf
try:
import deepspeed as ds
ds_ver = ds.__version__
except Exception:
ds_ver = "n/a"
def dbg(msg):
print(f"[dbg][host={socket.gethostname()} RANK={os.environ.get('RANK','0')} "
f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}", flush=True)
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
dbg(f"args={args}")
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
os.environ.get("WORLD_SIZE"), os.environ.get("RANK"),
os.environ.get("LOCAL_RANK", str(args.local_rank)),
os.environ.get("MASTER_ADDR"), os.environ.get("MASTER_PORT"),
os.environ.get("CUDA_VISIBLE_DEVICES"),
))
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
# init dist
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
if torch.cuda.is_available() and local_rank >= 0:
torch.cuda.set_device(local_rank)
dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} "
f"name={torch.cuda.get_device_name(torch.cuda.current_device())}")
if world_size > 1 and dist.is_available() and not dist.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
dbg(f"init_process_group backend={backend} via env://")
dist.init_process_group(backend=backend, init_method="env://")
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
try:
if getattr(tokenizer, "padding_side", None) != "left":
tokenizer.padding_side = "left"
except Exception:
pass
from transformers import PreTrainedTokenizerFast
if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False):
raise RuntimeError("需要 Fast tokenizer 以获取 offset_mapping请安装 tokenizers>=0.14。")
tokenizer.model_max_length = args.seq_len
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} model_max_length={tokenizer.model_max_length}")
# dtype
def _bf16_supported():
if not torch.cuda.is_available(): return False
if hasattr(torch.cuda, "is_bf16_supported"):
return torch.cuda.is_bf16_supported()
major, minor = torch.cuda.get_device_capability()
return (major, minor) >= (8, 0)
use_bf16 = bool(args.bf16 and _bf16_supported())
compute_dtype = torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32)
# -------- load base model (with/without 4bit) --------
quantization_config = None
if args.qlora:
try:
from transformers import BitsAndBytesConfig
from peft import prepare_model_for_kbit_training
except Exception as e:
raise RuntimeError("使用 --qlora 需要安装 bitsandbytes>=0.41 与 peft。") from e
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=compute_dtype
)
# 4bit 下不要传 attn_implementation="sdpa" 给部分旧版 torch
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=compute_dtype,
trust_remote_code=True,
low_cpu_mem_usage=True,
quantization_config=quantization_config,
device_map=None # 用 DeepSpeed/Trainer 接管
)
if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=compute_dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
attn_implementation="sdpa",
)
if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False
# -------- wrap with LoRA --------
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
if args.lora_target.strip().lower() == "auto":
targets = _auto_lora_targets(model)
else:
targets = [x.strip() for x in args.lora_target.split(",") if x.strip()]
if not targets:
targets = _auto_lora_targets(model)
lora_cfg = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=targets,
bias="none",
inference_mode=False
)
model = get_peft_model(model, lora_cfg)
# 冻结确认
if is_main_process():
try:
model.print_trainable_parameters()
except Exception:
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"[LoRA] trainable={trainable:,} / total={total:,} ({trainable/total:.2%})", flush=True)
# -------- data streams --------
files = sorted(glob.glob(args.data_glob))
if len(files) == 0:
raise FileNotFoundError(f"No files matched DATA_GLOB={args.data_glob}")
if is_main_process():
print(f"[data] matched {len(files)} files, example[0]={files[0]}", flush=True)
ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_probe():
for ex in ds_stream_probe: yield ex
train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len)
try:
_ = next(iter(train_stream_probe))
except StopIteration:
raise RuntimeError("[data] 样本结构不合法或全部被裁切。")
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
def has_at_least(stream, n: int):
it = iter(stream)
for _ in range(n):
try: next(it)
except StopIteration: return 0
return 1
need = max(1, args.gradient_accumulation_steps)
local_ok = has_at_least(probe_stream, need)
if dist.is_available() and dist.is_initialized():
t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank>=0 else "cpu"))
dist.all_reduce(t, op=dist.ReduceOp.MIN)
if t.item() == 0:
if is_main_process():
print(f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批。", flush=True)
dist.barrier(); sys.exit(2)
else:
if local_ok == 0:
print(f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批。", flush=True)
sys.exit(2)
# eval
eval_dataset: Optional[Dataset] = None
class ListDataset(Dataset):
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
if args.eval_data_glob:
eval_files = sorted(glob.glob(args.eval_data_glob))
if len(eval_files) == 0:
raise FileNotFoundError(f"No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
if is_main_process():
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
def ex_iter_eval():
for ex in ds_eval_stream: yield ex
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
eval_items: List[Dict[str, torch.Tensor]] = [s for s in eval_iterable]
if len(eval_items) == 0:
raise RuntimeError("[eval] 读到了 0 条有效样本。")
eval_dataset = ListDataset(eval_items)
# pad to global batch size
ws = max(int(os.environ.get("WORLD_SIZE","1")), 1)
be = max(1, args.per_device_eval_batch_size)
global_bs = ws * be
r = len(eval_dataset) % global_bs
if r != 0:
pad_need = global_bs - r
eval_dataset.items += eval_dataset.items[:pad_need]
if is_main_process():
print(f"[eval] padded eval set to {len(eval_dataset)}", flush=True)
# collator
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
# training args
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs"); os.makedirs(logging_dir, exist_ok=True)
ta_kwargs = {}
sig = inspect.signature(TrainingArguments.__init__).parameters
if eval_dataset is not None:
if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "steps"
elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "steps"
else:
if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "no"
elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "no"
ta_kwargs2 = dict(
output_dir=args.output_dir,
logging_dir=logging_dir,
do_train=True,
do_eval=(eval_dataset is not None),
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
max_steps=args.max_steps if args.max_steps > 0 else -1,
lr_scheduler_type="cosine",
logging_steps=args.log_interval,
save_steps=args.save_steps,
save_total_limit=2,
deepspeed=(args.deepspeed if use_ds else None),
dataloader_drop_last=False,
dataloader_num_workers=0,
per_device_eval_batch_size=args.per_device_eval_batch_size,
report_to=([] if args.report_to == "none" else [args.report_to]),
gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False,
save_on_each_node=True,
logging_first_step=True,
**ta_kwargs,
)
# 精度QLoRA/LoRA 均按 compute_dtype 设置
if "dataloader_pin_memory" in sig: ta_kwargs2["dataloader_pin_memory"] = False
if "torch_compile" in sig: ta_kwargs2["torch_compile"] = False
ta_kwargs2.update({
"bf16": (compute_dtype==torch.bfloat16),
"fp16": (compute_dtype==torch.float16),
})
training_args = TrainingArguments(**ta_kwargs2)
# pass tokenizer / processing_class
trainer_kwargs = {}
if "processing_class" in inspect.signature(Trainer.__init__).parameters:
trainer_kwargs["processing_class"] = tokenizer
else:
trainer_kwargs["tokenizer"] = tokenizer
trainer = DebugTrainer(
model=model,
args=training_args,
train_dataset=train_stream,
eval_dataset=eval_dataset,
data_collator=data_collator,
**trainer_kwargs
)
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
# resume (per-node local checkpoint agreement)
def last_step(path: str) -> int:
ck = get_last_checkpoint(path)
if ck is None: return -1
base = os.path.basename(ck)
try: return int(base.split("-")[-1])
except Exception: return -1
local_last = last_step(args.output_dir)
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank>=0) else "cpu")
resume_flag = None
if dist.is_available() and dist.is_initialized():
has_local = torch.tensor(1 if local_last >= 0 else 0, device=device)
dist.all_reduce(has_local, op=dist.ReduceOp.MIN)
if has_local.item() == 1:
ts = torch.tensor(local_last, device=device)
world = dist.get_world_size()
buf = [torch.zeros_like(ts) for _ in range(world)]
dist.all_gather(buf, ts)
steps = [b.item() for b in buf]
k = min(steps)
if k >= 0:
resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}")
if is_main_process():
print(f"[resume] steps={steps}, resume={resume_flag}", flush=True)
else:
if local_last >= 0:
resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}")
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}")
if dist.is_available() and dist.is_initialized():
present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device)
dist.all_reduce(present, op=dist.ReduceOp.MIN)
if present.item() == 0:
if is_main_process():
print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True)
resume_flag = None
dist.barrier()
else:
if resume_flag is not None and not os.path.isdir(resume_flag):
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
resume_flag = None
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
print_once("***** Starting LoRA training *****")
print(f"[dbg] allocated={torch.cuda.memory_allocated()/1024**2:.1f} MB, "
f"reserved={torch.cuda.memory_reserved()/1024**2:.1f} MB", flush=True)
train_result = trainer.train(resume_from_checkpoint=resume_flag)
# save adapter (not the full base)
trainer.save_model() # 对 PeftModel只保存 adapter 权重到 output_dir
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# eval
if eval_dataset is not None:
print_once("***** Running eval *****")
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)
# optional merge
if args.merge_lora_and_save and is_main_process():
print("[merge] Merging LoRA into base model ...", flush=True)
try:
if isinstance(trainer.model, PeftModel):
merged = trainer.model.merge_and_unload()
else:
merged = trainer.model
merge_dir = os.path.join(args.output_dir, "merged-full-model")
os.makedirs(merge_dir, exist_ok=True)
merged.save_pretrained(merge_dir, safe_serialization=True)
tokenizer.save_pretrained(merge_dir)
print(f"[merge] Saved merged model to: {merge_dir}", flush=True)
except Exception as e:
print(f"[merge] FAILED: {e}", flush=True)
print_once("Done.")
if __name__ == "__main__":
main()