jd_train/train_sft_lora.py

731 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
os.environ.pop("PYTHONNOUSERSITE", None)
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("WANDB_START_METHOD", "thread")
os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb")
import glob
import socket
import argparse
import inspect
import sys
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
set_seed
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint
# ---------- PATH / CUDA utils ----------
import site, shutil
home = os.path.expanduser("~")
want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"]
cur = os.environ.get("PATH", "").split(":")
new = [d for d in want if d and d not in cur] + cur
os.environ["PATH"] = ":".join(new)
print(f"[env] PATH={os.environ['PATH']}", flush=True)
print(f"[env] which ninja={shutil.which('ninja')} which nvcc={shutil.which('nvcc')}", flush=True)
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda-11.8")
ld = os.environ.get("LD_LIBRARY_PATH", "")
cuda_lib = "/usr/local/cuda-11.8/lib64"
if cuda_lib not in ld.split(":"):
os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib
print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True)
os.environ.pop("DS_BUILD_OPS", None)
os.environ.pop("DS_SKIP_CUDA_BUILD", None)
try:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.insert(0, user_site)
except Exception:
pass
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
os.environ.setdefault("MAX_JOBS", "12")
if shutil.which("ninja") is None:
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True)
try:
from deepspeed.ops.op_builder import CPUAdamBuilder
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK", flush=True)
except Exception as e:
if "Ninja is required to load C++ extensions" in str(e):
os.environ["USE_NINJA"] = "0"
print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
from deepspeed.ops.op_builder import CPUAdamBuilder
CPUAdamBuilder().load()
print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
else:
print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True)
# 不致命LoRA 不依赖这个算子,继续运行
pass
# ---------- helpers ----------
def is_main_process():
return int(os.environ.get("RANK", "0")) == 0
def print_once(*args, **kwargs):
if is_main_process():
print(*args, **kwargs, flush=True)
class DebugTrainer(Trainer):
def training_step(self, model, inputs, num_items_in_batch=None):
if not hasattr(self, "_dbg_printed"):
rank = int(os.environ.get("RANK", "0"))
host = socket.gethostname()
ids = inputs["input_ids"]; msk = inputs["attention_mask"]; labs = inputs["labels"]
print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} "
f"supervised={(labs!=-100).sum().item()}", flush=True)
print(f"[step0][host={host} RANK={rank}] "
f"input_ids.shape={tuple(ids.shape)} "
f"attention_mask.shape={tuple(msk.shape)} "
f"labels.shape={tuple(labs.shape)} "
f"num_items_in_batch={num_items_in_batch}", flush=True)
self._dbg_printed = True
return super().training_step(model, inputs, num_items_in_batch)
class CsvLossLogger(TrainerCallback):
def __init__(self, csv_path: str):
self.csv_path = csv_path
if is_main_process():
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
with open(self.csv_path, "w", encoding="utf-8") as f:
f.write("step,loss,lr,total_flos\n")
def on_train_begin(self, args, state, control, **kwargs):
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None: return
cur = int(getattr(state, "global_step", 0) or 0)
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a")
if tot and not hasattr(self, "_tot_announced"):
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
self._tot_announced = True
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] step {cur}/{tot} ({pct}) "
f"loss={logs.get('loss')} lr={logs.get('learning_rate')}", flush=True)
if not is_main_process(): return
with open(self.csv_path, "a", encoding="utf-8") as f:
f.write(f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
# ---------- assistant span detection ----------
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
spans: List[Tuple[int, int]] = []
open_tag = "<|im_start|>assistant\n"
close_tag = "<|im_end|>\n"
pos = 0
while True:
a = rendered.find(open_tag, pos)
if a == -1: break
s = a + len(open_tag)
b = rendered.find(close_tag, s)
if b == -1: break
spans.append((s, b))
pos = b + len(close_tag)
return spans
# ---------- Dataset (supervise assistant incl. <think> tags) ----------
class QwenChatSFTDataset(IterableDataset):
def __init__(self, ex_iter: Iterable[dict], tokenizer: AutoTokenizer, seq_len: int = 4096):
self.ex_iter = ex_iter
self.tok = tokenizer
self.seq_len = seq_len
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0
dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
rank = int(os.environ.get("RANK", "0"))
lrank = int(os.environ.get("LOCAL_RANK", "-1"))
host = socket.gethostname()
for ex in self.ex_iter:
msgs = ex.get("messages", None)
if not msgs or not isinstance(msgs, list): continue
tools = ex.get("tools", None)
try:
rendered: str = self.tok.apply_chat_template(
msgs, tools=tools, add_generation_prompt=False, tokenize=False
)
except TypeError:
rendered: str = self.tok.apply_chat_template(
msgs, add_generation_prompt=False, tokenize=False
)
if not isinstance(rendered, str) or not rendered.strip(): continue
spans = _assistant_char_spans(rendered)
if not spans: continue
enc = self.tok(rendered, add_special_tokens=False, return_offsets_mapping=True)
input_ids: List[int] = enc["input_ids"]
offsets: List[Tuple[int, int]] = enc["offset_mapping"]
if not input_ids: continue
labels = [-100] * len(input_ids)
def in_any_span(lo: int, hi: int) -> bool:
for s, e in spans:
if not (hi <= s or lo >= e):
return True
return False
for i, (lo, hi) in enumerate(offsets):
if in_any_span(lo, hi):
labels[i] = input_ids[i]
if all(v == -100 for v in labels): # 无监督 token
continue
# ---- assistant-aware truncation: keep last assistant not cut off
if len(input_ids) > self.seq_len:
s_last, e_last = spans[-1]
j = 0
while j < len(offsets) and offsets[j][1] <= s_last: j += 1
k_excl = j
while k_excl < len(offsets) and offsets[k_excl][0] < e_last: k_excl += 1
A = max(0, k_excl - j)
if A >= self.seq_len:
start = max(0, k_excl - self.seq_len); end = start + self.seq_len
else:
start = max(0, min(j, len(input_ids) - self.seq_len))
end = start + self.seq_len
if end < k_excl:
end = k_excl; start = end - self.seq_len
if start < 0: start = 0; end = self.seq_len
leftover = self.seq_len - A
left_wish = leftover // 2
start = max(0, min(j - left_wish, start))
end = start + self.seq_len
if end < k_excl:
end = k_excl; start = end - self.seq_len
if start < 0: start = 0; end = self.seq_len
input_ids = input_ids[start:end]
labels = labels[start:end]
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
L = len(input_ids)
if L < self.seq_len:
pad = self.seq_len - L
input_ids = ([pad_id]*pad) + input_ids
labels = ([-100]*pad) + labels
attn_mask = [0]*pad + [1]*L
else:
attn_mask = [1]*self.seq_len
assert len(input_ids) == self.seq_len
assert len(labels) == self.seq_len
assert len(attn_mask) == self.seq_len
if dbg_on and self._dbg_seen < dbg_limit:
sup_tok = sum(1 for v in labels if v != -100)
print(f"[sample][host={host} RANK={rank} LRank={lrank}] "
f"toks={len(input_ids)} sup_toks={sup_tok} seq_len={self.seq_len} pad_id={pad_id}", flush=True)
if sup_tok == 0:
print(f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> skipped", flush=True)
self._dbg_seen += 1
yield {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
# ---------- Collator ----------
class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
self.tok = tokenizer
self.pad_to_length = pad_to_length
assert self.tok.pad_token_id is not None
def __call__(self, features):
if not features:
raise RuntimeError(f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator.")
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
input_ids = [_to_list(f["input_ids"]) for f in features]
attn_masks = [_to_list(f["attention_mask"]) for f in features]
labels_list = [_to_list(f["labels"]) for f in features]
max_len_in_batch = max(len(x) for x in input_ids)
target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch
pad_id = self.tok.pad_token_id
batch_inp, batch_attn, batch_lab = [], [], []
for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
pad_len = target_len - len(inp)
if pad_len < 0:
inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len]
pad_len = 0
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
if os.environ.get("DBG_COLLATE","0") == "1":
print(f"[collate][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] "
f"features={len(features)} target_len={target_len}", flush=True)
return {
"input_ids": torch.stack(batch_inp, dim=0),
"attention_mask": torch.stack(batch_attn, dim=0),
"labels": torch.stack(batch_lab, dim=0),
}
# ---------- Args ----------
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--model_name_or_path", type=str, required=True)
ap.add_argument("--data_glob", type=str, required=True)
ap.add_argument("--output_dir", type=str, required=True)
ap.add_argument("--seq_len", type=int, default=4096)
ap.add_argument("--learning_rate", type=float, default=1e-4) # LoRA 通常可更大学习率
ap.add_argument("--weight_decay", type=float, default=0.0) # LoRA 常设 0 或很小
ap.add_argument("--warmup_ratio", type=float, default=0.03)
ap.add_argument("--num_train_epochs", type=float, default=1.0)
ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0)
ap.add_argument("--seed", type=int, default=1337)
ap.add_argument("--gradient_checkpointing", action="store_true")
ap.add_argument("--bf16", action="store_true")
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
ap.add_argument("--report_to", type=str, default="tensorboard", choices=["none","tensorboard","wandb"])
ap.add_argument("--wandb_project", type=str, default="ds-qwen3-lora")
ap.add_argument("--eval_data_glob", type=str, default=None)
ap.add_argument("--local_rank", type=int, default=-1)
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
ap.add_argument("--deepspeed", type=str, default=None)
# ---- LoRA specific ----
ap.add_argument("--lora_r", type=int, default=16)
ap.add_argument("--lora_alpha", type=float, default=32)
ap.add_argument("--lora_dropout", type=float, default=0.05)
ap.add_argument("--lora_target", type=str, default="auto",
help='逗号分隔,如 "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj";或 "auto"')
ap.add_argument("--qlora", action="store_true", help="使用 4bit (NF4) QLoRA多机 DS 不建议)")
ap.add_argument("--merge_lora_and_save", action="store_true",
help="训练后在 rank0 合并 LoRA 到基座并另存(注意显存/内存占用)")
return ap.parse_args()
# ---------- LoRA helpers ----------
def _auto_lora_targets(model) -> List[str]:
"""
针对 Qwen/Llama 族,自动挑选常见的线性层名字;仅匹配存在的模块。
"""
cand = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj",
"w1","w2","w3", "W_pack", "o_attn", "o_proj"] # 覆盖不同实现命名
present = set()
for name, module in model.named_modules():
if any(name.endswith(f".{c}") or name == c for c in cand):
present.add(name.split(".")[-1])
# 回落:若一个都没匹配到,使用“所有 nn.Linear”
if not present:
return ["all-linear"]
# 去重且保序
order = []
for c in cand:
if c in present: order.append(c)
return order
# ---------- main ----------
def main():
args = parse_args()
if os.environ.get("RANK","0") != "0" and args.report_to == "wandb":
print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True)
args.report_to = "none"
set_seed(args.seed)
# DeepSpeed enable?
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
dschf = None
if use_ds:
try:
from transformers.integrations.deepspeed import HfDeepSpeedConfig
src = "transformers.integrations.deepspeed"
except Exception:
from transformers import HfDeepSpeedConfig
src = "transformers"
dschf = HfDeepSpeedConfig(args.deepspeed)
print(f"[dbg] HfDeepSpeedConfig loaded from {src}", flush=True)
if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
import transformers as hf
try:
import deepspeed as ds
ds_ver = ds.__version__
except Exception:
ds_ver = "n/a"
def dbg(msg):
print(f"[dbg][host={socket.gethostname()} RANK={os.environ.get('RANK','0')} "
f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}", flush=True)
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
dbg(f"args={args}")
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
os.environ.get("WORLD_SIZE"), os.environ.get("RANK"),
os.environ.get("LOCAL_RANK", str(args.local_rank)),
os.environ.get("MASTER_ADDR"), os.environ.get("MASTER_PORT"),
os.environ.get("CUDA_VISIBLE_DEVICES"),
))
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
# init dist
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
if torch.cuda.is_available() and local_rank >= 0:
torch.cuda.set_device(local_rank)
dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} "
f"name={torch.cuda.get_device_name(torch.cuda.current_device())}")
if world_size > 1 and dist.is_available() and not dist.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
dbg(f"init_process_group backend={backend} via env://")
dist.init_process_group(backend=backend, init_method="env://")
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
try:
if getattr(tokenizer, "padding_side", None) != "left":
tokenizer.padding_side = "left"
except Exception:
pass
from transformers import PreTrainedTokenizerFast
if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False):
raise RuntimeError("需要 Fast tokenizer 以获取 offset_mapping请安装 tokenizers>=0.14。")
tokenizer.model_max_length = args.seq_len
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} model_max_length={tokenizer.model_max_length}")
# dtype
def _bf16_supported():
if not torch.cuda.is_available(): return False
if hasattr(torch.cuda, "is_bf16_supported"):
return torch.cuda.is_bf16_supported()
major, minor = torch.cuda.get_device_capability()
return (major, minor) >= (8, 0)
use_bf16 = bool(args.bf16 and _bf16_supported())
compute_dtype = torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32)
# -------- load base model (with/without 4bit) --------
quantization_config = None
if args.qlora:
try:
from transformers import BitsAndBytesConfig
from peft import prepare_model_for_kbit_training
except Exception as e:
raise RuntimeError("使用 --qlora 需要安装 bitsandbytes>=0.41 与 peft。") from e
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=compute_dtype
)
# 4bit 下不要传 attn_implementation="sdpa" 给部分旧版 torch
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=compute_dtype,
trust_remote_code=True,
low_cpu_mem_usage=True,
quantization_config=quantization_config,
device_map=None # 用 DeepSpeed/Trainer 接管
)
if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=compute_dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
attn_implementation="sdpa",
)
if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False
# -------- wrap with LoRA --------
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
if args.lora_target.strip().lower() == "auto":
targets = _auto_lora_targets(model)
else:
targets = [x.strip() for x in args.lora_target.split(",") if x.strip()]
if not targets:
targets = _auto_lora_targets(model)
lora_cfg = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=targets,
bias="none",
inference_mode=False
)
model = get_peft_model(model, lora_cfg)
# 冻结确认
if is_main_process():
try:
model.print_trainable_parameters()
except Exception:
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"[LoRA] trainable={trainable:,} / total={total:,} ({trainable/total:.2%})", flush=True)
# -------- data streams --------
files = sorted(glob.glob(args.data_glob))
if len(files) == 0:
raise FileNotFoundError(f"No files matched DATA_GLOB={args.data_glob}")
if is_main_process():
print(f"[data] matched {len(files)} files, example[0]={files[0]}", flush=True)
ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_probe():
for ex in ds_stream_probe: yield ex
train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len)
try:
_ = next(iter(train_stream_probe))
except StopIteration:
raise RuntimeError("[data] 样本结构不合法或全部被裁切。")
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
def has_at_least(stream, n: int):
it = iter(stream)
for _ in range(n):
try: next(it)
except StopIteration: return 0
return 1
need = max(1, args.gradient_accumulation_steps)
local_ok = has_at_least(probe_stream, need)
if dist.is_available() and dist.is_initialized():
t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank>=0 else "cpu"))
dist.all_reduce(t, op=dist.ReduceOp.MIN)
if t.item() == 0:
if is_main_process():
print(f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批。", flush=True)
dist.barrier(); sys.exit(2)
else:
if local_ok == 0:
print(f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批。", flush=True)
sys.exit(2)
# eval
eval_dataset: Optional[Dataset] = None
class ListDataset(Dataset):
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
if args.eval_data_glob:
eval_files = sorted(glob.glob(args.eval_data_glob))
if len(eval_files) == 0:
raise FileNotFoundError(f"No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
if is_main_process():
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
def ex_iter_eval():
for ex in ds_eval_stream: yield ex
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
eval_items: List[Dict[str, torch.Tensor]] = [s for s in eval_iterable]
if len(eval_items) == 0:
raise RuntimeError("[eval] 读到了 0 条有效样本。")
eval_dataset = ListDataset(eval_items)
# pad to global batch size
ws = max(int(os.environ.get("WORLD_SIZE","1")), 1)
be = max(1, args.per_device_eval_batch_size)
global_bs = ws * be
r = len(eval_dataset) % global_bs
if r != 0:
pad_need = global_bs - r
eval_dataset.items += eval_dataset.items[:pad_need]
if is_main_process():
print(f"[eval] padded eval set to {len(eval_dataset)}", flush=True)
# collator
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
# training args
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs"); os.makedirs(logging_dir, exist_ok=True)
ta_kwargs = {}
sig = inspect.signature(TrainingArguments.__init__).parameters
if eval_dataset is not None:
if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "steps"
elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "steps"
else:
if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "no"
elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "no"
ta_kwargs2 = dict(
output_dir=args.output_dir,
logging_dir=logging_dir,
do_train=True,
do_eval=(eval_dataset is not None),
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
max_steps=args.max_steps if args.max_steps > 0 else -1,
lr_scheduler_type="cosine",
logging_steps=args.log_interval,
save_steps=args.save_steps,
save_total_limit=2,
deepspeed=(args.deepspeed if use_ds else None),
dataloader_drop_last=False,
dataloader_num_workers=0,
per_device_eval_batch_size=args.per_device_eval_batch_size,
report_to=([] if args.report_to == "none" else [args.report_to]),
gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False,
save_on_each_node=True,
logging_first_step=True,
**ta_kwargs,
)
# 精度QLoRA/LoRA 均按 compute_dtype 设置
if "dataloader_pin_memory" in sig: ta_kwargs2["dataloader_pin_memory"] = False
if "torch_compile" in sig: ta_kwargs2["torch_compile"] = False
ta_kwargs2.update({
"bf16": (compute_dtype==torch.bfloat16),
"fp16": (compute_dtype==torch.float16),
})
training_args = TrainingArguments(**ta_kwargs2)
# pass tokenizer / processing_class
trainer_kwargs = {}
if "processing_class" in inspect.signature(Trainer.__init__).parameters:
trainer_kwargs["processing_class"] = tokenizer
else:
trainer_kwargs["tokenizer"] = tokenizer
trainer = DebugTrainer(
model=model,
args=training_args,
train_dataset=train_stream,
eval_dataset=eval_dataset,
data_collator=data_collator,
**trainer_kwargs
)
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
# resume (per-node local checkpoint agreement)
def last_step(path: str) -> int:
ck = get_last_checkpoint(path)
if ck is None: return -1
base = os.path.basename(ck)
try: return int(base.split("-")[-1])
except Exception: return -1
local_last = last_step(args.output_dir)
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank>=0) else "cpu")
resume_flag = None
if dist.is_available() and dist.is_initialized():
has_local = torch.tensor(1 if local_last >= 0 else 0, device=device)
dist.all_reduce(has_local, op=dist.ReduceOp.MIN)
if has_local.item() == 1:
ts = torch.tensor(local_last, device=device)
world = dist.get_world_size()
buf = [torch.zeros_like(ts) for _ in range(world)]
dist.all_gather(buf, ts)
steps = [b.item() for b in buf]
k = min(steps)
if k >= 0:
resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}")
if is_main_process():
print(f"[resume] steps={steps}, resume={resume_flag}", flush=True)
else:
if local_last >= 0:
resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}")
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}")
if dist.is_available() and dist.is_initialized():
present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device)
dist.all_reduce(present, op=dist.ReduceOp.MIN)
if present.item() == 0:
if is_main_process():
print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True)
resume_flag = None
dist.barrier()
else:
if resume_flag is not None and not os.path.isdir(resume_flag):
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
resume_flag = None
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
print_once("***** Starting LoRA training *****")
print(f"[dbg] allocated={torch.cuda.memory_allocated()/1024**2:.1f} MB, "
f"reserved={torch.cuda.memory_reserved()/1024**2:.1f} MB", flush=True)
train_result = trainer.train(resume_from_checkpoint=resume_flag)
# save adapter (not the full base)
trainer.save_model() # 对 PeftModel只保存 adapter 权重到 output_dir
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# eval
if eval_dataset is not None:
print_once("***** Running eval *****")
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)
# optional merge
if args.merge_lora_and_save and is_main_process():
print("[merge] Merging LoRA into base model ...", flush=True)
try:
if isinstance(trainer.model, PeftModel):
merged = trainer.model.merge_and_unload()
else:
merged = trainer.model
merge_dir = os.path.join(args.output_dir, "merged-full-model")
os.makedirs(merge_dir, exist_ok=True)
merged.save_pretrained(merge_dir, safe_serialization=True)
tokenizer.save_pretrained(merge_dir)
print(f"[merge] Saved merged model to: {merge_dir}", flush=True)
except Exception as e:
print(f"[merge] FAILED: {e}", flush=True)
print_once("Done.")
if __name__ == "__main__":
main()