This commit is contained in:
hailin 2025-08-26 08:19:40 +08:00
parent a273c4e2d3
commit c02f76c3ba
2 changed files with 523 additions and 45 deletions

View File

@ -5,9 +5,11 @@ import glob
import socket import socket
import argparse import argparse
import inspect import inspect
import sys
from typing import Dict, List, Iterable, Iterator, Tuple, Optional from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset from torch.utils.data import IterableDataset, Dataset
from datasets import load_dataset from datasets import load_dataset
@ -91,7 +93,6 @@ class QwenChatSFTDataset(IterableDataset):
for ex in self.ex_iter: for ex in self.ex_iter:
msgs = ex.get("messages", None) msgs = ex.get("messages", None)
if not msgs or not isinstance(msgs, list): if not msgs or not isinstance(msgs, list):
# 严格要求 messages 格式;发现旧的 "text" 数据直接跳过
continue continue
# 可选:过滤掉带有非空 <think>…</think> 的样本(避免训练真实 COT # 可选:过滤掉带有非空 <think>…</think> 的样本(避免训练真实 COT
@ -108,11 +109,11 @@ class QwenChatSFTDataset(IterableDataset):
tools = ex.get("tools", None) tools = ex.get("tools", None)
# 1) 按模型自带模板渲染(不要手写) # 1) 按模型自带模板渲染
rendered: str = self.tok.apply_chat_template( rendered: str = self.tok.apply_chat_template(
msgs, msgs,
tools=tools, tools=tools,
add_generation_prompt=False, # 训练包含 assistant 答案 add_generation_prompt=False,
tokenize=False tokenize=False
) )
if not isinstance(rendered, str) or not rendered.strip(): if not isinstance(rendered, str) or not rendered.strip():
@ -132,6 +133,10 @@ class QwenChatSFTDataset(IterableDataset):
input_ids: List[int] = enc["input_ids"] input_ids: List[int] = enc["input_ids"]
offsets: List[Tuple[int, int]] = enc["offset_mapping"] offsets: List[Tuple[int, int]] = enc["offset_mapping"]
# 空样本防御:分词后长度为 0
if not input_ids:
continue
# 4) 仅 assistant 计损失 # 4) 仅 assistant 计损失
labels = [-100] * len(input_ids) labels = [-100] * len(input_ids)
@ -150,6 +155,10 @@ class QwenChatSFTDataset(IterableDataset):
input_ids = input_ids[-self.seq_len:] input_ids = input_ids[-self.seq_len:]
labels = labels[-self.seq_len:] labels = labels[-self.seq_len:]
# 若没有任何可训练 tokenlabels 全 -100也跳过
if all(v == -100 for v in labels):
continue
yield { yield {
"input_ids": torch.tensor(input_ids, dtype=torch.long), "input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.ones(len(input_ids), dtype=torch.long), "attention_mask": torch.ones(len(input_ids), dtype=torch.long),
@ -206,7 +215,7 @@ def parse_args():
ap.add_argument("--max_steps", type=int, default=-1) ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10) ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500) ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0) # 兜底抽样评估 ap.add_argument("--eval_ratio", type=float, default=0.0)
ap.add_argument("--seed", type=int, default=1337) ap.add_argument("--seed", type=int, default=1337)
ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json") ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
ap.add_argument("--gradient_checkpointing", action="store_true") ap.add_argument("--gradient_checkpointing", action="store_true")
@ -229,13 +238,20 @@ def main():
args = parse_args() args = parse_args()
set_seed(args.seed) set_seed(args.seed)
# ---- 初始化分布式(供一致性探针使用)----
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
if world_size > 1 and dist.is_available() and not dist.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend=backend, init_method="env://")
if torch.cuda.is_available() and local_rank >= 0:
torch.cuda.set_device(local_rank)
# 1) 先补 tokenizer 的 pad # 1) 先补 tokenizer 的 pad
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None: if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # 供 padding 使用 tokenizer.pad_token = tokenizer.eos_token
# 可选:让警告更少
tokenizer.model_max_length = args.seq_len tokenizer.model_max_length = args.seq_len
# 2) 再加载模型 # 2) 再加载模型
@ -246,23 +262,20 @@ def main():
trust_remote_code=True trust_remote_code=True
) )
# 3) 最后对齐模型的 pad_token_id # 3) pad/alibi 等配置
model.config.pad_token_id = tokenizer.pad_token_id model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False # 训练时禁用 cache model.config.use_cache = False
if args.gradient_checkpointing: if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
try: try:
torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True) # 走 math 实现 torch.backends.cuda.enable_math_sdp(True)
except Exception: except Exception:
pass pass
# ===== 数据鲁棒性检查(多机各自执行)===== # ===== 数据鲁棒性检查(多机各自执行)=====
host = socket.gethostname() host = socket.gethostname()
rank = int(os.environ.get("RANK", "0"))
files = sorted(glob.glob(args.data_glob)) files = sorted(glob.glob(args.data_glob))
if len(files) == 0: if len(files) == 0:
@ -274,23 +287,14 @@ def main():
if is_main_process(): if is_main_process():
print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True) print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True)
# streaming 逐行读取messages/tools 结构) # ====== 小探针:样本结构 ======
ds_stream = load_dataset( ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
"json", def ex_iter_probe():
data_files={"train": files}, for ex in ds_stream_probe:
split="train",
streaming=True
)
def ex_iter():
for ex in ds_stream:
yield ex yield ex
train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len)
train_stream_probe = QwenChatSFTDataset(ex_iter(), tokenizer, seq_len=args.seq_len)
# 探针:确保能产出至少一个样本
_probe_it = iter(train_stream_probe)
try: try:
_ = next(_probe_it) _ = next(iter(train_stream_probe))
except StopIteration: except StopIteration:
raise RuntimeError( raise RuntimeError(
f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n" f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n"
@ -299,22 +303,42 @@ def main():
"另外检查 seq_len 是否过小导致全部被裁。" "另外检查 seq_len 是否过小导致全部被裁。"
) )
# 探针已消耗流;为正式训练重建一次 # ====== 正式训练流 + 模数分片(不要求样本数整除 world_size ======
ds_stream2 = load_dataset( ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
"json",
data_files={"train": files},
split="train",
streaming=True
)
# 多机/多卡分片(让每个全局 rank 读不同子流)
# world_size = int(os.environ.get("WORLD_SIZE", "1"))
# ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
def ex_iter2(): def ex_iter2():
for ex in ds_stream2: for i, ex in enumerate(ds_stream2):
yield ex if i % max(world_size, 1) == rank:
yield ex
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
# ====== 一致性探针:任意 rank 无样本 -> 全体退出 ======
def has_one_sample(stream):
it = iter(stream)
try:
next(it); return 1
except StopIteration:
return 0
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter2_probe():
for i, ex in enumerate(ds_stream_probe2):
if i % max(world_size, 1) == rank:
yield ex
probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
local_ok = has_one_sample(probe_stream)
if dist.is_available() and dist.is_initialized():
t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
dist.all_reduce(t, op=dist.ReduceOp.MIN)
if t.item() == 0:
if is_main_process():
print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True)
dist.barrier()
sys.exit(2)
else:
if local_ok == 0:
print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2)
# ---- Eval 构造:优先使用 --eval_data_glob否则才用 eval_ratio 抽样 ---- # ---- Eval 构造:优先使用 --eval_data_glob否则才用 eval_ratio 抽样 ----
eval_dataset: Optional[Dataset] = None eval_dataset: Optional[Dataset] = None
@ -346,7 +370,6 @@ def main():
eval_dataset = ListDataset(eval_items) eval_dataset = ListDataset(eval_items)
elif args.eval_ratio and args.eval_ratio > 0: elif args.eval_ratio and args.eval_ratio > 0:
# 简易头部抽样(流式下仅作粗评)
desired_eval_batches = 200 desired_eval_batches = 200
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True) tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_eval2(): def ex_iter_eval2():
@ -363,7 +386,8 @@ def main():
if len(eval_samples) > 0: if len(eval_samples) > 0:
eval_dataset = ListDataset(eval_samples) eval_dataset = ListDataset(eval_samples)
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len) # 更稳:联调阶段不强行 pad 到 4096
data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs") logging_dir = os.path.join(args.output_dir, "logs")
@ -401,7 +425,7 @@ def main():
save_steps=args.save_steps, save_steps=args.save_steps,
save_total_limit=2, save_total_limit=2,
deepspeed=args.deepspeed, deepspeed=args.deepspeed,
dataloader_drop_last=True, dataloader_drop_last=False, # 关键:别丢尾,避免空 batch
dataloader_num_workers=0, dataloader_num_workers=0,
dataloader_prefetch_factor=None, dataloader_prefetch_factor=None,
dataloader_pin_memory=False, dataloader_pin_memory=False,

454
train_sft_ds.py.old Normal file
View File

@ -0,0 +1,454 @@
#!/usr/bin/env python3
import os
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import glob
import socket
import argparse
import inspect
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch
from torch.utils.data import IterableDataset, Dataset
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
set_seed
)
from transformers.trainer_callback import TrainerCallback
# ----------------- 进程工具 -----------------
def is_main_process():
return int(os.environ.get("RANK", "0")) == 0
def print_once(*args, **kwargs):
if is_main_process():
print(*args, **kwargs, flush=True)
# ----------------- 日志回调 -----------------
class CsvLossLogger(TrainerCallback):
def __init__(self, csv_path: str):
self.csv_path = csv_path
if is_main_process():
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
with open(self.csv_path, "w", encoding="utf-8") as f:
f.write("step,loss,lr,total_flos\n")
def on_log(self, args, state, control, logs=None, **kwargs):
if not is_main_process() or logs is None:
return
with open(self.csv_path, "a", encoding="utf-8") as f:
f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
# ----------------- 仅监督 assistant 的数据集 -----------------
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
"""
apply_chat_template 渲染后的文本中返回所有 assistant 内容的字符区间 [start, end)
"""
spans: List[Tuple[int, int]] = []
open_tag = "<|im_start|>assistant\n"
close_tag = "<|im_end|>\n"
pos = 0
while True:
a = rendered.find(open_tag, pos)
if a == -1:
break
start = a + len(open_tag)
b = rendered.find(close_tag, start)
if b == -1:
break
spans.append((start, b))
pos = b + len(close_tag)
return spans
class QwenChatSFTDataset(IterableDataset):
"""
期望 jsonl 每行形如
{"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
可选包含工具
{"messages":[...], "tools":[{...}]}
工作流
- 使用 tokenizer.apply_chat_template 渲染
- 仅对 assistant 片段计损失其他 token label = -100
- 超长序列保留尾部通常包含回答
"""
def __init__(self,
ex_iter: Iterable[dict],
tokenizer: AutoTokenizer,
seq_len: int = 4096):
self.ex_iter = ex_iter
self.tok = tokenizer
self.seq_len = seq_len
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
for ex in self.ex_iter:
msgs = ex.get("messages", None)
if not msgs or not isinstance(msgs, list):
# 严格要求 messages 格式;发现旧的 "text" 数据直接跳过
continue
# 可选:过滤掉带有非空 <think>…</think> 的样本(避免训练真实 COT
bad = False
for m in msgs:
if m.get("role") == "assistant" and isinstance(m.get("content"), str):
c = m["content"]
if "<think>" in c and "</think>" in c:
inner = c.split("<think>")[-1].split("</think>")[0].strip()
if inner:
bad = True; break
if bad:
continue
tools = ex.get("tools", None)
# 1) 按模型自带模板渲染(不要手写)
rendered: str = self.tok.apply_chat_template(
msgs,
tools=tools,
add_generation_prompt=False, # 训练包含 assistant 答案
tokenize=False
)
if not isinstance(rendered, str) or not rendered.strip():
continue
# 2) 找出 assistant 片段的字符区间
spans = _assistant_char_spans(rendered)
if not spans:
continue
# 3) 分词 + 字符/Token 对齐
enc = self.tok(
rendered,
add_special_tokens=False,
return_offsets_mapping=True
)
input_ids: List[int] = enc["input_ids"]
offsets: List[Tuple[int, int]] = enc["offset_mapping"]
# 4) 仅 assistant 计损失
labels = [-100] * len(input_ids)
def in_any_span(lo: int, hi: int) -> bool:
for s, e in spans:
if not (hi <= s or lo >= e):
return True
return False
for i, (lo, hi) in enumerate(offsets):
if in_any_span(lo, hi):
labels[i] = input_ids[i]
# 5) 超长裁剪(保留尾部)
if len(input_ids) > self.seq_len:
input_ids = input_ids[-self.seq_len:]
labels = labels[-self.seq_len:]
yield {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.ones(len(input_ids), dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long)
}
# ----------------- 专用 Collatorpad inputs, pad labels=-100 -----------------
class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
self.tok = tokenizer
self.pad_to_length = pad_to_length
assert self.tok.pad_token_id is not None
def __call__(self, features):
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
input_ids = [_to_list(f["input_ids"]) for f in features]
attn_masks = [_to_list(f["attention_mask"]) for f in features]
labels_list = [_to_list(f["labels"]) for f in features]
max_len_in_batch = max(len(x) for x in input_ids)
target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch
pad_id = self.tok.pad_token_id
batch_inp, batch_attn, batch_lab = [], [], []
for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
pad_len = target_len - len(inp)
if pad_len < 0:
inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len]
pad_len = 0
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
return {
"input_ids": torch.stack(batch_inp, dim=0),
"attention_mask": torch.stack(batch_attn, dim=0),
"labels": torch.stack(batch_lab, dim=0),
}
# ----------------- 参数 -----------------
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--model_name_or_path", type=str, required=True,
help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B")
ap.add_argument("--data_glob", type=str, required=True,
help="本地 jsonl 通配符(每台机器都需有同路径数据;每行应含 messages/可选 tools")
ap.add_argument("--output_dir", type=str, required=True,
help="本地输出目录(各节点各自本地写)")
ap.add_argument("--seq_len", type=int, default=4096)
ap.add_argument("--learning_rate", type=float, default=2e-5)
ap.add_argument("--weight_decay", type=float, default=0.1)
ap.add_argument("--warmup_ratio", type=float, default=0.02)
ap.add_argument("--num_train_epochs", type=float, default=1.0)
ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0) # 兜底抽样评估
ap.add_argument("--seed", type=int, default=1337)
ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
ap.add_argument("--gradient_checkpointing", action="store_true")
ap.add_argument("--bf16", action="store_true",
help="3090/A100/H100 等可开 bf16同时在 DS 配置里也要开")
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
ap.add_argument("--report_to", type=str, default="tensorboard",
choices=["none","tensorboard","wandb"])
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
ap.add_argument("--eval_data_glob", type=str, default=None,
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
ap.add_argument("--local_rank", type=int, default=-1,
help="for deepspeed/torchrun launcher; ignored by user code")
return ap.parse_args()
# ----------------- 主函数 -----------------
def main():
args = parse_args()
set_seed(args.seed)
# 1) 先补 tokenizer 的 pad
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # 供 padding 使用
# 可选:让警告更少
tokenizer.model_max_length = args.seq_len
# 2) 再加载模型
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
low_cpu_mem_usage=True,
trust_remote_code=True
)
# 3) 最后对齐模型的 pad_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False # 训练时禁用 cache
if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
try:
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True) # 走 math 实现
except Exception:
pass
# ===== 数据鲁棒性检查(多机各自执行)=====
host = socket.gethostname()
rank = int(os.environ.get("RANK", "0"))
files = sorted(glob.glob(args.data_glob))
if len(files) == 0:
raise FileNotFoundError(
f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n"
"每台机器都必须在相同本地路径下放置数据;"
"可通过 DATA_GLOB=<your_glob> ./run_ds.sh 覆写。"
)
if is_main_process():
print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True)
# streaming 逐行读取messages/tools 结构)
ds_stream = load_dataset(
"json",
data_files={"train": files},
split="train",
streaming=True
)
def ex_iter():
for ex in ds_stream:
yield ex
train_stream_probe = QwenChatSFTDataset(ex_iter(), tokenizer, seq_len=args.seq_len)
# 探针:确保能产出至少一个样本
_probe_it = iter(train_stream_probe)
try:
_ = next(_probe_it)
except StopIteration:
raise RuntimeError(
f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n"
"请确认每行 JSON 至少包含 'messages'(列表,含 user/assistant字段"
"若含 <think>…</think> 请确保不包含真实思维文本,或移除。\n"
"另外检查 seq_len 是否过小导致全部被裁。"
)
# 探针已消耗流;为正式训练重建一次
ds_stream2 = load_dataset(
"json",
data_files={"train": files},
split="train",
streaming=True
)
# 多机/多卡分片(让每个全局 rank 读不同子流)
world_size = int(os.environ.get("WORLD_SIZE", "1"))
ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
def ex_iter2():
for ex in ds_stream2:
yield ex
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
# ---- Eval 构造:优先使用 --eval_data_glob否则才用 eval_ratio 抽样 ----
eval_dataset: Optional[Dataset] = None
class ListDataset(Dataset):
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
if args.eval_data_glob:
eval_files = sorted(glob.glob(args.eval_data_glob))
if len(eval_files) == 0:
raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
if is_main_process():
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
def ex_iter_eval():
for ex in ds_eval_stream:
yield ex
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
eval_items: List[Dict[str, torch.Tensor]] = []
for sample in eval_iterable:
eval_items.append(sample)
if len(eval_items) == 0:
raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。")
eval_dataset = ListDataset(eval_items)
elif args.eval_ratio and args.eval_ratio > 0:
# 简易头部抽样(流式下仅作粗评)
desired_eval_batches = 200
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_eval2():
for ex in tmp_stream:
yield ex
eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len)
eval_samples = []
it = iter(eval_stream)
for _ in range(desired_eval_batches):
try:
eval_samples.append(next(it))
except StopIteration:
break
if len(eval_samples) > 0:
eval_dataset = ListDataset(eval_samples)
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs")
os.makedirs(logging_dir, exist_ok=True)
# ---- 兼容 4.51eval_strategy与旧版evaluation_strategy ----
ta_kwargs = {}
sig = inspect.signature(TrainingArguments.__init__).parameters
if eval_dataset is not None:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "steps"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "steps"
else:
if "eval_strategy" in sig:
ta_kwargs["eval_strategy"] = "no"
elif "evaluation_strategy" in sig:
ta_kwargs["evaluation_strategy"] = "no"
training_args = TrainingArguments(
output_dir=args.output_dir,
logging_dir=logging_dir,
do_train=True,
do_eval=(eval_dataset is not None),
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
max_steps=args.max_steps if args.max_steps > 0 else -1,
lr_scheduler_type="cosine",
logging_steps=args.log_interval,
save_steps=args.save_steps,
save_total_limit=2,
deepspeed=args.deepspeed,
dataloader_drop_last=True,
dataloader_num_workers=0,
dataloader_prefetch_factor=None,
dataloader_pin_memory=False,
report_to=([] if args.report_to == "none" else [args.report_to]),
bf16=args.bf16,
fp16=(not args.bf16),
gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False,
torch_compile=False,
save_on_each_node=False,
logging_first_step=True,
**ta_kwargs,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_stream,
eval_dataset=eval_dataset,
processing_class=tokenizer,
data_collator=data_collator
)
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
ckpt_exists = (os.path.isdir(args.output_dir)
and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir)))
resume_flag = True if ckpt_exists else None
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is True}")
print_once("***** Starting training *****")
train_result = trainer.train(resume_from_checkpoint=resume_flag)
trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if eval_dataset is not None:
print_once("***** Running eval *****")
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)
print_once("Done.")
if __name__ == "__main__":
main()