jd_train/train_sft_lora.py

912 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import os
os.environ.pop("PYTHONNOUSERSITE", None)
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("WANDB_START_METHOD", "thread")
os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb")
os.environ.setdefault("WANDB_BASE_URL", "https://wandb.szaiai.com")
os.environ.setdefault("WANDB_INIT_TIMEOUT", "300")
import glob
import socket
import argparse
import inspect
import sys
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
set_seed
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint
from torch.optim import AdamW as TorchAdamW
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
import site, shutil
home = os.path.expanduser("~")
want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"]
cur = os.environ.get("PATH", "").split(":")
new = [d for d in want if d and d not in cur] + cur
os.environ["PATH"] = ":".join(new)
print(f"[env] PATH={os.environ['PATH']}", flush=True)
print(f"[env] which ninja={shutil.which('ninja')} which nvcc={shutil.which('nvcc')}", flush=True)
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda-11.8")
ld = os.environ.get("LD_LIBRARY_PATH", "")
cuda_lib = "/usr/local/cuda-11.8/lib64"
if cuda_lib not in ld.split(":"):
os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib
print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True)
os.environ.pop("DS_BUILD_OPS", None)
os.environ.pop("DS_SKIP_CUDA_BUILD", None)
try:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.insert(0, user_site)
except Exception:
pass
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
os.environ.setdefault("MAX_JOBS", "12")
if shutil.which("ninja") is None:
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True)
try:
from deepspeed.ops.op_builder import CPUAdamBuilder
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK", flush=True)
except Exception as e:
if "Ninja is required to load C++ extensions" in str(e):
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
from deepspeed.ops.op_builder import CPUAdamBuilder
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
else:
import socket
print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True)
raise
# ----------------- 进程工具 -----------------
def is_main_process():
return int(os.environ.get("RANK", "0")) == 0
def print_once(*args, **kwargs):
if is_main_process():
print(*args, **kwargs, flush=True)
class DebugTrainer(Trainer):
def training_step(self, model, inputs, num_items_in_batch=None):
if not hasattr(self, "_dbg_printed"):
rank = int(os.environ.get("RANK", "0"))
host = socket.gethostname()
ids = inputs["input_ids"]
msk = inputs["attention_mask"]
labs = inputs["labels"]
print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} "
f"supervised={(labs!=-100).sum().item()}",
flush=True)
print(
f"[step0][host={host} RANK={rank}] "
f"input_ids.shape={tuple(ids.shape)} "
f"attention_mask.shape={tuple(msk.shape)} "
f"labels.shape={tuple(labs.shape)} "
f"num_items_in_batch={num_items_in_batch}",
flush=True
)
self._dbg_printed = True
try:
return super().training_step(model, inputs, num_items_in_batch=num_items_in_batch)
except TypeError:
return super().training_step(model, inputs)
# ----------------- 日志回调 -----------------
class CsvLossLogger(TrainerCallback):
def __init__(self, csv_path: str):
self.csv_path = csv_path
if is_main_process():
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
with open(self.csv_path, "w", encoding="utf-8") as f:
f.write("step,loss,lr,total_flos\n")
def on_train_begin(self, args, state, control, **kwargs):
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
rank = os.environ.get("RANK", "?")
host = socket.gethostname()
print(f"[{host} rank={rank}] total_steps={tot}", flush=True)
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
cur = int(getattr(state, "global_step", 0) or 0)
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a")
if tot and not hasattr(self, "_tot_announced"):
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
self._tot_announced = True
rank = os.environ.get("RANK", "?")
host = socket.gethostname()
print(
f"[{host} rank={rank}] step {cur}/{tot} ({pct}) "
f"loss={logs.get('loss')} lr={logs.get('learning_rate')}",
flush=True
)
if not is_main_process():
return
with open(self.csv_path, "a", encoding="utf-8") as f:
f.write(f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
# ----------------- 仅监督 assistant 内容token-id 级) -----------------
class QwenChatSFTDataset(IterableDataset):
def __init__(self, ex_iter: Iterable[dict], tokenizer: AutoTokenizer,
seq_len: int = 4096, mask_think_and_tags: bool = True):
self.ex_iter = ex_iter
self.tok = tokenizer
self.seq_len = seq_len
self.mask_think_and_tags = mask_think_and_tags
self.id_START = self.tok.convert_tokens_to_ids("<|im_start|>")
self.id_END = self.tok.convert_tokens_to_ids("<|im_end|>")
self.ids_ASSISTANT_CANDIDATES = [
self.tok.encode("assistant\n", add_special_tokens=False),
self.tok.encode("assistant", add_special_tokens=False),
]
self.ids_ASSISTANT_CANDIDATES = [c for c in self.ids_ASSISTANT_CANDIDATES if len(c) > 0]
if not self.ids_ASSISTANT_CANDIDATES:
raise RuntimeError("[fatal] no valid 'assistant' role token sequence found; check chat template/tokenizer.")
self.ids_THINK_OPEN = self.tok.encode("<think>", add_special_tokens=False)
self.ids_THINK_CLOSE = self.tok.encode("</think>", add_special_tokens=False)
for name, val in {"id_START": self.id_START, "id_END": self.id_END}.items():
if val is None or val == self.tok.unk_token_id:
raise RuntimeError(f"[fatal] tokenizer missing special token id for {name}")
@staticmethod
def _find_subseq(hay: list, needle: list, start: int) -> int:
n = len(needle)
if n == 0: return start
for i in range(start, len(hay) - n + 1):
if hay[i:i+n] == needle:
return i
return -1
def _find_role_after_start(self, ids, j_start: int) -> Optional[Tuple[int, int]]:
for cand in self.ids_ASSISTANT_CANDIDATES:
pos = self._find_subseq(ids, cand, j_start)
if pos == j_start:
return (pos, len(cand))
return None
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
seen = 0
host = socket.gethostname()
rank = int(os.environ.get("RANK", "0"))
lrank = int(os.environ.get("LOCAL_RANK", "-1"))
it = self.ex_iter() if callable(self.ex_iter) else iter(self.ex_iter)
for ex in it:
# for ex in self.ex_iter:
msgs = ex.get("messages")
if not msgs or not isinstance(msgs, list):
continue
tools = ex.get("tools", None)
try:
ids = self.tok.apply_chat_template(
msgs, tools=tools, add_generation_prompt=False,
tokenize=True, return_tensors=None
)
if isinstance(ids, dict):
ids = ids["input_ids"]
except TypeError:
rendered: str = self.tok.apply_chat_template(
msgs, add_generation_prompt=False, tokenize=False
)
ids = self.tok(rendered, add_special_tokens=False)["input_ids"]
if not ids:
continue
mask = [0] * len(ids)
i = 0
while i < len(ids):
try:
a = ids.index(self.id_START, i)
except ValueError:
break
j = a + 1
role_match = self._find_role_after_start(ids, j)
if role_match is None:
i = a + 1
continue
_, role_len = role_match
content_lo = j + role_len
try:
b = ids.index(self.id_END, content_lo)
except ValueError:
i = a + 1
continue
content_hi = b
for t in range(content_lo, content_hi):
mask[t] = 1
if self.mask_think_and_tags:
p = content_lo
while True:
o = self._find_subseq(ids, self.ids_THINK_OPEN, p)
if o == -1 or o >= content_hi:
break
c = self._find_subseq(ids, self.ids_THINK_CLOSE, o + len(self.ids_THINK_OPEN))
if c == -1 or c > content_hi:
break
x_lo = o
x_hi = c + len(self.ids_THINK_CLOSE)
for t in range(x_lo, min(x_hi, content_hi)):
mask[t] = 0
p = x_hi
i = b + 1
if not any(mask):
continue
if len(ids) > self.seq_len:
last_on = max(idx for idx, v in enumerate(mask) if v == 1)
end = min(len(ids), last_on + 1)
start = max(0, end - self.seq_len)
ids = ids[start:end]
mask = mask[start:end]
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
L = len(ids)
if L < self.seq_len:
pad = self.seq_len - L
input_ids = [pad_id] * pad + ids
attention_mask = [0] * pad + [1] * L
labels = [-100] * pad + [tok if m == 1 else -100 for tok, m in zip(ids, mask)]
else:
input_ids = ids
attention_mask = [1] * self.seq_len
labels = [tok if m == 1 else -100 for tok, m in zip(ids, mask)]
if dbg_on and seen < dbg_limit:
sup_tok = sum(1 for v in labels if v != -100)
print(
f"[sample][host={host} RANK={rank} LRank={lrank}] "
f"toks={len(input_ids)} sup_toks={sup_tok} "
f"seq_len={self.seq_len} pad_id={pad_id}",
flush=True
)
seen += 1
yield {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
# ----------------- Collator -----------------
class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
self.tok = tokenizer
self.pad_to_length = pad_to_length
assert self.tok.pad_token_id is not None, "tokenizer.pad_token_id must be set"
def __call__(self, features):
if not features:
raise RuntimeError("Empty batch passed to collator")
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
input_ids = [_to_list(f["input_ids"]) for f in features]
attn_masks = [_to_list(f["attention_mask"]) for f in features]
labels_list = [_to_list(f["labels"]) for f in features]
max_len_in_batch = max(len(x) for x in input_ids)
target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch
pad_id = self.tok.pad_token_id
batch_inp, batch_attn, batch_lab = [], [], []
for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
pad_len = target_len - len(inp)
if pad_len < 0:
inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len]
pad_len = 0
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
return {
"input_ids": torch.stack(batch_inp, dim=0),
"attention_mask": torch.stack(batch_attn, dim=0),
"labels": torch.stack(batch_lab, dim=0),
}
# ----------------- 参数 -----------------
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--model_name_or_path", type=str, required=True)
ap.add_argument("--data_glob", type=str, required=True)
ap.add_argument("--output_dir", type=str, required=True)
ap.add_argument("--seq_len", type=int, default=4096)
ap.add_argument("--learning_rate", type=float, default=2e-4) # LoRA通常更大lr
ap.add_argument("--weight_decay", type=float, default=0.0)
ap.add_argument("--warmup_ratio", type=float, default=0.02)
ap.add_argument("--num_train_epochs", type=float, default=1.0)
ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0)
ap.add_argument("--seed", type=int, default=1337)
ap.add_argument("--gradient_checkpointing", action="store_true")
ap.add_argument("--bf16", action="store_true")
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
ap.add_argument("--report_to", type=str, default="tensorboard",
choices=["none","tensorboard","wandb"])
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
ap.add_argument("--eval_data_glob", type=str, default=None)
ap.add_argument("--local_rank", type=int, default=-1)
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
ap.add_argument("--deepspeed", type=str, default=None)
ap.add_argument("--eval_steps", type=int, default=10)
# ===== LoRA 相关 =====
ap.add_argument("--lora_r", type=int, default=16)
ap.add_argument("--lora_alpha", type=float, default=32.0)
ap.add_argument("--lora_dropout", type=float, default=0.05)
ap.add_argument("--lora_bias", type=str, default="none", choices=["none","all","lora_only"])
ap.add_argument("--lora_exclude", type=str, default="", help="逗号分隔的层名后缀(如 lm_head,embed_tokens用于排除")
ap.add_argument("--merge_lora_and_save", action="store_true", help="训练后把LoRA合并到基座并另存占显存/内存大)")
return ap.parse_args()
# ----------------- 小工具:日志与颜色 -----------------
def _make_dbg():
try:
import colorama
colorama.just_fix_windows_console()
except Exception:
pass
def _use_color() -> bool:
if os.environ.get("NO_COLOR"): return False
if os.environ.get("FORCE_COLOR"): return True
return sys.stdout.isatty()
class _C:
reset="\033[0m"; gray="\033[90m"; green="\033[32m"; yellow="\033[33m"; red="\033[31m"; cyan="\033[36m"
_LEVEL_ALIAS={"": "info", None: "info", "ok":"ok","success":"ok","warn":"warn","warning":"warn","err":"err","error":"err","fatal":"err","fail":"err","info":"info"}
_LEVEL_COLOR={"ok":_C.green,"warn":_C.yellow,"err":_C.red,"info":_C.cyan}
def _norm_level(level):
if level is None: return "info"
if isinstance(level,(int,float)):
if level>=40: return "err"
if level>=30: return "warn"
return "info"
if isinstance(level,str):
key=level.strip().lower()
return _LEVEL_ALIAS.get(key,"info")
return "info"
def _paint(s,c): return f"{c}{s}{_C.reset}" if _use_color() else s
def dbg(msg, level=None):
lvl=_norm_level(level); color=_LEVEL_COLOR.get(lvl,_C.cyan)
host=socket.gethostname(); rank=os.environ.get("RANK","0"); lrank=os.environ.get("LOCAL_RANK","-1")
prefix=f"[dbg][host={host} RANK={rank} LOCAL_RANK={lrank}] "
print(_paint(prefix,_C.gray)+_paint(str(msg),color), flush=True)
return dbg
dbg=_make_dbg()
# ----------------- LoRA 目标层自动发现:所有线性层 -----------------
def discover_all_linear_leaf_names(model, exclude: List[str]) -> List[str]:
"""
返回 LoRA target_modules 需要的“叶子模块名后缀”集合(去重)。
默认遍历 nn.Linear / bitsandbytes 的 Linear4bit/8bit 等线性类。
"""
linear_like = [torch.nn.Linear]
try:
import bitsandbytes as bnb
import bitsandbytes.nn as bnbnn
# 兼容 bnb 线性封装
for cls_name in ("Linear4bit", "Linear8bitLt"):
if hasattr(bnbnn, cls_name):
linear_like.append(getattr(bnbnn, cls_name))
except Exception:
pass
suffixes=set()
for full_name, module in model.named_modules():
if any(isinstance(module, cls) for cls in linear_like):
last = full_name.split(".")[-1]
if last not in exclude:
suffixes.add(last)
targets = sorted(suffixes)
if not targets:
raise RuntimeError("未发现任何线性层可用于 LoRA。请检查模型结构或放宽排除列表。")
return targets
# ----------------- 主函数 -----------------
def main():
args = parse_args()
# 只有 rank0 用 wandb
if os.environ.get("RANK", "0") != "0" and args.report_to == "wandb":
print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True)
args.report_to = "none"
set_seed(args.seed)
# DeepSpeed
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
dschf = None
if use_ds:
try:
from transformers.integrations.deepspeed import HfDeepSpeedConfig
src = "transformers.integrations.deepspeed"
except Exception:
try:
from transformers import HfDeepSpeedConfig
src = "transformers"
except Exception as e:
raise RuntimeError("当前 transformers 版本未提供 HfDeepSpeedConfig请升级/降级 transformers") from e
dschf = HfDeepSpeedConfig(args.deepspeed)
dbg(f"HfDeepSpeedConfig loaded from {src}")
# W&Brank0
if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
is_rank0 = os.environ.get("RANK", "0") == "0" and os.environ.get("LOCAL_RANK", "-1") in ("0","-1")
if is_rank0:
import wandb
try:
os.environ.pop("WANDB_RUN_ID", None)
extra={}
if os.getenv("WANDB_NAME"): extra["name"]=os.getenv("WANDB_NAME")
if os.getenv("WANDB_GROUP"): extra["group"]=os.getenv("WANDB_GROUP")
if os.getenv("WANDB_RESUME"): extra["resume"]=os.getenv("WANDB_RESUME")
run = wandb.init(
project=args.wandb_project,
entity=os.getenv("WANDB_ENTITY") or os.getenv("WB_ENTITY") or "hailin",
settings=wandb.Settings(
base_url=os.getenv("WANDB_BASE_URL","https://wandb.szaiai.com"),
init_timeout=int(os.getenv("WANDB_INIT_TIMEOUT","300")),
),
**extra,
)
print(f"[wandb] run url: {getattr(run, 'url', '(n/a)')}", flush=True)
except Exception as e:
print(f"[wandb] init failed -> disable logging, reason={e}", flush=True)
os.environ["WANDB_DISABLED"]="true"
args.report_to="none"
else:
os.environ["WANDB_DISABLED"]="true"
import transformers as hf
try:
import deepspeed as ds
ds_ver = ds.__version__
except Exception:
ds_ver = "n/a"
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
dbg(f"args={args}")
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
os.environ.get("WORLD_SIZE"),
os.environ.get("RANK"),
os.environ.get("LOCAL_RANK", str(args.local_rank)),
os.environ.get("MASTER_ADDR"),
os.environ.get("MASTER_PORT"),
os.environ.get("CUDA_VISIBLE_DEVICES"),
))
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
# 分布式初始化
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}")
if torch.cuda.is_available() and local_rank >= 0:
torch.cuda.set_device(local_rank)
dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} "
f"name={torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
dbg("no cuda or invalid local_rank; not calling set_device")
if world_size > 1 and dist.is_available() and not dist.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
dbg(f"init_process_group backend={backend} via env://")
dist.init_process_group(backend=backend, init_method="env://")
else:
dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}")
if dist.is_available() and dist.is_initialized():
try:
dbg(f"dist.get_backend()={dist.get_backend()} "
f"dist.get_world_size()={dist.get_world_size()} dist.get_rank()={dist.get_rank()}")
except Exception as e:
dbg(f"dist query error: {e}")
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
try:
if getattr(tokenizer, "padding_side", None) != "left":
tokenizer.padding_side = "left"
except Exception:
pass
tokenizer.model_max_length = args.seq_len
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
# dtype
def _bf16_supported():
if not torch.cuda.is_available():
return False
if hasattr(torch.cuda, "is_bf16_supported"):
return torch.cuda.is_bf16_supported()
major, minor = torch.cuda.get_device_capability()
return (major, minor) >= (8, 0)
use_bf16 = bool(args.bf16 and _bf16_supported())
dtype = (torch.bfloat16 if use_bf16 else
(torch.float16 if torch.cuda.is_available() else torch.float32))
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
except Exception:
pass
# 基座模型
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
attn_implementation="sdpa",
)
# pad/alibi 等
model.config.pad_token_id = tokenizer.pad_token_id
if getattr(model, "generation_config", None) is not None:
model.generation_config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False # 训练必须关掉 cache
# ============ 关键改动:注入 LoRA ============
# 1) 决定 LoRA 目标模块:默认“全模型所有线性层”
exclude = [x.strip() for x in args.lora_exclude.split(",") if x.strip()]
target_modules = discover_all_linear_leaf_names(model, exclude)
if is_main_process():
print(f"[lora] target_modules (auto, all-linear minus exclude) = {target_modules}", flush=True)
# 2) 构造 LoRA 配置并注入
from peft import LoraConfig, get_peft_model
lora_cfg = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias=args.lora_bias,
task_type="CAUSAL_LM",
target_modules=target_modules,
)
model = get_peft_model(model, lora_cfg)
# 3) 再次配置梯度检查点(注入后调用更稳)
if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
# 4) 打印可训练参数占比
try:
from peft import get_peft_model_state_dict
trainable, total = 0, 0
for n, p in model.named_parameters():
total += p.numel()
if p.requires_grad:
trainable += p.numel()
pct = (trainable / total * 100.0) if total else 0.0
if is_main_process():
print(f"[lora] trainable params: {trainable} / {total} ({pct:.2f}%)", flush=True)
except Exception:
pass
# ============ LoRA 注入结束 ============
dbg(f"post-config: use_cache={model.config.use_cache} "
f"model.pad_token_id={model.config.pad_token_id} "
f"gen.pad_token_id={getattr(getattr(model,'generation_config',None),'pad_token_id',None)} "
f"tok.pad={tokenizer.pad_token}/{tokenizer.pad_token_id}")
assert tokenizer.pad_token_id is not None, "tokenizer.pad_token_id must not be None"
assert model.config.pad_token_id == tokenizer.pad_token_id, \
f"model.pad_token_id {model.config.pad_token_id} != tokenizer.pad_token_id {tokenizer.pad_token_id}"
# ===== 数据检查 =====
host = socket.gethostname()
files = sorted(glob.glob(args.data_glob))
if len(files) == 0:
raise FileNotFoundError(
f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n"
"每台机器都必须在相同本地路径下放置数据;"
)
if is_main_process():
print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True)
ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_probe():
for ex in ds_stream_probe:
yield ex
train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len)
try:
sample = next(iter(train_stream_probe))
except StopIteration:
raise RuntimeError(
f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。"
)
ids, attn, labs = sample["input_ids"], sample["attention_mask"], sample["labels"]
assert (labs != -100).any(), "[fatal] no supervised tokens in first valid sample"
assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100"
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed)
# train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
train_stream = QwenChatSFTDataset(ds_stream2, tokenizer, seq_len=args.seq_len)
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
def has_at_least(stream, n: int):
it = iter(stream)
for _ in range(n):
try:
next(it)
except StopIteration:
return 0
return 1
need = max(1, args.gradient_accumulation_steps)
local_ok = has_at_least(probe_stream, need)
if dist.is_available() and dist.is_initialized():
t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu"))
dist.all_reduce(t, op=dist.ReduceOp.MIN)
if t.item() == 0:
if is_main_process():
print(
f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批 (GA={need})。 ",
flush=True
)
dist.barrier()
sys.exit(2)
else:
if local_ok == 0:
print(
f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批 (GA={need})。",
flush=True
)
sys.exit(2)
# ---- Eval ----
eval_dataset: Optional[Dataset] = None
class ListDataset(Dataset):
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
if args.eval_data_glob:
eval_files = sorted(glob.glob(args.eval_data_glob))
if len(eval_files) == 0:
raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
if is_main_process():
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
def ex_iter_eval():
for ex in ds_eval_stream:
yield ex
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
eval_items: List[Dict[str, torch.Tensor]] = []
for sample in eval_iterable:
eval_items.append(sample)
if len(eval_items) == 0:
raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。")
eval_dataset = ListDataset(eval_items)
elif args.eval_ratio and args.eval_ratio > 0:
desired_eval_batches = 200
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_eval2():
for ex in tmp_stream:
yield ex
eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len)
eval_samples = []
it = iter(eval_stream)
for _ in range(desired_eval_batches):
try:
eval_samples.append(next(it))
except StopIteration:
break
if len(eval_samples) > 0:
eval_dataset = ListDataset(eval_samples)
if eval_dataset is not None:
ws = max(world_size, 1)
be = max(1, args.per_device_eval_batch_size)
global_bs = ws * be
r = len(eval_dataset) % global_bs
if r != 0:
pad_need = global_bs - r
eval_dataset.items += eval_dataset.items[:pad_need]
if is_main_process():
print(f"[eval] padded eval set to {len(eval_dataset)} "
f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})",
flush=True)
assert len(eval_dataset) % global_bs == 0
data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs")
os.makedirs(logging_dir, exist_ok=True)
# ---- TrainingArguments ----
ta_kwargs = {}
sig = inspect.signature(TrainingArguments.__init__).parameters
if eval_dataset is not None:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "steps"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "steps"
else:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "no"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "no"
ta_sig = inspect.signature(TrainingArguments.__init__).parameters
ta_kwargs2 = dict(
output_dir=args.output_dir,
logging_dir=logging_dir,
run_name=f"lora-{os.path.basename(args.output_dir)}-{socket.gethostname()}",
do_train=True,
do_eval=(eval_dataset is not None),
eval_steps=(args.eval_steps if eval_dataset is not None else None),
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
num_train_epochs = args.num_train_epochs,
max_steps=args.max_steps if args.max_steps > 0 else -1,
lr_scheduler_type="cosine",
logging_steps=args.log_interval,
save_steps=args.save_steps,
save_total_limit=2,
deepspeed=(args.deepspeed if use_ds else None),
dataloader_drop_last=False,
dataloader_num_workers=0,
label_smoothing_factor=0.0,
per_device_eval_batch_size=args.per_device_eval_batch_size,
report_to=([] if args.report_to == "none" else [args.report_to]),
gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False,
save_on_each_node=True,
logging_first_step=True,
**ta_kwargs,
)
if "dataloader_pin_memory" in ta_sig:
ta_kwargs2["dataloader_pin_memory"] = False
if "torch_compile" in ta_sig:
ta_kwargs2["torch_compile"] = False
ta_kwargs2.update({"bf16": (dtype==torch.bfloat16), "fp16": (dtype==torch.float16)})
training_args = TrainingArguments(**ta_kwargs2)
trainer_kwargs = {}
if "processing_class" in inspect.signature(Trainer.__init__).parameters:
trainer_kwargs["processing_class"] = tokenizer
else:
trainer_kwargs["tokenizer"] = tokenizer
trainer = DebugTrainer(
model=model,
args=training_args,
train_dataset=train_stream,
eval_dataset=eval_dataset,
data_collator=data_collator,
**trainer_kwargs,
)
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
# ==== 断点恢复判定 ====
def last_step(path: str) -> int:
ck = get_last_checkpoint(path)
if ck is None:
return -1
base = os.path.basename(ck)
try:
return int(base.split("-")[-1])
except Exception:
return -1
local_last = last_step(args.output_dir)
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")
resume_flag = None
if dist.is_available() and dist.is_initialized():
has_local = torch.tensor(1 if local_last >= 0 else 0, device=device)
dist.all_reduce(has_local, op=dist.ReduceOp.MIN)
if has_local.item() == 1:
ts = torch.tensor(local_last, device=device)
world = dist.get_world_size()
buf = [torch.zeros_like(ts) for _ in range(world)]
dist.all_gather(buf, ts)
steps = [b.item() for b in buf]
k = min(steps)
if k >= 0:
resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}")
if is_main_process():
print(f"[resume] steps={steps}, resume={resume_flag}", flush=True)
else:
if local_last >= 0:
resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}")
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}")
if dist.is_available() and dist.is_initialized():
present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device)
dist.all_reduce(present, op=dist.ReduceOp.MIN)
if present.item() == 0:
if is_main_process():
print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True)
resume_flag = None
dist.barrier()
else:
if resume_flag is not None and not os.path.isdir(resume_flag):
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
resume_flag = None
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
print_once("***** Starting LoRA training *****")
dbg(f"allocated={torch.cuda.memory_allocated()/1024**2:.1f} MB, "
f"reserved={torch.cuda.memory_reserved()/1024**2:.1f} MB")
train_result = trainer.train(resume_from_checkpoint=resume_flag)
# 保存此处保存的是“LoRA 适配器”(非合并的整权重)
trainer.save_model() # 保存到 output_dir, 包含 adapter_model.bin & adapter_config.json
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if eval_dataset is not None:
print_once("***** Running eval *****")
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)
# (可选)合并 LoRA 并另存
if args.merge_lora_and_save and is_main_process():
print("[lora] merging LoRA into base weights ...", flush=True)
merged = model.merge_and_unload() # 需要足够显存/内存
merge_dir = os.path.join(args.output_dir, "merged")
os.makedirs(merge_dir, exist_ok=True)
merged.save_pretrained(merge_dir, safe_serialization=True)
tokenizer.save_pretrained(merge_dir)
print(f"[lora] merged model saved to: {merge_dir}", flush=True)
print_once("Done.")
if __name__ == "__main__":
main()