This commit is contained in:
parent
30de9a5e5c
commit
e22e569303
|
|
@ -0,0 +1,16 @@
|
|||
deepspeed --hostfile hostfile \
|
||||
--num_nodes 6 --num_gpus 4 \
|
||||
train_sft_lora.py \
|
||||
--model_name_or_path /home/test/Qwen3-32B \
|
||||
--data_glob "/home/test/datasets/my_corpus/train*.jsonl" \
|
||||
--output_dir /home/test/checkpoints/q3-32b-lora \
|
||||
--seq_len 4096 \
|
||||
--bf16 \
|
||||
--gradient_accumulation_steps 64 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--learning_rate 1e-4 \
|
||||
--warmup_ratio 0.03 \
|
||||
--lora_r 16 --lora_alpha 32 --lora_dropout 0.05 \
|
||||
--lora_target auto \
|
||||
--deepspeed /home/test/jd_train/ds_config_zero3.json \
|
||||
--report_to wandb --wandb_project ds-qwen3-lora
|
||||
275
train_sft_ds.py
275
train_sft_ds.py
|
|
@ -194,10 +194,174 @@ class CsvLossLogger(TrainerCallback):
|
|||
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
|
||||
)
|
||||
|
||||
# ----------------- 仅监督 assistant 的数据集 -----------------
|
||||
# # ----------------- 仅监督 assistant 的数据集 -----------------
|
||||
# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||
# """
|
||||
# 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
|
||||
# """
|
||||
# spans: List[Tuple[int, int]] = []
|
||||
# open_tag = "<|im_start|>assistant\n"
|
||||
# close_tag = "<|im_end|>\n"
|
||||
# pos = 0
|
||||
# while True:
|
||||
# a = rendered.find(open_tag, pos)
|
||||
# if a == -1:
|
||||
# break
|
||||
# start = a + len(open_tag)
|
||||
# b = rendered.find(close_tag, start)
|
||||
# if b == -1:
|
||||
# break
|
||||
# spans.append((start, b))
|
||||
# pos = b + len(close_tag)
|
||||
# return spans
|
||||
|
||||
# class QwenChatSFTDataset(IterableDataset):
|
||||
# """
|
||||
# 期望 jsonl 每行形如:
|
||||
# {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
|
||||
# 可选包含工具:
|
||||
# {"messages":[...], "tools":[{...}]}
|
||||
|
||||
# 工作流:
|
||||
# - 使用 tokenizer.apply_chat_template 渲染
|
||||
# - 仅对 assistant 片段计损失(其他 token 的 label = -100)
|
||||
# - 超长序列保留尾部(通常包含回答)
|
||||
# """
|
||||
# def __init__(self,
|
||||
# ex_iter: Iterable[dict],
|
||||
# tokenizer: AutoTokenizer,
|
||||
# seq_len: int = 4096):
|
||||
# self.ex_iter = ex_iter
|
||||
# self.tok = tokenizer
|
||||
# self.seq_len = seq_len
|
||||
|
||||
# def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
||||
|
||||
# # >>> DEBUG BEGIN
|
||||
# dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
|
||||
# if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0
|
||||
# dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
|
||||
# rank = int(os.environ.get("RANK", "0"))
|
||||
# lrank = int(os.environ.get("LOCAL_RANK", "-1"))
|
||||
# host = socket.gethostname()
|
||||
# # >>> DEBUG END
|
||||
|
||||
# for ex in self.ex_iter:
|
||||
# msgs = ex.get("messages", None)
|
||||
# if not msgs or not isinstance(msgs, list):
|
||||
# continue
|
||||
|
||||
# # 可选过滤 think
|
||||
# bad = False
|
||||
# for m in msgs:
|
||||
# if m.get("role") == "assistant" and isinstance(m.get("content"), str):
|
||||
# c = m["content"]
|
||||
# if "<think>" in c and "</think>" in c:
|
||||
# inner = c.split("<think>")[-1].split("</think>")[0].strip()
|
||||
# if inner:
|
||||
# bad = True; break
|
||||
# # 注销这里就可以确保<think></think>参与计算监督微调,打开就表示跳过
|
||||
# if bad:
|
||||
# continue
|
||||
|
||||
# tools = ex.get("tools", None)
|
||||
|
||||
# # 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况
|
||||
# try:
|
||||
# rendered: str = self.tok.apply_chat_template(
|
||||
# msgs, tools=tools, add_generation_prompt=False, tokenize=False
|
||||
# )
|
||||
# except TypeError:
|
||||
# rendered: str = self.tok.apply_chat_template(
|
||||
# msgs, add_generation_prompt=False, tokenize=False
|
||||
# )
|
||||
|
||||
|
||||
# if not isinstance(rendered, str) or not rendered.strip():
|
||||
# continue
|
||||
|
||||
# spans = _assistant_char_spans(rendered)
|
||||
# if not spans:
|
||||
# continue
|
||||
|
||||
# enc = self.tok(
|
||||
# rendered,
|
||||
# add_special_tokens=False,
|
||||
# return_offsets_mapping=True
|
||||
# )
|
||||
# input_ids: List[int] = enc["input_ids"]
|
||||
# offsets: List[Tuple[int, int]] = enc["offset_mapping"]
|
||||
|
||||
# if not input_ids:
|
||||
# continue
|
||||
|
||||
# labels = [-100] * len(input_ids)
|
||||
|
||||
# def in_any_span(lo: int, hi: int) -> bool:
|
||||
# for s, e in spans:
|
||||
# if not (hi <= s or lo >= e):
|
||||
# return True
|
||||
# return False
|
||||
|
||||
# for i, (lo, hi) in enumerate(offsets):
|
||||
# if in_any_span(lo, hi):
|
||||
# labels[i] = input_ids[i]
|
||||
|
||||
# # —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len ——
|
||||
# # 1) 截断到 seq_len(保留尾部)
|
||||
# if len(input_ids) > self.seq_len:
|
||||
# input_ids = input_ids[-self.seq_len:]
|
||||
# labels = labels[-self.seq_len:]
|
||||
|
||||
# # 2) 左侧补齐到 seq_len(保证所有样本长度一致)
|
||||
# pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
|
||||
# L = len(input_ids)
|
||||
# if L < self.seq_len:
|
||||
# pad = self.seq_len - L
|
||||
# input_ids = ([pad_id] * pad) + input_ids
|
||||
# labels = ([-100] * pad) + labels
|
||||
# attn_mask = [0] * pad + [1] * L
|
||||
# else:
|
||||
# # 恰好等于 seq_len
|
||||
# attn_mask = [1] * self.seq_len
|
||||
|
||||
# # 若没有任何可训练 token(labels 全 -100),跳过
|
||||
# if all(v == -100 for v in labels):
|
||||
# continue
|
||||
|
||||
# assert len(input_ids) == self.seq_len
|
||||
# assert len(labels) == self.seq_len
|
||||
# assert len(attn_mask) == self.seq_len
|
||||
|
||||
# # >>> DEBUG PRINT(此时变量已定义)
|
||||
# if dbg_on and self._dbg_seen < dbg_limit:
|
||||
# sup_tok = sum(1 for v in labels if v != -100)
|
||||
# print(
|
||||
# f"[sample][host={host} RANK={rank} LRank={lrank}] "
|
||||
# f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} "
|
||||
# f"seq_len={self.seq_len} pad_id={pad_id}",
|
||||
# flush=True
|
||||
# )
|
||||
# if sup_tok == 0:
|
||||
# print(
|
||||
# f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped",
|
||||
# flush=True
|
||||
# )
|
||||
# self._dbg_seen += 1
|
||||
# # <<< DEBUG PRINT
|
||||
|
||||
# yield {
|
||||
# "input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||||
# "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
|
||||
# "labels": torch.tensor(labels, dtype=torch.long),
|
||||
# }
|
||||
|
||||
|
||||
# ----------------- 工具:提取 assistant 字符区间 -----------------
|
||||
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
|
||||
在 apply_chat_template 渲染后的纯文本中,返回所有 assistant 段的字符区间 [start, end)
|
||||
这些区间覆盖了 assistant 的全部内容(包括 <think>...</think> 标签与正文)。
|
||||
"""
|
||||
spans: List[Tuple[int, int]] = []
|
||||
open_tag = "<|im_start|>assistant\n"
|
||||
|
|
@ -207,14 +371,16 @@ def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
|||
a = rendered.find(open_tag, pos)
|
||||
if a == -1:
|
||||
break
|
||||
start = a + len(open_tag)
|
||||
b = rendered.find(close_tag, start)
|
||||
s = a + len(open_tag)
|
||||
b = rendered.find(close_tag, s)
|
||||
if b == -1:
|
||||
break
|
||||
spans.append((start, b))
|
||||
spans.append((s, b))
|
||||
pos = b + len(close_tag)
|
||||
return spans
|
||||
|
||||
|
||||
# ----------------- 数据集:SFT(监督 assistant 全段,含 <think> 标签与内容) -----------------
|
||||
class QwenChatSFTDataset(IterableDataset):
|
||||
"""
|
||||
期望 jsonl 每行形如:
|
||||
|
|
@ -225,7 +391,7 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
工作流:
|
||||
- 使用 tokenizer.apply_chat_template 渲染
|
||||
- 仅对 assistant 片段计损失(其他 token 的 label = -100)
|
||||
- 超长序列保留尾部(通常包含回答)
|
||||
- 截断时“优先确保最后一个 assistant 不被截断”;若其长度 > seq_len,则保留其“结尾”以避免切尾
|
||||
"""
|
||||
def __init__(self,
|
||||
ex_iter: Iterable[dict],
|
||||
|
|
@ -251,18 +417,7 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
if not msgs or not isinstance(msgs, list):
|
||||
continue
|
||||
|
||||
# 可选过滤 think
|
||||
bad = False
|
||||
for m in msgs:
|
||||
if m.get("role") == "assistant" and isinstance(m.get("content"), str):
|
||||
c = m["content"]
|
||||
if "<think>" in c and "</think>" in c:
|
||||
inner = c.split("<think>")[-1].split("</think>")[0].strip()
|
||||
if inner:
|
||||
bad = True; break
|
||||
if bad:
|
||||
continue
|
||||
|
||||
# —— 不再过滤 <think>:显式允许其参与监督(包括标签与正文)
|
||||
tools = ex.get("tools", None)
|
||||
|
||||
# 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况
|
||||
|
|
@ -275,7 +430,6 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
msgs, add_generation_prompt=False, tokenize=False
|
||||
)
|
||||
|
||||
|
||||
if not isinstance(rendered, str) or not rendered.strip():
|
||||
continue
|
||||
|
||||
|
|
@ -283,6 +437,7 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
if not spans:
|
||||
continue
|
||||
|
||||
# 编码并拿到字符偏移,确保与 rendered 对齐
|
||||
enc = self.tok(
|
||||
rendered,
|
||||
add_special_tokens=False,
|
||||
|
|
@ -294,10 +449,12 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
if not input_ids:
|
||||
continue
|
||||
|
||||
# 先对“所有 assistant 片段”打标签;包含 <think> 标签与内容、以及回答正文
|
||||
labels = [-100] * len(input_ids)
|
||||
|
||||
def in_any_span(lo: int, hi: int) -> bool:
|
||||
for s, e in spans:
|
||||
# 与任一 [s, e) 有交集即监督
|
||||
if not (hi <= s or lo >= e):
|
||||
return True
|
||||
return False
|
||||
|
|
@ -306,13 +463,68 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
if in_any_span(lo, hi):
|
||||
labels[i] = input_ids[i]
|
||||
|
||||
# —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len ——
|
||||
# 1) 截断到 seq_len(保留尾部)
|
||||
if len(input_ids) > self.seq_len:
|
||||
input_ids = input_ids[-self.seq_len:]
|
||||
labels = labels[-self.seq_len:]
|
||||
# 若没有任何可训练 token(labels 全 -100),跳过
|
||||
if all(v == -100 for v in labels):
|
||||
continue
|
||||
|
||||
# 2) 左侧补齐到 seq_len(保证所有样本长度一致)
|
||||
# ======== Assistant 感知的截断策略(保证“最后一个 assistant 不被截掉”)========
|
||||
if len(input_ids) > self.seq_len:
|
||||
# 取“最后一个 assistant”的字符区间
|
||||
s_last, e_last = spans[-1]
|
||||
|
||||
# 将字符区间映射到 token 索引区间 [j, k_excl)
|
||||
# j: 第一个 token,其右端 hi > s_last
|
||||
j = 0
|
||||
while j < len(offsets) and offsets[j][1] <= s_last:
|
||||
j += 1
|
||||
# k_excl: 第一个 token,其左端 lo >= e_last(即不再与 [s_last, e_last) 相交)
|
||||
k_excl = j
|
||||
while k_excl < len(offsets) and offsets[k_excl][0] < e_last:
|
||||
k_excl += 1
|
||||
|
||||
A = max(0, k_excl - j) # 最后一个 assistant 覆盖的 token 数
|
||||
|
||||
if A >= self.seq_len:
|
||||
# 单个 assistant 本身超过窗口 —— 保“结尾”,避免被切尾
|
||||
start = max(0, k_excl - self.seq_len)
|
||||
end = start + self.seq_len
|
||||
else:
|
||||
# 有空间容纳整个 assistant:尽量把窗口对齐到包括完整 assistant
|
||||
# 先试图把窗口从 j 开始,但要保证 k_excl 也在窗口内
|
||||
start = max(0, min(j, len(input_ids) - self.seq_len))
|
||||
end = start + self.seq_len
|
||||
if end < k_excl:
|
||||
# 还没覆盖到 assistant 末尾,则右移窗口到恰好覆盖末尾
|
||||
end = k_excl
|
||||
start = end - self.seq_len
|
||||
if start < 0:
|
||||
start = 0
|
||||
end = self.seq_len
|
||||
|
||||
# 可选:尝试“居中”一点(留部分历史上下文),但仍需包含完整 [j, k_excl)
|
||||
leftover = self.seq_len - A
|
||||
# 把剩余的一半尽量分配给左侧上下文(不越界)
|
||||
left_wish = leftover // 2
|
||||
start = max(0, min(j - left_wish, start))
|
||||
end = start + self.seq_len
|
||||
if end < k_excl:
|
||||
# 若居中导致末尾又被排除,再纠正一次
|
||||
end = k_excl
|
||||
start = end - self.seq_len
|
||||
if start < 0:
|
||||
start = 0
|
||||
end = self.seq_len
|
||||
|
||||
# 真正切片
|
||||
input_ids = input_ids[start:end]
|
||||
labels = labels[start:end]
|
||||
# 注意:offsets 后续不再使用(只为确定切片窗口),无需同步裁剪
|
||||
|
||||
# 训练注意:这里的策略保证:
|
||||
# - 若最后一个 assistant <= seq_len:完整保留;
|
||||
# - 若 > seq_len:至少保证 assistant 的“结尾”在窗口内,不会“切尾”。
|
||||
|
||||
# ======== 统一长度:左侧补齐到 seq_len ========
|
||||
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
|
||||
L = len(input_ids)
|
||||
if L < self.seq_len:
|
||||
|
|
@ -321,24 +533,19 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
labels = ([-100] * pad) + labels
|
||||
attn_mask = [0] * pad + [1] * L
|
||||
else:
|
||||
# 恰好等于 seq_len
|
||||
attn_mask = [1] * self.seq_len
|
||||
|
||||
# 若没有任何可训练 token(labels 全 -100),跳过
|
||||
if all(v == -100 for v in labels):
|
||||
continue
|
||||
|
||||
# Sanity
|
||||
assert len(input_ids) == self.seq_len
|
||||
assert len(labels) == self.seq_len
|
||||
assert len(attn_mask) == self.seq_len
|
||||
|
||||
# >>> DEBUG PRINT(此时变量已定义)
|
||||
# >>> DEBUG PRINT
|
||||
if dbg_on and self._dbg_seen < dbg_limit:
|
||||
sup_tok = sum(1 for v in labels if v != -100)
|
||||
print(
|
||||
f"[sample][host={host} RANK={rank} LRank={lrank}] "
|
||||
f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} "
|
||||
f"seq_len={self.seq_len} pad_id={pad_id}",
|
||||
f"toks={len(input_ids)} sup_toks={sup_tok} seq_len={self.seq_len} pad_id={pad_id}",
|
||||
flush=True
|
||||
)
|
||||
if sup_tok == 0:
|
||||
|
|
@ -355,6 +562,8 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
"labels": torch.tensor(labels, dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
|
||||
# ----------------- 专用 Collator:pad inputs, pad labels=-100 -----------------
|
||||
class SFTDataCollator:
|
||||
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,730 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
os.environ.pop("PYTHONNOUSERSITE", None)
|
||||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||||
os.environ.setdefault("WANDB_START_METHOD", "thread")
|
||||
os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb")
|
||||
|
||||
import glob
|
||||
import socket
|
||||
import argparse
|
||||
import inspect
|
||||
import sys
|
||||
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset, Dataset
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
TrainingArguments,
|
||||
Trainer,
|
||||
set_seed
|
||||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
# ---------- PATH / CUDA utils ----------
|
||||
import site, shutil
|
||||
home = os.path.expanduser("~")
|
||||
want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"]
|
||||
cur = os.environ.get("PATH", "").split(":")
|
||||
new = [d for d in want if d and d not in cur] + cur
|
||||
os.environ["PATH"] = ":".join(new)
|
||||
print(f"[env] PATH={os.environ['PATH']}", flush=True)
|
||||
print(f"[env] which ninja={shutil.which('ninja')} which nvcc={shutil.which('nvcc')}", flush=True)
|
||||
|
||||
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda-11.8")
|
||||
ld = os.environ.get("LD_LIBRARY_PATH", "")
|
||||
cuda_lib = "/usr/local/cuda-11.8/lib64"
|
||||
if cuda_lib not in ld.split(":"):
|
||||
os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib
|
||||
|
||||
print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True)
|
||||
|
||||
os.environ.pop("DS_BUILD_OPS", None)
|
||||
os.environ.pop("DS_SKIP_CUDA_BUILD", None)
|
||||
try:
|
||||
user_site = site.getusersitepackages()
|
||||
if user_site and user_site not in sys.path:
|
||||
sys.path.insert(0, user_site)
|
||||
except Exception:
|
||||
pass
|
||||
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
|
||||
os.environ.setdefault("MAX_JOBS", "12")
|
||||
if shutil.which("ninja") is None:
|
||||
os.environ["USE_NINJA"] = "0"
|
||||
print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True)
|
||||
try:
|
||||
from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
CPUAdamBuilder().load()
|
||||
print("[env] CPUAdamBuilder JIT OK", flush=True)
|
||||
except Exception as e:
|
||||
if "Ninja is required to load C++ extensions" in str(e):
|
||||
os.environ["USE_NINJA"] = "0"
|
||||
print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
|
||||
from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
CPUAdamBuilder().load()
|
||||
print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
|
||||
else:
|
||||
print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True)
|
||||
# 不致命:LoRA 不依赖这个算子,继续运行
|
||||
pass
|
||||
|
||||
# ---------- helpers ----------
|
||||
def is_main_process():
|
||||
return int(os.environ.get("RANK", "0")) == 0
|
||||
|
||||
def print_once(*args, **kwargs):
|
||||
if is_main_process():
|
||||
print(*args, **kwargs, flush=True)
|
||||
|
||||
class DebugTrainer(Trainer):
|
||||
def training_step(self, model, inputs, num_items_in_batch=None):
|
||||
if not hasattr(self, "_dbg_printed"):
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
host = socket.gethostname()
|
||||
ids = inputs["input_ids"]; msk = inputs["attention_mask"]; labs = inputs["labels"]
|
||||
print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} "
|
||||
f"supervised={(labs!=-100).sum().item()}", flush=True)
|
||||
print(f"[step0][host={host} RANK={rank}] "
|
||||
f"input_ids.shape={tuple(ids.shape)} "
|
||||
f"attention_mask.shape={tuple(msk.shape)} "
|
||||
f"labels.shape={tuple(labs.shape)} "
|
||||
f"num_items_in_batch={num_items_in_batch}", flush=True)
|
||||
self._dbg_printed = True
|
||||
return super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
class CsvLossLogger(TrainerCallback):
|
||||
def __init__(self, csv_path: str):
|
||||
self.csv_path = csv_path
|
||||
if is_main_process():
|
||||
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
|
||||
with open(self.csv_path, "w", encoding="utf-8") as f:
|
||||
f.write("step,loss,lr,total_flos\n")
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
|
||||
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
|
||||
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if logs is None: return
|
||||
cur = int(getattr(state, "global_step", 0) or 0)
|
||||
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
|
||||
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
|
||||
pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a")
|
||||
if tot and not hasattr(self, "_tot_announced"):
|
||||
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
|
||||
self._tot_announced = True
|
||||
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] step {cur}/{tot} ({pct}) "
|
||||
f"loss={logs.get('loss')} lr={logs.get('learning_rate')}", flush=True)
|
||||
if not is_main_process(): return
|
||||
with open(self.csv_path, "a", encoding="utf-8") as f:
|
||||
f.write(f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
|
||||
|
||||
# ---------- assistant span detection ----------
|
||||
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||
spans: List[Tuple[int, int]] = []
|
||||
open_tag = "<|im_start|>assistant\n"
|
||||
close_tag = "<|im_end|>\n"
|
||||
pos = 0
|
||||
while True:
|
||||
a = rendered.find(open_tag, pos)
|
||||
if a == -1: break
|
||||
s = a + len(open_tag)
|
||||
b = rendered.find(close_tag, s)
|
||||
if b == -1: break
|
||||
spans.append((s, b))
|
||||
pos = b + len(close_tag)
|
||||
return spans
|
||||
|
||||
# ---------- Dataset (supervise assistant incl. <think> tags) ----------
|
||||
class QwenChatSFTDataset(IterableDataset):
|
||||
def __init__(self, ex_iter: Iterable[dict], tokenizer: AutoTokenizer, seq_len: int = 4096):
|
||||
self.ex_iter = ex_iter
|
||||
self.tok = tokenizer
|
||||
self.seq_len = seq_len
|
||||
|
||||
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
||||
dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
|
||||
if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0
|
||||
dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
lrank = int(os.environ.get("LOCAL_RANK", "-1"))
|
||||
host = socket.gethostname()
|
||||
|
||||
for ex in self.ex_iter:
|
||||
msgs = ex.get("messages", None)
|
||||
if not msgs or not isinstance(msgs, list): continue
|
||||
|
||||
tools = ex.get("tools", None)
|
||||
try:
|
||||
rendered: str = self.tok.apply_chat_template(
|
||||
msgs, tools=tools, add_generation_prompt=False, tokenize=False
|
||||
)
|
||||
except TypeError:
|
||||
rendered: str = self.tok.apply_chat_template(
|
||||
msgs, add_generation_prompt=False, tokenize=False
|
||||
)
|
||||
if not isinstance(rendered, str) or not rendered.strip(): continue
|
||||
|
||||
spans = _assistant_char_spans(rendered)
|
||||
if not spans: continue
|
||||
|
||||
enc = self.tok(rendered, add_special_tokens=False, return_offsets_mapping=True)
|
||||
input_ids: List[int] = enc["input_ids"]
|
||||
offsets: List[Tuple[int, int]] = enc["offset_mapping"]
|
||||
if not input_ids: continue
|
||||
|
||||
labels = [-100] * len(input_ids)
|
||||
|
||||
def in_any_span(lo: int, hi: int) -> bool:
|
||||
for s, e in spans:
|
||||
if not (hi <= s or lo >= e):
|
||||
return True
|
||||
return False
|
||||
|
||||
for i, (lo, hi) in enumerate(offsets):
|
||||
if in_any_span(lo, hi):
|
||||
labels[i] = input_ids[i]
|
||||
|
||||
if all(v == -100 for v in labels): # 无监督 token
|
||||
continue
|
||||
|
||||
# ---- assistant-aware truncation: keep last assistant not cut off
|
||||
if len(input_ids) > self.seq_len:
|
||||
s_last, e_last = spans[-1]
|
||||
j = 0
|
||||
while j < len(offsets) and offsets[j][1] <= s_last: j += 1
|
||||
k_excl = j
|
||||
while k_excl < len(offsets) and offsets[k_excl][0] < e_last: k_excl += 1
|
||||
A = max(0, k_excl - j)
|
||||
if A >= self.seq_len:
|
||||
start = max(0, k_excl - self.seq_len); end = start + self.seq_len
|
||||
else:
|
||||
start = max(0, min(j, len(input_ids) - self.seq_len))
|
||||
end = start + self.seq_len
|
||||
if end < k_excl:
|
||||
end = k_excl; start = end - self.seq_len
|
||||
if start < 0: start = 0; end = self.seq_len
|
||||
leftover = self.seq_len - A
|
||||
left_wish = leftover // 2
|
||||
start = max(0, min(j - left_wish, start))
|
||||
end = start + self.seq_len
|
||||
if end < k_excl:
|
||||
end = k_excl; start = end - self.seq_len
|
||||
if start < 0: start = 0; end = self.seq_len
|
||||
input_ids = input_ids[start:end]
|
||||
labels = labels[start:end]
|
||||
|
||||
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
|
||||
L = len(input_ids)
|
||||
if L < self.seq_len:
|
||||
pad = self.seq_len - L
|
||||
input_ids = ([pad_id]*pad) + input_ids
|
||||
labels = ([-100]*pad) + labels
|
||||
attn_mask = [0]*pad + [1]*L
|
||||
else:
|
||||
attn_mask = [1]*self.seq_len
|
||||
|
||||
assert len(input_ids) == self.seq_len
|
||||
assert len(labels) == self.seq_len
|
||||
assert len(attn_mask) == self.seq_len
|
||||
|
||||
if dbg_on and self._dbg_seen < dbg_limit:
|
||||
sup_tok = sum(1 for v in labels if v != -100)
|
||||
print(f"[sample][host={host} RANK={rank} LRank={lrank}] "
|
||||
f"toks={len(input_ids)} sup_toks={sup_tok} seq_len={self.seq_len} pad_id={pad_id}", flush=True)
|
||||
if sup_tok == 0:
|
||||
print(f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> skipped", flush=True)
|
||||
self._dbg_seen += 1
|
||||
|
||||
yield {
|
||||
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||||
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
|
||||
"labels": torch.tensor(labels, dtype=torch.long),
|
||||
}
|
||||
|
||||
# ---------- Collator ----------
|
||||
class SFTDataCollator:
|
||||
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
|
||||
self.tok = tokenizer
|
||||
self.pad_to_length = pad_to_length
|
||||
assert self.tok.pad_token_id is not None
|
||||
|
||||
def __call__(self, features):
|
||||
if not features:
|
||||
raise RuntimeError(f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator.")
|
||||
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
|
||||
input_ids = [_to_list(f["input_ids"]) for f in features]
|
||||
attn_masks = [_to_list(f["attention_mask"]) for f in features]
|
||||
labels_list = [_to_list(f["labels"]) for f in features]
|
||||
max_len_in_batch = max(len(x) for x in input_ids)
|
||||
target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch
|
||||
pad_id = self.tok.pad_token_id
|
||||
batch_inp, batch_attn, batch_lab = [], [], []
|
||||
for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
|
||||
pad_len = target_len - len(inp)
|
||||
if pad_len < 0:
|
||||
inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len]
|
||||
pad_len = 0
|
||||
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
|
||||
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
|
||||
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
|
||||
if os.environ.get("DBG_COLLATE","0") == "1":
|
||||
print(f"[collate][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] "
|
||||
f"features={len(features)} target_len={target_len}", flush=True)
|
||||
return {
|
||||
"input_ids": torch.stack(batch_inp, dim=0),
|
||||
"attention_mask": torch.stack(batch_attn, dim=0),
|
||||
"labels": torch.stack(batch_lab, dim=0),
|
||||
}
|
||||
|
||||
# ---------- Args ----------
|
||||
def parse_args():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model_name_or_path", type=str, required=True)
|
||||
ap.add_argument("--data_glob", type=str, required=True)
|
||||
ap.add_argument("--output_dir", type=str, required=True)
|
||||
ap.add_argument("--seq_len", type=int, default=4096)
|
||||
ap.add_argument("--learning_rate", type=float, default=1e-4) # LoRA 通常可更大学习率
|
||||
ap.add_argument("--weight_decay", type=float, default=0.0) # LoRA 常设 0 或很小
|
||||
ap.add_argument("--warmup_ratio", type=float, default=0.03)
|
||||
ap.add_argument("--num_train_epochs", type=float, default=1.0)
|
||||
ap.add_argument("--max_steps", type=int, default=-1)
|
||||
ap.add_argument("--log_interval", type=int, default=10)
|
||||
ap.add_argument("--save_steps", type=int, default=500)
|
||||
ap.add_argument("--eval_ratio", type=float, default=0.0)
|
||||
ap.add_argument("--seed", type=int, default=1337)
|
||||
ap.add_argument("--gradient_checkpointing", action="store_true")
|
||||
ap.add_argument("--bf16", action="store_true")
|
||||
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
|
||||
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
|
||||
ap.add_argument("--report_to", type=str, default="tensorboard", choices=["none","tensorboard","wandb"])
|
||||
ap.add_argument("--wandb_project", type=str, default="ds-qwen3-lora")
|
||||
ap.add_argument("--eval_data_glob", type=str, default=None)
|
||||
ap.add_argument("--local_rank", type=int, default=-1)
|
||||
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
|
||||
ap.add_argument("--deepspeed", type=str, default=None)
|
||||
|
||||
# ---- LoRA specific ----
|
||||
ap.add_argument("--lora_r", type=int, default=16)
|
||||
ap.add_argument("--lora_alpha", type=float, default=32)
|
||||
ap.add_argument("--lora_dropout", type=float, default=0.05)
|
||||
ap.add_argument("--lora_target", type=str, default="auto",
|
||||
help='逗号分隔,如 "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj";或 "auto"')
|
||||
|
||||
ap.add_argument("--qlora", action="store_true", help="使用 4bit (NF4) QLoRA(多机 DS 不建议)")
|
||||
ap.add_argument("--merge_lora_and_save", action="store_true",
|
||||
help="训练后在 rank0 合并 LoRA 到基座并另存(注意显存/内存占用)")
|
||||
return ap.parse_args()
|
||||
|
||||
# ---------- LoRA helpers ----------
|
||||
def _auto_lora_targets(model) -> List[str]:
|
||||
"""
|
||||
针对 Qwen/Llama 族,自动挑选常见的线性层名字;仅匹配存在的模块。
|
||||
"""
|
||||
cand = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj",
|
||||
"w1","w2","w3", "W_pack", "o_attn", "o_proj"] # 覆盖不同实现命名
|
||||
present = set()
|
||||
for name, module in model.named_modules():
|
||||
if any(name.endswith(f".{c}") or name == c for c in cand):
|
||||
present.add(name.split(".")[-1])
|
||||
# 回落:若一个都没匹配到,使用“所有 nn.Linear”
|
||||
if not present:
|
||||
return ["all-linear"]
|
||||
# 去重且保序
|
||||
order = []
|
||||
for c in cand:
|
||||
if c in present: order.append(c)
|
||||
return order
|
||||
|
||||
# ---------- main ----------
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if os.environ.get("RANK","0") != "0" and args.report_to == "wandb":
|
||||
print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True)
|
||||
args.report_to = "none"
|
||||
|
||||
set_seed(args.seed)
|
||||
|
||||
# DeepSpeed enable?
|
||||
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
|
||||
dschf = None
|
||||
if use_ds:
|
||||
try:
|
||||
from transformers.integrations.deepspeed import HfDeepSpeedConfig
|
||||
src = "transformers.integrations.deepspeed"
|
||||
except Exception:
|
||||
from transformers import HfDeepSpeedConfig
|
||||
src = "transformers"
|
||||
dschf = HfDeepSpeedConfig(args.deepspeed)
|
||||
print(f"[dbg] HfDeepSpeedConfig loaded from {src}", flush=True)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
|
||||
|
||||
import transformers as hf
|
||||
try:
|
||||
import deepspeed as ds
|
||||
ds_ver = ds.__version__
|
||||
except Exception:
|
||||
ds_ver = "n/a"
|
||||
|
||||
def dbg(msg):
|
||||
print(f"[dbg][host={socket.gethostname()} RANK={os.environ.get('RANK','0')} "
|
||||
f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}", flush=True)
|
||||
|
||||
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
|
||||
dbg(f"args={args}")
|
||||
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
|
||||
os.environ.get("WORLD_SIZE"), os.environ.get("RANK"),
|
||||
os.environ.get("LOCAL_RANK", str(args.local_rank)),
|
||||
os.environ.get("MASTER_ADDR"), os.environ.get("MASTER_PORT"),
|
||||
os.environ.get("CUDA_VISIBLE_DEVICES"),
|
||||
))
|
||||
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
|
||||
|
||||
# init dist
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
|
||||
if torch.cuda.is_available() and local_rank >= 0:
|
||||
torch.cuda.set_device(local_rank)
|
||||
dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} "
|
||||
f"name={torch.cuda.get_device_name(torch.cuda.current_device())}")
|
||||
if world_size > 1 and dist.is_available() and not dist.is_initialized():
|
||||
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
||||
dbg(f"init_process_group backend={backend} via env://")
|
||||
dist.init_process_group(backend=backend, init_method="env://")
|
||||
|
||||
# tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
try:
|
||||
if getattr(tokenizer, "padding_side", None) != "left":
|
||||
tokenizer.padding_side = "left"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False):
|
||||
raise RuntimeError("需要 Fast tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14。")
|
||||
tokenizer.model_max_length = args.seq_len
|
||||
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} model_max_length={tokenizer.model_max_length}")
|
||||
|
||||
# dtype
|
||||
def _bf16_supported():
|
||||
if not torch.cuda.is_available(): return False
|
||||
if hasattr(torch.cuda, "is_bf16_supported"):
|
||||
return torch.cuda.is_bf16_supported()
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
return (major, minor) >= (8, 0)
|
||||
use_bf16 = bool(args.bf16 and _bf16_supported())
|
||||
compute_dtype = torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32)
|
||||
|
||||
# -------- load base model (with/without 4bit) --------
|
||||
quantization_config = None
|
||||
if args.qlora:
|
||||
try:
|
||||
from transformers import BitsAndBytesConfig
|
||||
from peft import prepare_model_for_kbit_training
|
||||
except Exception as e:
|
||||
raise RuntimeError("使用 --qlora 需要安装 bitsandbytes>=0.41 与 peft。") from e
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_compute_dtype=compute_dtype
|
||||
)
|
||||
# 4bit 下不要传 attn_implementation="sdpa" 给部分旧版 torch
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
torch_dtype=compute_dtype,
|
||||
trust_remote_code=True,
|
||||
low_cpu_mem_usage=True,
|
||||
quantization_config=quantization_config,
|
||||
device_map=None # 用 DeepSpeed/Trainer 接管
|
||||
)
|
||||
if args.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
torch_dtype=compute_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
if args.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
model.config.use_cache = False
|
||||
|
||||
# -------- wrap with LoRA --------
|
||||
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
|
||||
if args.lora_target.strip().lower() == "auto":
|
||||
targets = _auto_lora_targets(model)
|
||||
else:
|
||||
targets = [x.strip() for x in args.lora_target.split(",") if x.strip()]
|
||||
if not targets:
|
||||
targets = _auto_lora_targets(model)
|
||||
|
||||
lora_cfg = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
target_modules=targets,
|
||||
bias="none",
|
||||
inference_mode=False
|
||||
)
|
||||
model = get_peft_model(model, lora_cfg)
|
||||
|
||||
# 冻结确认
|
||||
if is_main_process():
|
||||
try:
|
||||
model.print_trainable_parameters()
|
||||
except Exception:
|
||||
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
total = sum(p.numel() for p in model.parameters())
|
||||
print(f"[LoRA] trainable={trainable:,} / total={total:,} ({trainable/total:.2%})", flush=True)
|
||||
|
||||
# -------- data streams --------
|
||||
files = sorted(glob.glob(args.data_glob))
|
||||
if len(files) == 0:
|
||||
raise FileNotFoundError(f"No files matched DATA_GLOB={args.data_glob}")
|
||||
if is_main_process():
|
||||
print(f"[data] matched {len(files)} files, example[0]={files[0]}", flush=True)
|
||||
|
||||
ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
def ex_iter_probe():
|
||||
for ex in ds_stream_probe: yield ex
|
||||
train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len)
|
||||
try:
|
||||
_ = next(iter(train_stream_probe))
|
||||
except StopIteration:
|
||||
raise RuntimeError("[data] 样本结构不合法或全部被裁切。")
|
||||
|
||||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
def has_at_least(stream, n: int):
|
||||
it = iter(stream)
|
||||
for _ in range(n):
|
||||
try: next(it)
|
||||
except StopIteration: return 0
|
||||
return 1
|
||||
|
||||
need = max(1, args.gradient_accumulation_steps)
|
||||
local_ok = has_at_least(probe_stream, need)
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank>=0 else "cpu"))
|
||||
dist.all_reduce(t, op=dist.ReduceOp.MIN)
|
||||
if t.item() == 0:
|
||||
if is_main_process():
|
||||
print(f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批。", flush=True)
|
||||
dist.barrier(); sys.exit(2)
|
||||
else:
|
||||
if local_ok == 0:
|
||||
print(f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批。", flush=True)
|
||||
sys.exit(2)
|
||||
|
||||
# eval
|
||||
eval_dataset: Optional[Dataset] = None
|
||||
class ListDataset(Dataset):
|
||||
def __init__(self, items): self.items = items
|
||||
def __len__(self): return len(self.items)
|
||||
def __getitem__(self, idx): return self.items[idx]
|
||||
|
||||
if args.eval_data_glob:
|
||||
eval_files = sorted(glob.glob(args.eval_data_glob))
|
||||
if len(eval_files) == 0:
|
||||
raise FileNotFoundError(f"No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
|
||||
if is_main_process():
|
||||
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
|
||||
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
|
||||
def ex_iter_eval():
|
||||
for ex in ds_eval_stream: yield ex
|
||||
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
|
||||
eval_items: List[Dict[str, torch.Tensor]] = [s for s in eval_iterable]
|
||||
if len(eval_items) == 0:
|
||||
raise RuntimeError("[eval] 读到了 0 条有效样本。")
|
||||
eval_dataset = ListDataset(eval_items)
|
||||
# pad to global batch size
|
||||
ws = max(int(os.environ.get("WORLD_SIZE","1")), 1)
|
||||
be = max(1, args.per_device_eval_batch_size)
|
||||
global_bs = ws * be
|
||||
r = len(eval_dataset) % global_bs
|
||||
if r != 0:
|
||||
pad_need = global_bs - r
|
||||
eval_dataset.items += eval_dataset.items[:pad_need]
|
||||
if is_main_process():
|
||||
print(f"[eval] padded eval set to {len(eval_dataset)}", flush=True)
|
||||
|
||||
# collator
|
||||
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
|
||||
|
||||
# training args
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
logging_dir = os.path.join(args.output_dir, "logs"); os.makedirs(logging_dir, exist_ok=True)
|
||||
|
||||
ta_kwargs = {}
|
||||
sig = inspect.signature(TrainingArguments.__init__).parameters
|
||||
if eval_dataset is not None:
|
||||
if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "steps"
|
||||
elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "steps"
|
||||
else:
|
||||
if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "no"
|
||||
elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "no"
|
||||
|
||||
ta_kwargs2 = dict(
|
||||
output_dir=args.output_dir,
|
||||
logging_dir=logging_dir,
|
||||
do_train=True,
|
||||
do_eval=(eval_dataset is not None),
|
||||
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
learning_rate=args.learning_rate,
|
||||
weight_decay=args.weight_decay,
|
||||
warmup_ratio=args.warmup_ratio,
|
||||
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
|
||||
max_steps=args.max_steps if args.max_steps > 0 else -1,
|
||||
lr_scheduler_type="cosine",
|
||||
logging_steps=args.log_interval,
|
||||
save_steps=args.save_steps,
|
||||
save_total_limit=2,
|
||||
deepspeed=(args.deepspeed if use_ds else None),
|
||||
dataloader_drop_last=False,
|
||||
dataloader_num_workers=0,
|
||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
report_to=([] if args.report_to == "none" else [args.report_to]),
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
remove_unused_columns=False,
|
||||
save_on_each_node=True,
|
||||
logging_first_step=True,
|
||||
**ta_kwargs,
|
||||
)
|
||||
# 精度:QLoRA/LoRA 均按 compute_dtype 设置
|
||||
if "dataloader_pin_memory" in sig: ta_kwargs2["dataloader_pin_memory"] = False
|
||||
if "torch_compile" in sig: ta_kwargs2["torch_compile"] = False
|
||||
ta_kwargs2.update({
|
||||
"bf16": (compute_dtype==torch.bfloat16),
|
||||
"fp16": (compute_dtype==torch.float16),
|
||||
})
|
||||
training_args = TrainingArguments(**ta_kwargs2)
|
||||
|
||||
# pass tokenizer / processing_class
|
||||
trainer_kwargs = {}
|
||||
if "processing_class" in inspect.signature(Trainer.__init__).parameters:
|
||||
trainer_kwargs["processing_class"] = tokenizer
|
||||
else:
|
||||
trainer_kwargs["tokenizer"] = tokenizer
|
||||
|
||||
trainer = DebugTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_stream,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=data_collator,
|
||||
**trainer_kwargs
|
||||
)
|
||||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||||
|
||||
# resume (per-node local checkpoint agreement)
|
||||
def last_step(path: str) -> int:
|
||||
ck = get_last_checkpoint(path)
|
||||
if ck is None: return -1
|
||||
base = os.path.basename(ck)
|
||||
try: return int(base.split("-")[-1])
|
||||
except Exception: return -1
|
||||
|
||||
local_last = last_step(args.output_dir)
|
||||
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank>=0) else "cpu")
|
||||
resume_flag = None
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
has_local = torch.tensor(1 if local_last >= 0 else 0, device=device)
|
||||
dist.all_reduce(has_local, op=dist.ReduceOp.MIN)
|
||||
if has_local.item() == 1:
|
||||
ts = torch.tensor(local_last, device=device)
|
||||
world = dist.get_world_size()
|
||||
buf = [torch.zeros_like(ts) for _ in range(world)]
|
||||
dist.all_gather(buf, ts)
|
||||
steps = [b.item() for b in buf]
|
||||
k = min(steps)
|
||||
if k >= 0:
|
||||
resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}")
|
||||
if is_main_process():
|
||||
print(f"[resume] steps={steps}, resume={resume_flag}", flush=True)
|
||||
else:
|
||||
if local_last >= 0:
|
||||
resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}")
|
||||
|
||||
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}")
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device)
|
||||
dist.all_reduce(present, op=dist.ReduceOp.MIN)
|
||||
if present.item() == 0:
|
||||
if is_main_process():
|
||||
print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True)
|
||||
resume_flag = None
|
||||
dist.barrier()
|
||||
else:
|
||||
if resume_flag is not None and not os.path.isdir(resume_flag):
|
||||
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
|
||||
resume_flag = None
|
||||
|
||||
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
|
||||
print_once("***** Starting LoRA training *****")
|
||||
print(f"[dbg] allocated={torch.cuda.memory_allocated()/1024**2:.1f} MB, "
|
||||
f"reserved={torch.cuda.memory_reserved()/1024**2:.1f} MB", flush=True)
|
||||
|
||||
train_result = trainer.train(resume_from_checkpoint=resume_flag)
|
||||
|
||||
# save adapter (not the full base)
|
||||
trainer.save_model() # 对 PeftModel:只保存 adapter 权重到 output_dir
|
||||
metrics = train_result.metrics
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# eval
|
||||
if eval_dataset is not None:
|
||||
print_once("***** Running eval *****")
|
||||
eval_metrics = trainer.evaluate()
|
||||
trainer.log_metrics("eval", eval_metrics)
|
||||
trainer.save_metrics("eval", eval_metrics)
|
||||
|
||||
# optional merge
|
||||
if args.merge_lora_and_save and is_main_process():
|
||||
print("[merge] Merging LoRA into base model ...", flush=True)
|
||||
try:
|
||||
if isinstance(trainer.model, PeftModel):
|
||||
merged = trainer.model.merge_and_unload()
|
||||
else:
|
||||
merged = trainer.model
|
||||
merge_dir = os.path.join(args.output_dir, "merged-full-model")
|
||||
os.makedirs(merge_dir, exist_ok=True)
|
||||
merged.save_pretrained(merge_dir, safe_serialization=True)
|
||||
tokenizer.save_pretrained(merge_dir)
|
||||
print(f"[merge] Saved merged model to: {merge_dir}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[merge] FAILED: {e}", flush=True)
|
||||
|
||||
print_once("Done.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue