diff --git a/train_sft_ds.py b/train_sft_ds.py
index 52b6454..b395c07 100644
--- a/train_sft_ds.py
+++ b/train_sft_ds.py
@@ -5,9 +5,11 @@ import glob
import socket
import argparse
import inspect
+import sys
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch
+import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset
from datasets import load_dataset
@@ -91,7 +93,6 @@ class QwenChatSFTDataset(IterableDataset):
for ex in self.ex_iter:
msgs = ex.get("messages", None)
if not msgs or not isinstance(msgs, list):
- # 严格要求 messages 格式;发现旧的 "text" 数据直接跳过
continue
# 可选:过滤掉带有非空 … 的样本(避免训练真实 COT)
@@ -108,11 +109,11 @@ class QwenChatSFTDataset(IterableDataset):
tools = ex.get("tools", None)
- # 1) 按模型自带模板渲染(不要手写)
+ # 1) 按模型自带模板渲染
rendered: str = self.tok.apply_chat_template(
msgs,
tools=tools,
- add_generation_prompt=False, # 训练包含 assistant 答案
+ add_generation_prompt=False,
tokenize=False
)
if not isinstance(rendered, str) or not rendered.strip():
@@ -132,6 +133,10 @@ class QwenChatSFTDataset(IterableDataset):
input_ids: List[int] = enc["input_ids"]
offsets: List[Tuple[int, int]] = enc["offset_mapping"]
+ # 空样本防御:分词后长度为 0
+ if not input_ids:
+ continue
+
# 4) 仅 assistant 计损失
labels = [-100] * len(input_ids)
@@ -150,6 +155,10 @@ class QwenChatSFTDataset(IterableDataset):
input_ids = input_ids[-self.seq_len:]
labels = labels[-self.seq_len:]
+ # 若没有任何可训练 token(labels 全 -100),也跳过
+ if all(v == -100 for v in labels):
+ continue
+
yield {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.ones(len(input_ids), dtype=torch.long),
@@ -206,7 +215,7 @@ def parse_args():
ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500)
- ap.add_argument("--eval_ratio", type=float, default=0.0) # 兜底抽样评估
+ ap.add_argument("--eval_ratio", type=float, default=0.0)
ap.add_argument("--seed", type=int, default=1337)
ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
ap.add_argument("--gradient_checkpointing", action="store_true")
@@ -229,13 +238,20 @@ def main():
args = parse_args()
set_seed(args.seed)
+ # ---- 初始化分布式(供一致性探针使用)----
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
+ rank = int(os.environ.get("RANK", "0"))
+ local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
+ if world_size > 1 and dist.is_available() and not dist.is_initialized():
+ backend = "nccl" if torch.cuda.is_available() else "gloo"
+ dist.init_process_group(backend=backend, init_method="env://")
+ if torch.cuda.is_available() and local_rank >= 0:
+ torch.cuda.set_device(local_rank)
+
# 1) 先补 tokenizer 的 pad
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
-
if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token # 供 padding 使用
-
- # 可选:让警告更少
+ tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = args.seq_len
# 2) 再加载模型
@@ -246,23 +262,20 @@ def main():
trust_remote_code=True
)
- # 3) 最后对齐模型的 pad_token_id
+ # 3) pad/alibi 等配置
model.config.pad_token_id = tokenizer.pad_token_id
- model.config.use_cache = False # 训练时禁用 cache
-
+ model.config.use_cache = False
if args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
-
try:
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
- torch.backends.cuda.enable_math_sdp(True) # 走 math 实现
+ torch.backends.cuda.enable_math_sdp(True)
except Exception:
pass
# ===== 数据鲁棒性检查(多机各自执行)=====
host = socket.gethostname()
- rank = int(os.environ.get("RANK", "0"))
files = sorted(glob.glob(args.data_glob))
if len(files) == 0:
@@ -274,23 +287,14 @@ def main():
if is_main_process():
print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True)
- # streaming 逐行读取(messages/tools 结构)
- ds_stream = load_dataset(
- "json",
- data_files={"train": files},
- split="train",
- streaming=True
- )
-
- def ex_iter():
- for ex in ds_stream:
+ # ====== 小探针:样本结构 ======
+ ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
+ def ex_iter_probe():
+ for ex in ds_stream_probe:
yield ex
-
- train_stream_probe = QwenChatSFTDataset(ex_iter(), tokenizer, seq_len=args.seq_len)
- # 探针:确保能产出至少一个样本
- _probe_it = iter(train_stream_probe)
+ train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len)
try:
- _ = next(_probe_it)
+ _ = next(iter(train_stream_probe))
except StopIteration:
raise RuntimeError(
f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n"
@@ -299,22 +303,42 @@ def main():
"另外检查 seq_len 是否过小导致全部被裁。"
)
- # 探针已消耗流;为正式训练重建一次
- ds_stream2 = load_dataset(
- "json",
- data_files={"train": files},
- split="train",
- streaming=True
- )
-
- # 多机/多卡分片(让每个全局 rank 读不同子流)
- # world_size = int(os.environ.get("WORLD_SIZE", "1"))
- # ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
+ # ====== 正式训练流 + 模数分片(不要求样本数整除 world_size) ======
+ ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter2():
- for ex in ds_stream2:
- yield ex
+ for i, ex in enumerate(ds_stream2):
+ if i % max(world_size, 1) == rank:
+ yield ex
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
+ # ====== 一致性探针:任意 rank 无样本 -> 全体退出 ======
+ def has_one_sample(stream):
+ it = iter(stream)
+ try:
+ next(it); return 1
+ except StopIteration:
+ return 0
+
+ ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
+ def ex_iter2_probe():
+ for i, ex in enumerate(ds_stream_probe2):
+ if i % max(world_size, 1) == rank:
+ yield ex
+ probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
+ local_ok = has_one_sample(probe_stream)
+
+ if dist.is_available() and dist.is_initialized():
+ t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
+ dist.all_reduce(t, op=dist.ReduceOp.MIN)
+ if t.item() == 0:
+ if is_main_process():
+ print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True)
+ dist.barrier()
+ sys.exit(2)
+ else:
+ if local_ok == 0:
+ print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2)
+
# ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
eval_dataset: Optional[Dataset] = None
@@ -346,7 +370,6 @@ def main():
eval_dataset = ListDataset(eval_items)
elif args.eval_ratio and args.eval_ratio > 0:
- # 简易头部抽样(流式下仅作粗评)
desired_eval_batches = 200
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_eval2():
@@ -363,7 +386,8 @@ def main():
if len(eval_samples) > 0:
eval_dataset = ListDataset(eval_samples)
- data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
+ # 更稳:联调阶段不强行 pad 到 4096
+ data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs")
@@ -401,7 +425,7 @@ def main():
save_steps=args.save_steps,
save_total_limit=2,
deepspeed=args.deepspeed,
- dataloader_drop_last=True,
+ dataloader_drop_last=False, # 关键:别丢尾,避免空 batch
dataloader_num_workers=0,
dataloader_prefetch_factor=None,
dataloader_pin_memory=False,
diff --git a/train_sft_ds.py.old b/train_sft_ds.py.old
new file mode 100644
index 0000000..a874b9d
--- /dev/null
+++ b/train_sft_ds.py.old
@@ -0,0 +1,454 @@
+#!/usr/bin/env python3
+import os
+os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
+import glob
+import socket
+import argparse
+import inspect
+from typing import Dict, List, Iterable, Iterator, Tuple, Optional
+
+import torch
+from torch.utils.data import IterableDataset, Dataset
+
+from datasets import load_dataset
+from transformers import (
+ AutoTokenizer,
+ AutoModelForCausalLM,
+ TrainingArguments,
+ Trainer,
+ set_seed
+)
+from transformers.trainer_callback import TrainerCallback
+
+
+# ----------------- 进程工具 -----------------
+def is_main_process():
+ return int(os.environ.get("RANK", "0")) == 0
+
+def print_once(*args, **kwargs):
+ if is_main_process():
+ print(*args, **kwargs, flush=True)
+
+
+# ----------------- 日志回调 -----------------
+class CsvLossLogger(TrainerCallback):
+ def __init__(self, csv_path: str):
+ self.csv_path = csv_path
+ if is_main_process():
+ os.makedirs(os.path.dirname(csv_path), exist_ok=True)
+ with open(self.csv_path, "w", encoding="utf-8") as f:
+ f.write("step,loss,lr,total_flos\n")
+
+ def on_log(self, args, state, control, logs=None, **kwargs):
+ if not is_main_process() or logs is None:
+ return
+ with open(self.csv_path, "a", encoding="utf-8") as f:
+ f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
+
+
+# ----------------- 仅监督 assistant 的数据集 -----------------
+def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
+ """
+ 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。
+ """
+ spans: List[Tuple[int, int]] = []
+ open_tag = "<|im_start|>assistant\n"
+ close_tag = "<|im_end|>\n"
+ pos = 0
+ while True:
+ a = rendered.find(open_tag, pos)
+ if a == -1:
+ break
+ start = a + len(open_tag)
+ b = rendered.find(close_tag, start)
+ if b == -1:
+ break
+ spans.append((start, b))
+ pos = b + len(close_tag)
+ return spans
+
+class QwenChatSFTDataset(IterableDataset):
+ """
+ 期望 jsonl 每行形如:
+ {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]}
+ 可选包含工具:
+ {"messages":[...], "tools":[{...}]}
+
+ 工作流:
+ - 使用 tokenizer.apply_chat_template 渲染
+ - 仅对 assistant 片段计损失(其他 token 的 label = -100)
+ - 超长序列保留尾部(通常包含回答)
+ """
+ def __init__(self,
+ ex_iter: Iterable[dict],
+ tokenizer: AutoTokenizer,
+ seq_len: int = 4096):
+ self.ex_iter = ex_iter
+ self.tok = tokenizer
+ self.seq_len = seq_len
+
+ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
+ for ex in self.ex_iter:
+ msgs = ex.get("messages", None)
+ if not msgs or not isinstance(msgs, list):
+ # 严格要求 messages 格式;发现旧的 "text" 数据直接跳过
+ continue
+
+ # 可选:过滤掉带有非空 … 的样本(避免训练真实 COT)
+ bad = False
+ for m in msgs:
+ if m.get("role") == "assistant" and isinstance(m.get("content"), str):
+ c = m["content"]
+ if "" in c and "" in c:
+ inner = c.split("")[-1].split("")[0].strip()
+ if inner:
+ bad = True; break
+ if bad:
+ continue
+
+ tools = ex.get("tools", None)
+
+ # 1) 按模型自带模板渲染(不要手写)
+ rendered: str = self.tok.apply_chat_template(
+ msgs,
+ tools=tools,
+ add_generation_prompt=False, # 训练包含 assistant 答案
+ tokenize=False
+ )
+ if not isinstance(rendered, str) or not rendered.strip():
+ continue
+
+ # 2) 找出 assistant 片段的字符区间
+ spans = _assistant_char_spans(rendered)
+ if not spans:
+ continue
+
+ # 3) 分词 + 字符/Token 对齐
+ enc = self.tok(
+ rendered,
+ add_special_tokens=False,
+ return_offsets_mapping=True
+ )
+ input_ids: List[int] = enc["input_ids"]
+ offsets: List[Tuple[int, int]] = enc["offset_mapping"]
+
+ # 4) 仅 assistant 计损失
+ labels = [-100] * len(input_ids)
+
+ def in_any_span(lo: int, hi: int) -> bool:
+ for s, e in spans:
+ if not (hi <= s or lo >= e):
+ return True
+ return False
+
+ for i, (lo, hi) in enumerate(offsets):
+ if in_any_span(lo, hi):
+ labels[i] = input_ids[i]
+
+ # 5) 超长裁剪(保留尾部)
+ if len(input_ids) > self.seq_len:
+ input_ids = input_ids[-self.seq_len:]
+ labels = labels[-self.seq_len:]
+
+ yield {
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
+ "attention_mask": torch.ones(len(input_ids), dtype=torch.long),
+ "labels": torch.tensor(labels, dtype=torch.long)
+ }
+
+
+# ----------------- 专用 Collator:pad inputs, pad labels=-100 -----------------
+class SFTDataCollator:
+ def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
+ self.tok = tokenizer
+ self.pad_to_length = pad_to_length
+ assert self.tok.pad_token_id is not None
+
+ def __call__(self, features):
+ def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
+ input_ids = [_to_list(f["input_ids"]) for f in features]
+ attn_masks = [_to_list(f["attention_mask"]) for f in features]
+ labels_list = [_to_list(f["labels"]) for f in features]
+
+ max_len_in_batch = max(len(x) for x in input_ids)
+ target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch
+
+ pad_id = self.tok.pad_token_id
+ batch_inp, batch_attn, batch_lab = [], [], []
+ for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
+ pad_len = target_len - len(inp)
+ if pad_len < 0:
+ inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len]
+ pad_len = 0
+ batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
+ batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
+ batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
+ return {
+ "input_ids": torch.stack(batch_inp, dim=0),
+ "attention_mask": torch.stack(batch_attn, dim=0),
+ "labels": torch.stack(batch_lab, dim=0),
+ }
+
+# ----------------- 参数 -----------------
+def parse_args():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--model_name_or_path", type=str, required=True,
+ help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B)")
+ ap.add_argument("--data_glob", type=str, required=True,
+ help="本地 jsonl 通配符(每台机器都需有同路径数据;每行应含 messages/可选 tools)")
+ ap.add_argument("--output_dir", type=str, required=True,
+ help="本地输出目录(各节点各自本地写)")
+ ap.add_argument("--seq_len", type=int, default=4096)
+ ap.add_argument("--learning_rate", type=float, default=2e-5)
+ ap.add_argument("--weight_decay", type=float, default=0.1)
+ ap.add_argument("--warmup_ratio", type=float, default=0.02)
+ ap.add_argument("--num_train_epochs", type=float, default=1.0)
+ ap.add_argument("--max_steps", type=int, default=-1)
+ ap.add_argument("--log_interval", type=int, default=10)
+ ap.add_argument("--save_steps", type=int, default=500)
+ ap.add_argument("--eval_ratio", type=float, default=0.0) # 兜底抽样评估
+ ap.add_argument("--seed", type=int, default=1337)
+ ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
+ ap.add_argument("--gradient_checkpointing", action="store_true")
+ ap.add_argument("--bf16", action="store_true",
+ help="3090/A100/H100 等可开 bf16;同时在 DS 配置里也要开")
+ ap.add_argument("--per_device_train_batch_size", type=int, default=1)
+ ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
+ ap.add_argument("--report_to", type=str, default="tensorboard",
+ choices=["none","tensorboard","wandb"])
+ ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
+ ap.add_argument("--eval_data_glob", type=str, default=None,
+ help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
+ ap.add_argument("--local_rank", type=int, default=-1,
+ help="for deepspeed/torchrun launcher; ignored by user code")
+ return ap.parse_args()
+
+
+# ----------------- 主函数 -----------------
+def main():
+ args = parse_args()
+ set_seed(args.seed)
+
+ # 1) 先补 tokenizer 的 pad
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
+
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token # 供 padding 使用
+
+ # 可选:让警告更少
+ tokenizer.model_max_length = args.seq_len
+
+ # 2) 再加载模型
+ model = AutoModelForCausalLM.from_pretrained(
+ args.model_name_or_path,
+ torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
+ low_cpu_mem_usage=True,
+ trust_remote_code=True
+ )
+
+ # 3) 最后对齐模型的 pad_token_id
+ model.config.pad_token_id = tokenizer.pad_token_id
+ model.config.use_cache = False # 训练时禁用 cache
+
+ if args.gradient_checkpointing:
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
+
+ try:
+ torch.backends.cuda.enable_flash_sdp(False)
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
+ torch.backends.cuda.enable_math_sdp(True) # 走 math 实现
+ except Exception:
+ pass
+
+ # ===== 数据鲁棒性检查(多机各自执行)=====
+ host = socket.gethostname()
+ rank = int(os.environ.get("RANK", "0"))
+
+ files = sorted(glob.glob(args.data_glob))
+ if len(files) == 0:
+ raise FileNotFoundError(
+ f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n"
+ "每台机器都必须在相同本地路径下放置数据;"
+ "可通过 DATA_GLOB= ./run_ds.sh 覆写。"
+ )
+ if is_main_process():
+ print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True)
+
+ # streaming 逐行读取(messages/tools 结构)
+ ds_stream = load_dataset(
+ "json",
+ data_files={"train": files},
+ split="train",
+ streaming=True
+ )
+
+ def ex_iter():
+ for ex in ds_stream:
+ yield ex
+
+ train_stream_probe = QwenChatSFTDataset(ex_iter(), tokenizer, seq_len=args.seq_len)
+ # 探针:确保能产出至少一个样本
+ _probe_it = iter(train_stream_probe)
+ try:
+ _ = next(_probe_it)
+ except StopIteration:
+ raise RuntimeError(
+ f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n"
+ "请确认每行 JSON 至少包含 'messages'(列表,含 user/assistant)字段;"
+ "若含 … 请确保不包含真实思维文本,或移除。\n"
+ "另外检查 seq_len 是否过小导致全部被裁。"
+ )
+
+ # 探针已消耗流;为正式训练重建一次
+ ds_stream2 = load_dataset(
+ "json",
+ data_files={"train": files},
+ split="train",
+ streaming=True
+ )
+
+ # 多机/多卡分片(让每个全局 rank 读不同子流)
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
+ ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
+ def ex_iter2():
+ for ex in ds_stream2:
+ yield ex
+ train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
+
+ # ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
+ eval_dataset: Optional[Dataset] = None
+
+ class ListDataset(Dataset):
+ def __init__(self, items): self.items = items
+ def __len__(self): return len(self.items)
+ def __getitem__(self, idx): return self.items[idx]
+
+ if args.eval_data_glob:
+ eval_files = sorted(glob.glob(args.eval_data_glob))
+ if len(eval_files) == 0:
+ raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
+ if is_main_process():
+ print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
+
+ ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
+ def ex_iter_eval():
+ for ex in ds_eval_stream:
+ yield ex
+
+ eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
+ eval_items: List[Dict[str, torch.Tensor]] = []
+ for sample in eval_iterable:
+ eval_items.append(sample)
+
+ if len(eval_items) == 0:
+ raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。")
+
+ eval_dataset = ListDataset(eval_items)
+
+ elif args.eval_ratio and args.eval_ratio > 0:
+ # 简易头部抽样(流式下仅作粗评)
+ desired_eval_batches = 200
+ tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
+ def ex_iter_eval2():
+ for ex in tmp_stream:
+ yield ex
+ eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len)
+ eval_samples = []
+ it = iter(eval_stream)
+ for _ in range(desired_eval_batches):
+ try:
+ eval_samples.append(next(it))
+ except StopIteration:
+ break
+ if len(eval_samples) > 0:
+ eval_dataset = ListDataset(eval_samples)
+
+ data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ logging_dir = os.path.join(args.output_dir, "logs")
+ os.makedirs(logging_dir, exist_ok=True)
+
+ # ---- 兼容 4.51(eval_strategy)与旧版(evaluation_strategy) ----
+ ta_kwargs = {}
+ sig = inspect.signature(TrainingArguments.__init__).parameters
+ if eval_dataset is not None:
+ if "eval_strategy" in sig:
+ ta_kwargs["eval_strategy"] = "steps"
+ elif "evaluation_strategy" in sig:
+ ta_kwargs["evaluation_strategy"] = "steps"
+ else:
+ if "eval_strategy" in sig:
+ ta_kwargs["eval_strategy"] = "no"
+ elif "evaluation_strategy" in sig:
+ ta_kwargs["evaluation_strategy"] = "no"
+
+ training_args = TrainingArguments(
+ output_dir=args.output_dir,
+ logging_dir=logging_dir,
+ do_train=True,
+ do_eval=(eval_dataset is not None),
+ eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
+ per_device_train_batch_size=args.per_device_train_batch_size,
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ learning_rate=args.learning_rate,
+ weight_decay=args.weight_decay,
+ warmup_ratio=args.warmup_ratio,
+ num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
+ max_steps=args.max_steps if args.max_steps > 0 else -1,
+ lr_scheduler_type="cosine",
+ logging_steps=args.log_interval,
+ save_steps=args.save_steps,
+ save_total_limit=2,
+ deepspeed=args.deepspeed,
+ dataloader_drop_last=True,
+ dataloader_num_workers=0,
+ dataloader_prefetch_factor=None,
+ dataloader_pin_memory=False,
+ report_to=([] if args.report_to == "none" else [args.report_to]),
+ bf16=args.bf16,
+ fp16=(not args.bf16),
+ gradient_checkpointing=args.gradient_checkpointing,
+ remove_unused_columns=False,
+ torch_compile=False,
+ save_on_each_node=False,
+ logging_first_step=True,
+ **ta_kwargs,
+ )
+
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_stream,
+ eval_dataset=eval_dataset,
+ processing_class=tokenizer,
+ data_collator=data_collator
+ )
+ trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
+
+ # 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
+ ckpt_exists = (os.path.isdir(args.output_dir)
+ and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir)))
+ resume_flag = True if ckpt_exists else None
+
+ print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is True}")
+ print_once("***** Starting training *****")
+ train_result = trainer.train(resume_from_checkpoint=resume_flag)
+ trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型
+
+ metrics = train_result.metrics
+ trainer.log_metrics("train", metrics)
+ trainer.save_metrics("train", metrics)
+ trainer.save_state()
+
+ if eval_dataset is not None:
+ print_once("***** Running eval *****")
+ eval_metrics = trainer.evaluate()
+ trainer.log_metrics("eval", eval_metrics)
+ trainer.save_metrics("eval", eval_metrics)
+
+ print_once("Done.")
+
+
+if __name__ == "__main__":
+ main()