426 lines
17 KiB
Python
426 lines
17 KiB
Python
#!/usr/bin/env python3
|
||
import os
|
||
import glob
|
||
import socket
|
||
import argparse
|
||
import inspect
|
||
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
||
|
||
import torch
|
||
from torch.utils.data import IterableDataset, Dataset
|
||
|
||
from datasets import load_dataset
|
||
from transformers import (
|
||
AutoTokenizer,
|
||
AutoModelForCausalLM,
|
||
TrainingArguments,
|
||
Trainer,
|
||
set_seed
|
||
)
|
||
from transformers.trainer_callback import TrainerCallback
|
||
|
||
|
||
# ----------------- 进程工具 -----------------
|
||
def is_main_process():
|
||
return int(os.environ.get("RANK", "0")) == 0
|
||
|
||
def print_once(*args, **kwargs):
|
||
if is_main_process():
|
||
print(*args, **kwargs, flush=True)
|
||
|
||
|
||
# ----------------- 日志回调 -----------------
|
||
class CsvLossLogger(TrainerCallback):
|
||
def __init__(self, csv_path: str):
|
||
self.csv_path = csv_path
|
||
if is_main_process():
|
||
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
|
||
with open(self.csv_path, "w", encoding="utf-8") as f:
|
||
f.write("step,loss,lr,total_flos\n")
|
||
|
||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||
if not is_main_process() or logs is None:
|
||
return
|
||
with open(self.csv_path, "a", encoding="utf-8") as f:
|
||
f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
|
||
|
||
|
||
# ----------------- 仅监督 assistant 的数据集 -----------------
|
||
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||
"""
|
||
在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
|
||
"""
|
||
spans: List[Tuple[int, int]] = []
|
||
open_tag = "<|im_start|>assistant\n"
|
||
close_tag = "<|im_end|>\n"
|
||
pos = 0
|
||
while True:
|
||
a = rendered.find(open_tag, pos)
|
||
if a == -1:
|
||
break
|
||
start = a + len(open_tag)
|
||
b = rendered.find(close_tag, start)
|
||
if b == -1:
|
||
break
|
||
spans.append((start, b))
|
||
pos = b + len(close_tag)
|
||
return spans
|
||
|
||
class QwenChatSFTDataset(IterableDataset):
|
||
"""
|
||
期望 jsonl 每行形如:
|
||
{"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
|
||
可选包含工具:
|
||
{"messages":[...], "tools":[{...}]}
|
||
|
||
工作流:
|
||
- 使用 tokenizer.apply_chat_template 渲染
|
||
- 仅对 assistant 片段计损失(其他 token 的 label = -100)
|
||
- 超长序列保留尾部(通常包含回答)
|
||
"""
|
||
def __init__(self,
|
||
ex_iter: Iterable[dict],
|
||
tokenizer: AutoTokenizer,
|
||
seq_len: int = 4096):
|
||
self.ex_iter = ex_iter
|
||
self.tok = tokenizer
|
||
self.seq_len = seq_len
|
||
|
||
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
||
for ex in self.ex_iter:
|
||
msgs = ex.get("messages", None)
|
||
if not msgs or not isinstance(msgs, list):
|
||
# 严格要求 messages 格式;发现旧的 "text" 数据直接跳过
|
||
continue
|
||
|
||
# 可选:过滤掉带有非空 <think>…</think> 的样本(避免训练真实 COT)
|
||
bad = False
|
||
for m in msgs:
|
||
if m.get("role") == "assistant" and isinstance(m.get("content"), str):
|
||
c = m["content"]
|
||
if "<think>" in c and "</think>" in c:
|
||
inner = c.split("<think>")[-1].split("</think>")[0].strip()
|
||
if inner:
|
||
bad = True; break
|
||
if bad:
|
||
continue
|
||
|
||
tools = ex.get("tools", None)
|
||
|
||
# 1) 按模型自带模板渲染(不要手写)
|
||
rendered: str = self.tok.apply_chat_template(
|
||
msgs,
|
||
tools=tools,
|
||
add_generation_prompt=False, # 训练包含 assistant 答案
|
||
tokenize=False
|
||
)
|
||
if not isinstance(rendered, str) or not rendered.strip():
|
||
continue
|
||
|
||
# 2) 找出 assistant 片段的字符区间
|
||
spans = _assistant_char_spans(rendered)
|
||
if not spans:
|
||
continue
|
||
|
||
# 3) 分词 + 字符/Token 对齐
|
||
enc = self.tok(
|
||
rendered,
|
||
add_special_tokens=False,
|
||
return_offsets_mapping=True
|
||
)
|
||
input_ids: List[int] = enc["input_ids"]
|
||
offsets: List[Tuple[int, int]] = enc["offset_mapping"]
|
||
|
||
# 4) 仅 assistant 计损失
|
||
labels = [-100] * len(input_ids)
|
||
|
||
def in_any_span(lo: int, hi: int) -> bool:
|
||
for s, e in spans:
|
||
if not (hi <= s or lo >= e):
|
||
return True
|
||
return False
|
||
|
||
for i, (lo, hi) in enumerate(offsets):
|
||
if in_any_span(lo, hi):
|
||
labels[i] = input_ids[i]
|
||
|
||
# 5) 超长裁剪(保留尾部)
|
||
if len(input_ids) > self.seq_len:
|
||
input_ids = input_ids[-self.seq_len:]
|
||
labels = labels[-self.seq_len:]
|
||
|
||
yield {
|
||
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||
"attention_mask": torch.ones(len(input_ids), dtype=torch.long),
|
||
"labels": torch.tensor(labels, dtype=torch.long)
|
||
}
|
||
|
||
|
||
# ----------------- 专用 Collator:pad inputs, pad labels=-100 -----------------
|
||
class SFTDataCollator:
|
||
def __init__(self, tokenizer: AutoTokenizer):
|
||
self.tok = tokenizer
|
||
assert self.tok.pad_token_id is not None, "tokenizer.pad_token 不能为空;已在主函数里兜底为 eos_token"
|
||
|
||
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
||
# 将变长样本对齐到 batch 内最大长度;labels 用 -100 补齐
|
||
def _to_list(x):
|
||
return x.tolist() if isinstance(x, torch.Tensor) else list(x)
|
||
input_ids = [ _to_list(f["input_ids"]) for f in features ]
|
||
attn_masks = [ _to_list(f["attention_mask"]) for f in features ]
|
||
labels_list = [ _to_list(f["labels"]) for f in features ]
|
||
|
||
max_len = max(len(x) for x in input_ids)
|
||
pad_id = self.tok.pad_token_id
|
||
|
||
batch_inp, batch_attn, batch_lab = [], [], []
|
||
for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
|
||
pad_len = max_len - len(inp)
|
||
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
|
||
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
|
||
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
|
||
|
||
return {
|
||
"input_ids": torch.stack(batch_inp, dim=0),
|
||
"attention_mask": torch.stack(batch_attn, dim=0),
|
||
"labels": torch.stack(batch_lab, dim=0),
|
||
}
|
||
|
||
|
||
# ----------------- 参数 -----------------
|
||
def parse_args():
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument("--model_name_or_path", type=str, required=True,
|
||
help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B)")
|
||
ap.add_argument("--data_glob", type=str, required=True,
|
||
help="本地 jsonl 通配符(每台机器都需有同路径数据;每行应含 messages/可选 tools)")
|
||
ap.add_argument("--output_dir", type=str, required=True,
|
||
help="本地输出目录(各节点各自本地写)")
|
||
ap.add_argument("--seq_len", type=int, default=4096)
|
||
ap.add_argument("--learning_rate", type=float, default=2e-5)
|
||
ap.add_argument("--weight_decay", type=float, default=0.1)
|
||
ap.add_argument("--warmup_ratio", type=float, default=0.02)
|
||
ap.add_argument("--num_train_epochs", type=float, default=1.0)
|
||
ap.add_argument("--max_steps", type=int, default=-1)
|
||
ap.add_argument("--log_interval", type=int, default=10)
|
||
ap.add_argument("--save_steps", type=int, default=500)
|
||
ap.add_argument("--eval_ratio", type=float, default=0.0) # 兜底抽样评估
|
||
ap.add_argument("--seed", type=int, default=1337)
|
||
ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
|
||
ap.add_argument("--gradient_checkpointing", action="store_true")
|
||
ap.add_argument("--bf16", action="store_true",
|
||
help="3090/A100/H100 等可开 bf16;同时在 DS 配置里也要开")
|
||
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
|
||
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
|
||
ap.add_argument("--report_to", type=str, default="tensorboard",
|
||
choices=["none","tensorboard","wandb"])
|
||
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
|
||
ap.add_argument("--eval_data_glob", type=str, default=None,
|
||
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
|
||
ap.add_argument("--local_rank", type=int, default=-1,
|
||
help="for deepspeed/torchrun launcher; ignored by user code")
|
||
return ap.parse_args()
|
||
|
||
|
||
# ----------------- 主函数 -----------------
|
||
def main():
|
||
args = parse_args()
|
||
set_seed(args.seed)
|
||
|
||
# Tokenizer/Model
|
||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
|
||
if tokenizer.pad_token is None:
|
||
tokenizer.pad_token = tokenizer.eos_token # 供 padding 使用
|
||
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
args.model_name_or_path,
|
||
torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
|
||
low_cpu_mem_usage=True,
|
||
trust_remote_code=True
|
||
)
|
||
model.config.use_cache = False # 训练时禁用 cache
|
||
if args.gradient_checkpointing:
|
||
model.gradient_checkpointing_enable()
|
||
|
||
# ===== 数据鲁棒性检查(多机各自执行)=====
|
||
host = socket.gethostname()
|
||
rank = int(os.environ.get("RANK", "0"))
|
||
|
||
files = sorted(glob.glob(args.data_glob))
|
||
if len(files) == 0:
|
||
raise FileNotFoundError(
|
||
f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n"
|
||
"每台机器都必须在相同本地路径下放置数据;"
|
||
"可通过 DATA_GLOB=<your_glob> ./run_ds.sh 覆写。"
|
||
)
|
||
if is_main_process():
|
||
print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True)
|
||
|
||
# streaming 逐行读取(messages/tools 结构)
|
||
ds_stream = load_dataset(
|
||
"json",
|
||
data_files={"train": files},
|
||
split="train",
|
||
streaming=True
|
||
)
|
||
|
||
def ex_iter():
|
||
for ex in ds_stream:
|
||
yield ex
|
||
|
||
train_stream_probe = QwenChatSFTDataset(ex_iter(), tokenizer, seq_len=args.seq_len)
|
||
# 探针:确保能产出至少一个样本
|
||
_probe_it = iter(train_stream_probe)
|
||
try:
|
||
_ = next(_probe_it)
|
||
except StopIteration:
|
||
raise RuntimeError(
|
||
f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n"
|
||
"请确认每行 JSON 至少包含 'messages'(列表,含 user/assistant)字段;"
|
||
"若含 <think>…</think> 请确保不包含真实思维文本,或移除。\n"
|
||
"另外检查 seq_len 是否过小导致全部被裁。"
|
||
)
|
||
|
||
# 探针已消耗流;为正式训练重建一次
|
||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||
def ex_iter2():
|
||
for ex in ds_stream2:
|
||
yield ex
|
||
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
||
|
||
# ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
|
||
eval_dataset: Optional[Dataset] = None
|
||
|
||
class ListDataset(Dataset):
|
||
def __init__(self, items): self.items = items
|
||
def __len__(self): return len(self.items)
|
||
def __getitem__(self, idx): return self.items[idx]
|
||
|
||
if args.eval_data_glob:
|
||
eval_files = sorted(glob.glob(args.eval_data_glob))
|
||
if len(eval_files) == 0:
|
||
raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
|
||
if is_main_process():
|
||
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
|
||
|
||
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
|
||
def ex_iter_eval():
|
||
for ex in ds_eval_stream:
|
||
yield ex
|
||
|
||
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
|
||
eval_items: List[Dict[str, torch.Tensor]] = []
|
||
for sample in eval_iterable:
|
||
eval_items.append(sample)
|
||
|
||
if len(eval_items) == 0:
|
||
raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。")
|
||
|
||
eval_dataset = ListDataset(eval_items)
|
||
|
||
elif args.eval_ratio and args.eval_ratio > 0:
|
||
# 简易头部抽样(流式下仅作粗评)
|
||
desired_eval_batches = 200
|
||
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||
def ex_iter_eval2():
|
||
for ex in tmp_stream:
|
||
yield ex
|
||
eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len)
|
||
eval_samples = []
|
||
it = iter(eval_stream)
|
||
for _ in range(desired_eval_batches):
|
||
try:
|
||
eval_samples.append(next(it))
|
||
except StopIteration:
|
||
break
|
||
if len(eval_samples) > 0:
|
||
eval_dataset = ListDataset(eval_samples)
|
||
|
||
data_collator = SFTDataCollator(tokenizer)
|
||
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
logging_dir = os.path.join(args.output_dir, "logs")
|
||
os.makedirs(logging_dir, exist_ok=True)
|
||
|
||
# ---- 兼容 4.51(eval_strategy)与旧版(evaluation_strategy) ----
|
||
ta_kwargs = {}
|
||
sig = inspect.signature(TrainingArguments.__init__).parameters
|
||
if eval_dataset is not None:
|
||
if "eval_strategy" in sig:
|
||
ta_kwargs["eval_strategy"] = "steps"
|
||
elif "evaluation_strategy" in sig:
|
||
ta_kwargs["evaluation_strategy"] = "steps"
|
||
else:
|
||
if "eval_strategy" in sig:
|
||
ta_kwargs["eval_strategy"] = "no"
|
||
elif "evaluation_strategy" in sig:
|
||
ta_kwargs["evaluation_strategy"] = "no"
|
||
|
||
training_args = TrainingArguments(
|
||
output_dir=args.output_dir,
|
||
logging_dir=logging_dir,
|
||
do_train=True,
|
||
do_eval=(eval_dataset is not None),
|
||
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
|
||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||
learning_rate=args.learning_rate,
|
||
weight_decay=args.weight_decay,
|
||
warmup_ratio=args.warmup_ratio,
|
||
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
|
||
max_steps=args.max_steps if args.max_steps > 0 else -1,
|
||
lr_scheduler_type="cosine",
|
||
logging_steps=args.log_interval,
|
||
save_steps=args.save_steps,
|
||
save_total_limit=2,
|
||
deepspeed=args.deepspeed,
|
||
dataloader_drop_last=True,
|
||
dataloader_num_workers=0,
|
||
report_to=([] if args.report_to == "none" else [args.report_to]),
|
||
bf16=args.bf16,
|
||
fp16=(not args.bf16),
|
||
gradient_checkpointing=args.gradient_checkpointing,
|
||
remove_unused_columns=False,
|
||
torch_compile=False,
|
||
save_on_each_node=False,
|
||
logging_first_step=True,
|
||
**ta_kwargs,
|
||
)
|
||
|
||
trainer = Trainer(
|
||
model=model,
|
||
args=training_args,
|
||
train_dataset=train_stream,
|
||
eval_dataset=eval_dataset,
|
||
tokenizer=tokenizer,
|
||
data_collator=data_collator
|
||
)
|
||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||
|
||
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
|
||
ckpt_exists = (os.path.isdir(args.output_dir)
|
||
and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir)))
|
||
resume_flag = True if ckpt_exists else None
|
||
|
||
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is True}")
|
||
print_once("***** Starting training *****")
|
||
train_result = trainer.train(resume_from_checkpoint=resume_flag)
|
||
trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型
|
||
|
||
metrics = train_result.metrics
|
||
trainer.log_metrics("train", metrics)
|
||
trainer.save_metrics("train", metrics)
|
||
trainer.save_state()
|
||
|
||
if eval_dataset is not None:
|
||
print_once("***** Running eval *****")
|
||
eval_metrics = trainer.evaluate()
|
||
trainer.log_metrics("eval", eval_metrics)
|
||
trainer.save_metrics("eval", eval_metrics)
|
||
|
||
print_once("Done.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|