#!/usr/bin/env python3
import os
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import glob
import socket
import argparse
import inspect
import sys
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
set_seed
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint
# ----------------- 进程工具 -----------------
def is_main_process():
return int(os.environ.get("RANK", "0")) == 0
def print_once(*args, **kwargs):
if is_main_process():
print(*args, **kwargs, flush=True)
class DebugTrainer(Trainer):
def training_step(self, model, inputs, num_items_in_batch=None):
if not hasattr(self, "_dbg_printed"):
rank = int(os.environ.get("RANK", "0"))
host = socket.gethostname()
ids = inputs["input_ids"]
msk = inputs["attention_mask"]
labs = inputs["labels"]
print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} "
f"supervised={(labs!=-100).sum().item()}",
flush=True)
print(
f"[step0][host={host} RANK={rank}] "
f"input_ids.shape={tuple(ids.shape)} "
f"attention_mask.shape={tuple(msk.shape)} "
f"labels.shape={tuple(labs.shape)} "
f"num_items_in_batch={num_items_in_batch}",
flush=True
)
self._dbg_printed = True
return super().training_step(model, inputs, num_items_in_batch)
# ----------------- 日志回调 -----------------
class CsvLossLogger(TrainerCallback):
def __init__(self, csv_path: str):
self.csv_path = csv_path
if is_main_process():
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
with open(self.csv_path, "w", encoding="utf-8") as f:
f.write("step,loss,lr,total_flos\n")
# def on_log(self, args, state, control, logs=None, **kwargs):
# if not is_main_process() or logs is None:
# return
# with open(self.csv_path, "a", encoding="utf-8") as f:
# f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
def on_train_begin(self, args, state, control, **kwargs):
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
rank = os.environ.get("RANK", "?")
host = socket.gethostname()
print(f"[{host} rank={rank}] total_steps={tot}", flush=True)
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
# ---- 控制台打印:所有 rank 都打当前步/总步 ----
cur = int(getattr(state, "global_step", 0) or 0)
# if getattr(args, "logging_steps", None) and cur % args.logging_steps != 0:
# return
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a")
# —— tot 一旦可用,就再宣布一次总步数(只打印一次)
if tot and not hasattr(self, "_tot_announced"):
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
self._tot_announced = True
# if not is_main_process():
# return
rank = os.environ.get("RANK", "?")
host = socket.gethostname()
print(
f"[{host} rank={rank}] step {cur}/{tot} ({pct}) "
f"loss={logs.get('loss')} lr={logs.get('learning_rate')}",
flush=True
)
# ---- 只在主进程写 CSV,避免并发写 ----
if not is_main_process():
return
with open(self.csv_path, "a", encoding="utf-8") as f:
f.write(
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
)
# ----------------- 仅监督 assistant 的数据集 -----------------
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
"""
在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
"""
spans: List[Tuple[int, int]] = []
open_tag = "<|im_start|>assistant\n"
close_tag = "<|im_end|>\n"
pos = 0
while True:
a = rendered.find(open_tag, pos)
if a == -1:
break
start = a + len(open_tag)
b = rendered.find(close_tag, start)
if b == -1:
break
spans.append((start, b))
pos = b + len(close_tag)
return spans
class QwenChatSFTDataset(IterableDataset):
"""
期望 jsonl 每行形如:
{"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
可选包含工具:
{"messages":[...], "tools":[{...}]}
工作流:
- 使用 tokenizer.apply_chat_template 渲染
- 仅对 assistant 片段计损失(其他 token 的 label = -100)
- 超长序列保留尾部(通常包含回答)
"""
def __init__(self,
ex_iter: Iterable[dict],
tokenizer: AutoTokenizer,
seq_len: int = 4096):
self.ex_iter = ex_iter
self.tok = tokenizer
self.seq_len = seq_len
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
# >>> DEBUG BEGIN
dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0
dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
rank = int(os.environ.get("RANK", "0"))
lrank = int(os.environ.get("LOCAL_RANK", "-1"))
host = socket.gethostname()
# >>> DEBUG END
for ex in self.ex_iter:
msgs = ex.get("messages", None)
if not msgs or not isinstance(msgs, list):
continue
# 可选过滤 think
bad = False
for m in msgs:
if m.get("role") == "assistant" and isinstance(m.get("content"), str):
c = m["content"]
if "" in c and "" in c:
inner = c.split("")[-1].split("")[0].strip()
if inner:
bad = True; break
if bad:
continue
tools = ex.get("tools", None)
# 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况
try:
rendered: str = self.tok.apply_chat_template(
msgs, tools=tools, add_generation_prompt=False, tokenize=False
)
except TypeError:
rendered: str = self.tok.apply_chat_template(
msgs, add_generation_prompt=False, tokenize=False
)
if not isinstance(rendered, str) or not rendered.strip():
continue
spans = _assistant_char_spans(rendered)
if not spans:
continue
enc = self.tok(
rendered,
add_special_tokens=False,
return_offsets_mapping=True
)
input_ids: List[int] = enc["input_ids"]
offsets: List[Tuple[int, int]] = enc["offset_mapping"]
if not input_ids:
continue
labels = [-100] * len(input_ids)
def in_any_span(lo: int, hi: int) -> bool:
for s, e in spans:
if not (hi <= s or lo >= e):
return True
return False
for i, (lo, hi) in enumerate(offsets):
if in_any_span(lo, hi):
labels[i] = input_ids[i]
# —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len ——
# 1) 截断到 seq_len(保留尾部)
if len(input_ids) > self.seq_len:
input_ids = input_ids[-self.seq_len:]
labels = labels[-self.seq_len:]
# 2) 左侧补齐到 seq_len(保证所有样本长度一致)
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
L = len(input_ids)
if L < self.seq_len:
pad = self.seq_len - L
input_ids = ([pad_id] * pad) + input_ids
labels = ([-100] * pad) + labels
attn_mask = [0] * pad + [1] * L
else:
# 恰好等于 seq_len
attn_mask = [1] * self.seq_len
# 若没有任何可训练 token(labels 全 -100),跳过
if all(v == -100 for v in labels):
continue
assert len(input_ids) == self.seq_len
assert len(labels) == self.seq_len
assert len(attn_mask) == self.seq_len
# >>> DEBUG PRINT(此时变量已定义)
if dbg_on and self._dbg_seen < dbg_limit:
sup_tok = sum(1 for v in labels if v != -100)
print(
f"[sample][host={host} RANK={rank} LRank={lrank}] "
f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} "
f"seq_len={self.seq_len} pad_id={pad_id}",
flush=True
)
if sup_tok == 0:
print(
f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped",
flush=True
)
self._dbg_seen += 1
# <<< DEBUG PRINT
yield {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
# ----------------- 专用 Collator:pad inputs, pad labels=-100 -----------------
class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
self.tok = tokenizer
self.pad_to_length = pad_to_length
assert self.tok.pad_token_id is not None
def __call__(self, features):
if not features:
raise RuntimeError(
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
f"Check dataset sharding/streaming."
)
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
input_ids = [_to_list(f["input_ids"]) for f in features]
attn_masks = [_to_list(f["attention_mask"]) for f in features]
labels_list = [_to_list(f["labels"]) for f in features]
max_len_in_batch = max(len(x) for x in input_ids)
target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch
pad_id = self.tok.pad_token_id
batch_inp, batch_attn, batch_lab = [], [], []
for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
pad_len = target_len - len(inp)
if pad_len < 0:
inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len]
pad_len = 0
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
# >>> DEBUG BEGIN
dbg_on = os.environ.get("DBG_COLLATE", "0") == "1"
if dbg_on:
rank = int(os.environ.get("RANK", "0"))
host = socket.gethostname()
bs = len(features)
first_len = len(input_ids[0]) if bs > 0 else None
print(
f"[collate][host={host} RANK={rank}] features={bs} "
f"target_len={target_len} first_len={first_len}",
flush=True
)
return {
"input_ids": torch.stack(batch_inp, dim=0),
"attention_mask": torch.stack(batch_attn, dim=0),
"labels": torch.stack(batch_lab, dim=0),
}
# ----------------- 参数 -----------------
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--model_name_or_path", type=str, required=True,
help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B)")
ap.add_argument("--data_glob", type=str, required=True,
help="本地 jsonl 通配符(每台机器都需有同路径数据;每行应含 messages/可选 tools)")
ap.add_argument("--output_dir", type=str, required=True,
help="本地输出目录(各节点各自本地写)")
ap.add_argument("--seq_len", type=int, default=4096)
ap.add_argument("--learning_rate", type=float, default=2e-5)
ap.add_argument("--weight_decay", type=float, default=0.1)
ap.add_argument("--warmup_ratio", type=float, default=0.02)
ap.add_argument("--num_train_epochs", type=float, default=1.0)
ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0)
ap.add_argument("--seed", type=int, default=1337)
#ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
ap.add_argument("--gradient_checkpointing", action="store_true")
ap.add_argument("--bf16", action="store_true",
help="3090/A100/H100 等可开 bf16;同时在 DS 配置里也要开")
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
ap.add_argument("--report_to", type=str, default="tensorboard",
choices=["none","tensorboard","wandb"])
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
ap.add_argument("--eval_data_glob", type=str, default=None,
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
ap.add_argument("--local_rank", type=int, default=-1,
help="for deepspeed/torchrun launcher; ignored by user code")
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
ap.add_argument("--deepspeed", type=str, default=None)
return ap.parse_args()
# ----------------- 主函数 -----------------
def main():
args = parse_args()
set_seed(args.seed)
if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
# -------- 调试打印工具(每个 rank 都打)--------
host = socket.gethostname()
def dbg(msg):
print(
f"[dbg][host={host} RANK={os.environ.get('RANK','0')} "
f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}",
flush=True
)
# 版本 & 启动参数 & 关键环境变量
import transformers as hf
try:
import deepspeed as ds
ds_ver = ds.__version__
except Exception:
ds_ver = "n/a"
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
dbg(f"args={args}")
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
os.environ.get("WORLD_SIZE"),
os.environ.get("RANK"),
os.environ.get("LOCAL_RANK", str(args.local_rank)),
os.environ.get("MASTER_ADDR"),
os.environ.get("MASTER_PORT"),
os.environ.get("CUDA_VISIBLE_DEVICES"),
))
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
# ---- 初始化分布式(供一致性探针使用)----
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}")
if torch.cuda.is_available() and local_rank >= 0:
torch.cuda.set_device(local_rank)
dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} "
f"name={torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
dbg("no cuda or invalid local_rank; not calling set_device")
if world_size > 1 and dist.is_available() and not dist.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
dbg(f"init_process_group backend={backend} via env://")
dist.init_process_group(backend=backend, init_method="env://")
else:
dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}")
if dist.is_available() and dist.is_initialized():
try:
dbg(f"dist.get_backend()={dist.get_backend()} "
f"dist.get_world_size()={dist.get_world_size()} dist.get_rank()={dist.get_rank()}")
except Exception as e:
dbg(f"dist query error: {e}")
# 1) 先补 tokenizer 的 pad
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 左侧补齐以匹配 Dataset 的左 pad 策略
try:
if getattr(tokenizer, "padding_side", None) != "left":
tokenizer.padding_side = "left"
except Exception:
pass
# 强制要求 fast tokenizer(offset_mapping 依赖 fast)
from transformers import PreTrainedTokenizerFast
if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False):
raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。")
tokenizer.model_max_length = args.seq_len
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
# 2) 再加载模型 之前,先算 dtype
def _bf16_supported():
if not torch.cuda.is_available():
return False
# 兼容不同 torch 版本:优先用 API,退化到算力判断
if hasattr(torch.cuda, "is_bf16_supported"):
return torch.cuda.is_bf16_supported()
major, minor = torch.cuda.get_device_capability()
return (major, minor) >= (8, 0) # Ampere 及以上
use_bf16 = bool(args.bf16 and _bf16_supported())
dtype = (torch.bfloat16 if use_bf16 else
(torch.float16 if torch.cuda.is_available() else torch.float32))
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
attn_implementation="flash_attention_2"
)
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
f"use_cache={getattr(model.config,'use_cache',None)} "
f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
# 3) pad/alibi 等配置
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False
if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
try:
# torch.backends.cuda.enable_flash_sdp(False)
# torch.backends.cuda.enable_mem_efficient_sdp(False)
# torch.backends.cuda.enable_math_sdp(True)
# 让 PyTorch 自己选,或显式打开高效实现(任选其一):
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
except Exception:
pass
# ===== 数据鲁棒性检查(多机各自执行)=====
host = socket.gethostname()
files = sorted(glob.glob(args.data_glob))
if len(files) == 0:
raise FileNotFoundError(
f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n"
"每台机器都必须在相同本地路径下放置数据;"
"可通过 DATA_GLOB= ./run_ds.sh 覆写。"
)
if is_main_process():
print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True)
# ====== 小探针:样本结构 ======
ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_probe():
for ex in ds_stream_probe:
yield ex
train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len)
try:
_ = next(iter(train_stream_probe))
except StopIteration:
raise RuntimeError(
f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n"
"请确认每行 JSON 至少包含 'messages'(列表,含 user/assistant)字段;"
"若含 … 请确保不包含真实思维文本,或移除。\n"
"另外检查 seq_len 是否过小导致全部被裁。"
)
# # ====== 正式训练流 ======
# ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
# if world_size > 1 and len(files) >= world_size:
# # 多文件,按文件连续分片
# ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
# train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
# else:
# # 单文件或文件数不足,按样本取模轮转
# def ex_iter2():
# for i, ex in enumerate(ds_stream2):
# if i % max(world_size, 1) == rank:
# yield ex
# train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)======
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
# # ====== 一致性探针(与上面保持同逻辑)=====
# ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
# if world_size > 1 and len(files) >= world_size:
# ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True)
# probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
# else:
# def ex_iter2_probe():
# for i, ex in enumerate(ds_stream_probe2):
# if i % max(world_size, 1) == rank:
# yield ex
# probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
# ====== 一致性探针(不分片)======
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
def has_at_least(stream, n: int):
it = iter(stream)
for _ in range(n):
try:
next(it)
except StopIteration:
return 0
return 1
need = max(1, args.gradient_accumulation_steps)
local_ok = has_at_least(probe_stream, need)
if dist.is_available() and dist.is_initialized():
t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu"))
dist.all_reduce(t, op=dist.ReduceOp.MIN)
if t.item() == 0:
if is_main_process():
print(
f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
flush=True
)
dist.barrier()
sys.exit(2)
else:
if local_ok == 0:
print(
f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
flush=True
)
sys.exit(2)
# ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
eval_dataset: Optional[Dataset] = None
class ListDataset(Dataset):
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
if args.eval_data_glob:
eval_files = sorted(glob.glob(args.eval_data_glob))
if len(eval_files) == 0:
raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
if is_main_process():
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
def ex_iter_eval():
for ex in ds_eval_stream:
yield ex
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
eval_items: List[Dict[str, torch.Tensor]] = []
for sample in eval_iterable:
eval_items.append(sample)
if len(eval_items) == 0:
raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。")
eval_dataset = ListDataset(eval_items)
elif args.eval_ratio and args.eval_ratio > 0:
desired_eval_batches = 200
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_eval2():
for ex in tmp_stream:
yield ex
eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len)
eval_samples = []
it = iter(eval_stream)
for _ in range(desired_eval_batches):
try:
eval_samples.append(next(it))
except StopIteration:
break
if len(eval_samples) > 0:
eval_dataset = ListDataset(eval_samples)
# ---- 统一补齐 eval 集(确保不会出现空 batch)----
if eval_dataset is not None:
ws = max(world_size, 1)
be = max(1, args.per_device_eval_batch_size)
global_bs = ws * be
r = len(eval_dataset) % global_bs
if r != 0:
pad_need = global_bs - r
eval_dataset.items += eval_dataset.items[:pad_need]
if is_main_process():
print(f"[eval] padded eval set to {len(eval_dataset)} "
f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})",
flush=True)
# 补齐后再做 sanity check
assert len(eval_dataset) % global_bs == 0, \
f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}"
# 更稳:联调阶段不强行 pad 到 4096
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs")
os.makedirs(logging_dir, exist_ok=True)
# ---- 兼容 4.51(eval_strategy)与旧版(evaluation_strategy) ----
ta_kwargs = {}
sig = inspect.signature(TrainingArguments.__init__).parameters
if eval_dataset is not None:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "steps"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "steps"
else:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "no"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "no"
ta_sig = inspect.signature(TrainingArguments.__init__).parameters
ta_kwargs2 = dict(
output_dir=args.output_dir,
logging_dir=logging_dir,
do_train=True,
do_eval=(eval_dataset is not None),
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
max_steps=args.max_steps if args.max_steps > 0 else -1,
lr_scheduler_type="cosine",
logging_steps=args.log_interval,
save_steps=args.save_steps,
save_total_limit=2,
deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
dataloader_drop_last=False,
dataloader_num_workers=0,
per_device_eval_batch_size=args.per_device_eval_batch_size,
report_to=([] if args.report_to == "none" else [args.report_to]),
bf16=args.bf16,
fp16=(not args.bf16),
gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False,
save_on_each_node=True,
logging_first_step=True,
**ta_kwargs, # 你之前构造的 eval_strategy 兼容项
)
# if "dataloader_prefetch_factor" in ta_sig:
# ta_kwargs2["dataloader_prefetch_factor"] = None
if "dataloader_pin_memory" in ta_sig:
ta_kwargs2["dataloader_pin_memory"] = False
if "torch_compile" in ta_sig:
ta_kwargs2["torch_compile"] = False
# 构造 TrainingArguments 之前,沿用上面的 use_bf16 判定
ta_kwargs2.update({
"bf16": use_bf16,
"fp16": (torch.cuda.is_available() and not use_bf16),
})
training_args = TrainingArguments(**ta_kwargs2)
trainer_kwargs = {}
if "processing_class" in inspect.signature(Trainer.__init__).parameters:
trainer_kwargs["processing_class"] = tokenizer
else:
trainer_kwargs["tokenizer"] = tokenizer
trainer = DebugTrainer(
model=model,
args=training_args,
train_dataset=train_stream,
eval_dataset=eval_dataset,
#tokenizer=tokenizer,
#processing_class=tokenizer,
data_collator=data_collator,
**trainer_kwargs,
)
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
# ==== 断点恢复判定(非共享盘安全写法)====
def last_step(path: str) -> int:
ck = get_last_checkpoint(path)
if ck is None:
return -1
base = os.path.basename(ck)
try:
return int(base.split("-")[-1])
except Exception:
return -1
local_last = last_step(args.output_dir) # -1 表示本机没有任何 checkpoint
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")
resume_flag = None
if dist.is_available() and dist.is_initialized():
# 只要有任意一个 rank 没有 ckpt -> 不恢复
has_local = torch.tensor(1 if local_last >= 0 else 0, device=device)
dist.all_reduce(has_local, op=dist.ReduceOp.MIN)
if has_local.item() == 1:
# 全员都有:收集每个 rank 的 last step,取公共最小步 k(每台机器都一定存在)
ts = torch.tensor(local_last, device=device)
world = dist.get_world_size()
buf = [torch.zeros_like(ts) for _ in range(world)]
dist.all_gather(buf, ts)
steps = [b.item() for b in buf]
k = min(steps)
if k >= 0:
resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}")
if is_main_process():
print(f"[resume] steps={steps}, resume={resume_flag}", flush=True)
else:
# 单机或未初始化分布式:本地有就按本机最后一步恢复
if local_last >= 0:
resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}")
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}")
# —— 全局一致性检测:如果有任意 rank 缺这个 ckpt,就禁用恢复 ——
if dist.is_available() and dist.is_initialized():
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")
present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device)
dist.all_reduce(present, op=dist.ReduceOp.MIN)
if present.item() == 0:
if is_main_process():
print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True)
resume_flag = None
dist.barrier()
else:
if resume_flag is not None and not os.path.isdir(resume_flag):
# 单机:缺就直接禁用恢复
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
resume_flag = None
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
print_once("***** Starting training *****")
train_result = trainer.train(resume_from_checkpoint=resume_flag)
trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if eval_dataset is not None:
print_once("***** Running eval *****")
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)
print_once("Done.")
if __name__ == "__main__":
main()