1023 lines
42 KiB
Python
1023 lines
42 KiB
Python
import os
|
||
os.environ.pop("PYTHONNOUSERSITE", None)
|
||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||
os.environ.setdefault("WANDB_START_METHOD", "thread")
|
||
os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb")
|
||
os.environ.setdefault("WANDB_BASE_URL", "https://wandb.szaiai.com")
|
||
os.environ.setdefault("WANDB_INIT_TIMEOUT", "300")
|
||
|
||
import glob
|
||
import socket
|
||
import argparse
|
||
import inspect
|
||
import sys
|
||
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
||
|
||
import torch
|
||
import torch.distributed as dist
|
||
from torch.utils.data import IterableDataset, Dataset
|
||
|
||
from datasets import load_dataset
|
||
from transformers import (
|
||
AutoTokenizer,
|
||
AutoModelForCausalLM,
|
||
TrainingArguments,
|
||
Trainer,
|
||
set_seed
|
||
)
|
||
from transformers.trainer_callback import TrainerCallback
|
||
from transformers.trainer_utils import get_last_checkpoint
|
||
from torch.optim import AdamW as TorchAdamW
|
||
|
||
# ==== make sure CLI ninja/nvcc are reachable even in non-interactive ssh ====
|
||
import site, shutil
|
||
|
||
home = os.path.expanduser("~")
|
||
want = [f"{home}/.local/bin", "/usr/local/cuda-11.8/bin"]
|
||
cur = os.environ.get("PATH", "").split(":")
|
||
new = [d for d in want if d and d not in cur] + cur
|
||
os.environ["PATH"] = ":".join(new)
|
||
|
||
# 可见性打印,方便你在日志里确认 tn06 是否拿到了
|
||
print(f"[env] PATH={os.environ['PATH']}", flush=True)
|
||
print(f"[env] which ninja={shutil.which('ninja')} which nvcc={shutil.which('nvcc')}", flush=True)
|
||
|
||
|
||
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda-11.8")
|
||
ld = os.environ.get("LD_LIBRARY_PATH", "")
|
||
cuda_lib = "/usr/local/cuda-11.8/lib64"
|
||
if cuda_lib not in ld.split(":"):
|
||
os.environ["LD_LIBRARY_PATH"] = f"{cuda_lib}:{ld}" if ld else cuda_lib
|
||
|
||
# 可视化确认
|
||
print(f"[env] torch.version.cuda={torch.version.cuda} CUDA_HOME={os.environ['CUDA_HOME']}", flush=True)
|
||
|
||
# 1) 确保不会屏蔽用户站点包(ninja 安在 ~/.local 里)
|
||
os.environ.pop("DS_BUILD_OPS", None)
|
||
os.environ.pop("DS_SKIP_CUDA_BUILD", None)
|
||
|
||
# 2) 把用户站点目录插入 sys.path(比如 /home/test/.local/lib/python3.10/site-packages)
|
||
try:
|
||
user_site = site.getusersitepackages()
|
||
if user_site and user_site not in sys.path:
|
||
sys.path.insert(0, user_site)
|
||
except Exception:
|
||
pass
|
||
|
||
# 3) 统一 JIT 缓存目录(可选,但更稳;日志里你现在用的是 ~/.cache/torch_extensions)
|
||
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
|
||
os.environ.setdefault("MAX_JOBS", "12")
|
||
|
||
if shutil.which("ninja") is None:
|
||
os.environ["USE_NINJA"] = "0"
|
||
print("[env] no CLI ninja on PATH -> USE_NINJA=0 fallback", flush=True)
|
||
|
||
# 4) 立即验证 ninja 与 CPUAdam 的 JIT(若这里失败,日志会第一时间告诉你是哪台/哪 rank 环境不对)
|
||
try:
|
||
from deepspeed.ops.op_builder import CPUAdamBuilder
|
||
CPUAdamBuilder().load()
|
||
print("[env] CPUAdamBuilder JIT OK", flush=True)
|
||
except Exception as e:
|
||
# ninja 可执行找不到时走兜底:禁用 ninja,用 setuptools 构建(首次会慢一点,但必过)
|
||
if "Ninja is required to load C++ extensions" in str(e):
|
||
os.environ["USE_NINJA"] = "0"
|
||
print("[env] no CLI ninja, retry with USE_NINJA=0 (fallback build)", flush=True)
|
||
from deepspeed.ops.op_builder import CPUAdamBuilder
|
||
CPUAdamBuilder().load()
|
||
print("[env] CPUAdamBuilder JIT OK (fallback)", flush=True)
|
||
else:
|
||
import socket
|
||
print(f"[env][host={socket.gethostname()} RANK={os.environ.get('RANK','?')}] PRE-JIT FAILED: {e}", flush=True)
|
||
raise
|
||
|
||
|
||
# ----------------- 进程工具 -----------------
|
||
def is_main_process():
|
||
return int(os.environ.get("RANK", "0")) == 0
|
||
|
||
def print_once(*args, **kwargs):
|
||
if is_main_process():
|
||
print(*args, **kwargs, flush=True)
|
||
|
||
class DebugTrainer(Trainer):
|
||
def training_step(self, model, inputs, num_items_in_batch=None):
|
||
if not hasattr(self, "_dbg_printed"):
|
||
rank = int(os.environ.get("RANK", "0"))
|
||
host = socket.gethostname()
|
||
ids = inputs["input_ids"]
|
||
msk = inputs["attention_mask"]
|
||
labs = inputs["labels"]
|
||
print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} "
|
||
f"supervised={(labs!=-100).sum().item()}",
|
||
flush=True)
|
||
print(
|
||
f"[step0][host={host} RANK={rank}] "
|
||
f"input_ids.shape={tuple(ids.shape)} "
|
||
f"attention_mask.shape={tuple(msk.shape)} "
|
||
f"labels.shape={tuple(labs.shape)} "
|
||
f"num_items_in_batch={num_items_in_batch}",
|
||
flush=True
|
||
)
|
||
self._dbg_printed = True
|
||
|
||
try:
|
||
return super().training_step(model, inputs, num_items_in_batch=num_items_in_batch)
|
||
except TypeError:
|
||
return super().training_step(model, inputs)
|
||
|
||
# return super().training_step(model, inputs, num_items_in_batch)
|
||
|
||
# ----------------- 日志回调 -----------------
|
||
class CsvLossLogger(TrainerCallback):
|
||
def __init__(self, csv_path: str):
|
||
self.csv_path = csv_path
|
||
if is_main_process():
|
||
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
|
||
with open(self.csv_path, "w", encoding="utf-8") as f:
|
||
f.write("step,loss,lr,total_flos\n")
|
||
|
||
def on_train_begin(self, args, state, control, **kwargs):
|
||
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
|
||
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
|
||
|
||
rank = os.environ.get("RANK", "?")
|
||
host = socket.gethostname()
|
||
print(f"[{host} rank={rank}] total_steps={tot}", flush=True)
|
||
|
||
|
||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||
if logs is None:
|
||
return
|
||
|
||
# ---- 控制台打印:所有 rank 都打当前步/总步 ----
|
||
cur = int(getattr(state, "global_step", 0) or 0)
|
||
|
||
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
|
||
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
|
||
pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a")
|
||
|
||
# —— tot 一旦可用,就再宣布一次总步数(只打印一次)
|
||
if tot and not hasattr(self, "_tot_announced"):
|
||
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
|
||
self._tot_announced = True
|
||
|
||
rank = os.environ.get("RANK", "?")
|
||
host = socket.gethostname()
|
||
print(
|
||
f"[{host} rank={rank}] step {cur}/{tot} ({pct}) "
|
||
f"loss={logs.get('loss')} lr={logs.get('learning_rate')}",
|
||
flush=True
|
||
)
|
||
|
||
# ---- 只在主进程写 CSV,避免并发写 ----
|
||
if not is_main_process():
|
||
return
|
||
with open(self.csv_path, "a", encoding="utf-8") as f:
|
||
f.write(
|
||
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
|
||
)
|
||
|
||
from typing import List, Tuple, Iterable, Iterator, Dict
|
||
|
||
# ----------------- 工具:提取 assistant 字符区间 -----------------
|
||
# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||
# spans: List[Tuple[int, int]] = []
|
||
# open_tag = "<|im_start|>assistant\n"
|
||
# close_token = "<|im_end|>"
|
||
# close_tag = close_token + "\n" # 常见模板带换行
|
||
|
||
# pos = 0
|
||
# while True:
|
||
# a = rendered.find(open_tag, pos)
|
||
# if a == -1:
|
||
# break
|
||
# start = a + len(open_tag)
|
||
|
||
# # 先找含换行版本,找不到再退化找不带换行的
|
||
# b = rendered.find(close_tag, start)
|
||
# if b == -1:
|
||
# b = rendered.find(close_token, start)
|
||
# if b == -1:
|
||
# break
|
||
|
||
# end = b + len(close_token) # 把 <|im_end|> 本体纳入监督
|
||
# spans.append((start, end))
|
||
|
||
# # pos 跳过这一轮结束标记(带换行就多跳一格)
|
||
# pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token))
|
||
# return spans
|
||
|
||
# def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||
# """
|
||
# 返回需要忽略监督的区间(仅 <think>...</think> 的“内部”),
|
||
# 标签本身 <think> 与 </think> 仍参与监督,以便模型学会闭合。
|
||
# """
|
||
# spans: List[Tuple[int, int]] = []
|
||
# open_tag = "<think>"
|
||
# close_tag = "</think>"
|
||
# pos = 0
|
||
# while True:
|
||
# a = rendered.find(open_tag, pos)
|
||
# if a == -1:
|
||
# break
|
||
# b = rendered.find(close_tag, a + len(open_tag))
|
||
# if b == -1:
|
||
# break
|
||
# # 只忽略内部,不忽略两侧标签
|
||
# spans.append((a + len(open_tag), b))
|
||
# pos = b + len(close_tag)
|
||
# return spans
|
||
|
||
|
||
# ----------------- 仅监督 assistant 内容(token-id 级,不用 offsets) -----------------
|
||
class QwenChatSFTDataset(IterableDataset):
|
||
"""
|
||
- 通过 chat_template 得到 token ids
|
||
- 以 special token id 定位 assistant 片段(<|im_start|>assistant\n ... <|im_end|>)
|
||
- 只监督 assistant 内容本体;默认把 <think>…</think>(含标签)整体屏蔽
|
||
- 超长时保最后一个 assistant 片段完整,左侧补齐到 seq_len
|
||
"""
|
||
def __init__(self,
|
||
ex_iter: Iterable[dict],
|
||
tokenizer: AutoTokenizer,
|
||
seq_len: int = 4096,
|
||
mask_think_and_tags: bool = True):
|
||
self.ex_iter = ex_iter
|
||
self.tok = tokenizer
|
||
self.seq_len = seq_len
|
||
self.mask_think_and_tags = mask_think_and_tags
|
||
|
||
# 关键标记的 token 序列
|
||
self.id_START = self.tok.convert_tokens_to_ids("<|im_start|>")
|
||
self.id_END = self.tok.convert_tokens_to_ids("<|im_end|>")
|
||
# self.ids_ASSISTANT_NL = self.tok.encode("assistant\n", add_special_tokens=False)
|
||
# 支持两种常见写法:'assistant\\n' 或 'assistant'
|
||
self.ids_ASSISTANT_CANDIDATES = [
|
||
self.tok.encode("assistant\n", add_special_tokens=False),
|
||
self.tok.encode("assistant", add_special_tokens=False),
|
||
]
|
||
# 过滤空候选(极端 tokenizer 配置)
|
||
self.ids_ASSISTANT_CANDIDATES = [c for c in self.ids_ASSISTANT_CANDIDATES if len(c) > 0]
|
||
|
||
if not self.ids_ASSISTANT_CANDIDATES:
|
||
raise RuntimeError("[fatal] no valid 'assistant' role token sequence found; check chat template/tokenizer.")
|
||
|
||
|
||
self.ids_THINK_OPEN = self.tok.encode("<think>", add_special_tokens=False)
|
||
self.ids_THINK_CLOSE = self.tok.encode("</think>", add_special_tokens=False)
|
||
|
||
# 兜底:有些模型未注册这些特殊 id 时,直接 fail-fast
|
||
for name, val in {
|
||
"id_START": self.id_START, "id_END": self.id_END
|
||
}.items():
|
||
if val is None or val == self.tok.unk_token_id:
|
||
raise RuntimeError(f"[fatal] tokenizer missing special token id for {name}")
|
||
|
||
@staticmethod
|
||
def _find_subseq(hay: list, needle: list, start: int) -> int:
|
||
n = len(needle)
|
||
if n == 0: return start
|
||
for i in range(start, len(hay) - n + 1):
|
||
if hay[i:i+n] == needle:
|
||
return i
|
||
return -1
|
||
|
||
def _find_role_after_start(self, ids, j_start: int) -> Optional[Tuple[int, int]]:
|
||
"""
|
||
从 j_start 开始,尝试匹配任一 'assistant' 角色 token 序列。
|
||
返回 (pos, length);匹配失败返回 None。
|
||
"""
|
||
for cand in self.ids_ASSISTANT_CANDIDATES:
|
||
pos = self._find_subseq(ids, cand, j_start)
|
||
if pos == j_start:
|
||
return (pos, len(cand))
|
||
return None
|
||
|
||
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
||
# 调试开关
|
||
dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1"
|
||
dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3"))
|
||
seen = 0
|
||
host = socket.gethostname()
|
||
rank = int(os.environ.get("RANK", "0"))
|
||
lrank = int(os.environ.get("LOCAL_RANK", "-1"))
|
||
|
||
for ex in self.ex_iter:
|
||
msgs = ex.get("messages")
|
||
if not msgs or not isinstance(msgs, list):
|
||
continue
|
||
tools = ex.get("tools", None)
|
||
|
||
# 直接让模板 tokenization -> ids(避免 offset 落坑)
|
||
try:
|
||
ids = self.tok.apply_chat_template(
|
||
msgs, tools=tools, add_generation_prompt=False,
|
||
tokenize=True, return_tensors=None
|
||
)
|
||
# 兼容老版本返回 dict 的情况
|
||
if isinstance(ids, dict):
|
||
ids = ids["input_ids"]
|
||
except TypeError:
|
||
# 极端回退:先渲染字符串再手动分词
|
||
rendered: str = self.tok.apply_chat_template(
|
||
msgs, add_generation_prompt=False, tokenize=False
|
||
)
|
||
ids = self.tok(rendered, add_special_tokens=False)["input_ids"]
|
||
|
||
if not ids:
|
||
continue
|
||
|
||
# 构建监督掩码(0/1)
|
||
mask = [0] * len(ids)
|
||
i = 0
|
||
while i < len(ids):
|
||
# 找到一个 <|im_start|>
|
||
try:
|
||
a = ids.index(self.id_START, i)
|
||
except ValueError:
|
||
break
|
||
|
||
# 必须是 assistant 角色(兼容 'assistant\\n' 或 'assistant')
|
||
j = a + 1
|
||
role_match = self._find_role_after_start(ids, j)
|
||
if role_match is None:
|
||
i = a + 1
|
||
continue
|
||
_, role_len = role_match
|
||
content_lo = j + role_len # 跳过角色 token 序列
|
||
|
||
# 找匹配的 <|im_end|>
|
||
try:
|
||
b = ids.index(self.id_END, content_lo)
|
||
except ValueError:
|
||
# 不闭合就放弃这个片段
|
||
i = a + 1
|
||
continue
|
||
|
||
content_hi = b # 不含 END
|
||
# 先把整个内容区间标 1(监督)
|
||
for t in range(content_lo, content_hi):
|
||
mask[t] = 1
|
||
|
||
# 可选:把 <think>…</think>(含标签)整体屏蔽
|
||
if self.mask_think_and_tags:
|
||
p = content_lo
|
||
while True:
|
||
o = self._find_subseq(ids, self.ids_THINK_OPEN, p)
|
||
if o == -1 or o >= content_hi:
|
||
break
|
||
c = self._find_subseq(ids, self.ids_THINK_CLOSE, o + len(self.ids_THINK_OPEN))
|
||
if c == -1 or c > content_hi:
|
||
break
|
||
x_lo = o # 含 <think>
|
||
x_hi = c + len(self.ids_THINK_CLOSE) # 含 </think>
|
||
for t in range(x_lo, min(x_hi, content_hi)):
|
||
mask[t] = 0
|
||
p = x_hi
|
||
|
||
# 继续找下一个片段
|
||
i = b + 1
|
||
|
||
# 如果没有任何可监督 token,跳过
|
||
if not any(mask):
|
||
continue
|
||
|
||
# ======== 截断策略:优先保留“最后一个被监督 token”为终点 ========
|
||
if len(ids) > self.seq_len:
|
||
last_on = max(idx for idx, v in enumerate(mask) if v == 1)
|
||
end = min(len(ids), last_on + 1)
|
||
start = max(0, end - self.seq_len)
|
||
ids = ids[start:end]
|
||
mask = mask[start:end]
|
||
|
||
# ======== 左侧 pad ========
|
||
pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
|
||
L = len(ids)
|
||
if L < self.seq_len:
|
||
pad = self.seq_len - L
|
||
input_ids = [pad_id] * pad + ids
|
||
attention_mask = [0] * pad + [1] * L
|
||
labels = [-100] * pad + [tok if m == 1 else -100 for tok, m in zip(ids, mask)]
|
||
else:
|
||
input_ids = ids
|
||
attention_mask = [1] * self.seq_len
|
||
labels = [tok if m == 1 else -100 for tok, m in zip(ids, mask)]
|
||
|
||
# >>> 调试打印(可选)
|
||
if dbg_on and seen < dbg_limit:
|
||
sup_tok = sum(1 for v in labels if v != -100)
|
||
print(
|
||
f"[sample][host={host} RANK={rank} LRank={lrank}] "
|
||
f"toks={len(input_ids)} sup_toks={sup_tok} "
|
||
f"seq_len={self.seq_len} pad_id={pad_id}",
|
||
flush=True
|
||
)
|
||
seen += 1
|
||
|
||
yield {
|
||
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
||
"labels": torch.tensor(labels, dtype=torch.long),
|
||
}
|
||
|
||
|
||
# ----------------- Collator(保持与上游一致:pad->label=-100, attn=0) -----------------
|
||
class SFTDataCollator:
|
||
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
|
||
self.tok = tokenizer
|
||
self.pad_to_length = pad_to_length
|
||
assert self.tok.pad_token_id is not None, "tokenizer.pad_token_id must be set"
|
||
|
||
def __call__(self, features):
|
||
if not features:
|
||
raise RuntimeError("Empty batch passed to collator")
|
||
|
||
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
|
||
input_ids = [_to_list(f["input_ids"]) for f in features]
|
||
attn_masks = [_to_list(f["attention_mask"]) for f in features]
|
||
labels_list = [_to_list(f["labels"]) for f in features]
|
||
|
||
max_len_in_batch = max(len(x) for x in input_ids)
|
||
target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch
|
||
|
||
pad_id = self.tok.pad_token_id
|
||
batch_inp, batch_attn, batch_lab = [], [], []
|
||
for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
|
||
pad_len = target_len - len(inp)
|
||
if pad_len < 0:
|
||
inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len]
|
||
pad_len = 0
|
||
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
|
||
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
|
||
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
|
||
|
||
return {
|
||
"input_ids": torch.stack(batch_inp, dim=0),
|
||
"attention_mask": torch.stack(batch_attn, dim=0),
|
||
"labels": torch.stack(batch_lab, dim=0),
|
||
}
|
||
|
||
|
||
# ----------------- 参数 -----------------
|
||
def parse_args():
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument("--model_name_or_path", type=str, required=True,
|
||
help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B)")
|
||
ap.add_argument("--data_glob", type=str, required=True,
|
||
help="本地 jsonl 通配符(每台机器都需有同路径数据;每行应含 messages/可选 tools)")
|
||
ap.add_argument("--output_dir", type=str, required=True,
|
||
help="本地输出目录(各节点各自本地写)")
|
||
ap.add_argument("--seq_len", type=int, default=4096)
|
||
ap.add_argument("--learning_rate", type=float, default=2e-5)
|
||
ap.add_argument("--weight_decay", type=float, default=0.1)
|
||
ap.add_argument("--warmup_ratio", type=float, default=0.02)
|
||
ap.add_argument("--num_train_epochs", type=float, default=1.0)
|
||
ap.add_argument("--max_steps", type=int, default=-1)
|
||
ap.add_argument("--log_interval", type=int, default=10)
|
||
ap.add_argument("--save_steps", type=int, default=500)
|
||
ap.add_argument("--eval_ratio", type=float, default=0.0)
|
||
ap.add_argument("--seed", type=int, default=1337)
|
||
ap.add_argument("--gradient_checkpointing", action="store_true")
|
||
ap.add_argument("--bf16", action="store_true",
|
||
help="3090/A100/H100 等可开 bf16;同时在 DS 配置里也要开")
|
||
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
|
||
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
|
||
ap.add_argument("--report_to", type=str, default="tensorboard",
|
||
choices=["none","tensorboard","wandb"])
|
||
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
|
||
ap.add_argument("--eval_data_glob", type=str, default=None,
|
||
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
|
||
ap.add_argument("--local_rank", type=int, default=-1,
|
||
help="for deepspeed/torchrun launcher; ignored by user code")
|
||
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
|
||
ap.add_argument("--deepspeed", type=str, default=None)
|
||
ap.add_argument("--eval_steps", type=int, default=10,
|
||
help="Evaluate every N optimizer steps when eval_dataset is provided")
|
||
|
||
return ap.parse_args()
|
||
|
||
|
||
# ----------------- 主函数 -----------------
|
||
def main():
|
||
|
||
args = parse_args()
|
||
|
||
# ✅ 只有 rank0 用 wandb,其它 rank 不上报
|
||
if os.environ.get("RANK", "0") != "0" and args.report_to == "wandb":
|
||
print(f"[rank {os.environ.get('RANK')}] force report_to=none", flush=True)
|
||
args.report_to = "none"
|
||
|
||
set_seed(args.seed)
|
||
|
||
host = socket.gethostname()
|
||
def dbg(msg):
|
||
print(
|
||
f"[dbg][host={host} RANK={os.environ.get('RANK','0')} "
|
||
f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}",
|
||
flush=True
|
||
)
|
||
|
||
# 是否真的启用 DeepSpeed(传了配置文件且文件存在)
|
||
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
|
||
|
||
dschf = None
|
||
if use_ds:
|
||
try:
|
||
from transformers.integrations.deepspeed import HfDeepSpeedConfig
|
||
src = "transformers.integrations.deepspeed"
|
||
except Exception:
|
||
try:
|
||
# 备用:部分版本直接从 transformers 暴露
|
||
from transformers import HfDeepSpeedConfig
|
||
src = "transformers"
|
||
except Exception as e:
|
||
raise RuntimeError(
|
||
"当前 transformers 版本未提供 HfDeepSpeedConfig,请升级/降级 transformers") from e
|
||
dschf = HfDeepSpeedConfig(args.deepspeed)
|
||
dbg(f"HfDeepSpeedConfig loaded from {src}")
|
||
|
||
|
||
if args.report_to == "wandb":
|
||
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
|
||
|
||
|
||
# 仅在 rank0 预初始化 W&B
|
||
is_rank0 = os.environ.get("RANK", "0") == "0" and os.environ.get("LOCAL_RANK", "-1") in ("0", "-1")
|
||
if is_rank0:
|
||
import wandb
|
||
try:
|
||
# 避免外部遗留的 RUN_ID 强制续跑导致卡住
|
||
os.environ.pop("WANDB_RUN_ID", None)
|
||
|
||
# 可选字段从环境注入(有就用)
|
||
extra = {}
|
||
if os.getenv("WANDB_NAME"): extra["name"] = os.getenv("WANDB_NAME")
|
||
if os.getenv("WANDB_GROUP"): extra["group"] = os.getenv("WANDB_GROUP")
|
||
if os.getenv("WANDB_RESUME"): extra["resume"] = os.getenv("WANDB_RESUME") # 建议 'allow'
|
||
|
||
run = wandb.init(
|
||
project=args.wandb_project,
|
||
entity=os.getenv("WANDB_ENTITY") or os.getenv("WB_ENTITY") or "hailin",
|
||
settings=wandb.Settings(
|
||
base_url=os.getenv("WANDB_BASE_URL", "https://wandb.szaiai.com"),
|
||
init_timeout=int(os.getenv("WANDB_INIT_TIMEOUT", "300")),
|
||
),
|
||
**extra,
|
||
)
|
||
print(f"[wandb] run url: {getattr(run, 'url', '(n/a)')}", flush=True)
|
||
except Exception as e:
|
||
print(f"[wandb] init failed -> disable logging, reason={e}", flush=True)
|
||
os.environ["WANDB_DISABLED"] = "true"
|
||
args.report_to = "none"
|
||
else:
|
||
os.environ["WANDB_DISABLED"] = "true"
|
||
|
||
# 版本 & 启动参数 & 关键环境变量
|
||
import transformers as hf
|
||
try:
|
||
import deepspeed as ds
|
||
ds_ver = ds.__version__
|
||
except Exception:
|
||
ds_ver = "n/a"
|
||
|
||
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
|
||
dbg(f"args={args}")
|
||
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
|
||
os.environ.get("WORLD_SIZE"),
|
||
os.environ.get("RANK"),
|
||
os.environ.get("LOCAL_RANK", str(args.local_rank)),
|
||
os.environ.get("MASTER_ADDR"),
|
||
os.environ.get("MASTER_PORT"),
|
||
os.environ.get("CUDA_VISIBLE_DEVICES"),
|
||
))
|
||
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
|
||
|
||
|
||
# ---- 初始化分布式(供一致性探针使用)----
|
||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||
rank = int(os.environ.get("RANK", "0"))
|
||
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
|
||
dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}")
|
||
|
||
if torch.cuda.is_available() and local_rank >= 0:
|
||
torch.cuda.set_device(local_rank)
|
||
dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} "
|
||
f"name={torch.cuda.get_device_name(torch.cuda.current_device())}")
|
||
else:
|
||
dbg("no cuda or invalid local_rank; not calling set_device")
|
||
|
||
if world_size > 1 and dist.is_available() and not dist.is_initialized():
|
||
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
||
dbg(f"init_process_group backend={backend} via env://")
|
||
dist.init_process_group(backend=backend, init_method="env://")
|
||
else:
|
||
dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}")
|
||
|
||
|
||
if dist.is_available() and dist.is_initialized():
|
||
try:
|
||
dbg(f"dist.get_backend()={dist.get_backend()} "
|
||
f"dist.get_world_size()={dist.get_world_size()} dist.get_rank()={dist.get_rank()}")
|
||
except Exception as e:
|
||
dbg(f"dist query error: {e}")
|
||
|
||
# 1) 先补 tokenizer 的 pad
|
||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
|
||
if tokenizer.pad_token is None:
|
||
tokenizer.pad_token = tokenizer.eos_token
|
||
|
||
|
||
# 左侧补齐以匹配 Dataset 的左 pad 策略
|
||
try:
|
||
if getattr(tokenizer, "padding_side", None) != "left":
|
||
tokenizer.padding_side = "left"
|
||
except Exception:
|
||
pass
|
||
|
||
# 强制要求 fast tokenizer(offset_mapping 依赖 fast)
|
||
# from transformers import PreTrainedTokenizerFast
|
||
# if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False):
|
||
# raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。")
|
||
|
||
# 建议使用 fast 分词器(更快);不再依赖 offset_mapping
|
||
if not getattr(tokenizer, "is_fast", False):
|
||
print("[warn] using a slow tokenizer; masks are token-id based and still correct, just slower.", flush=True)
|
||
|
||
|
||
|
||
tokenizer.model_max_length = args.seq_len
|
||
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
|
||
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
|
||
|
||
# 2) 再加载模型 之前,先算 dtype
|
||
def _bf16_supported():
|
||
if not torch.cuda.is_available():
|
||
return False
|
||
# 兼容不同 torch 版本:优先用 API,退化到算力判断
|
||
if hasattr(torch.cuda, "is_bf16_supported"):
|
||
return torch.cuda.is_bf16_supported()
|
||
major, minor = torch.cuda.get_device_capability()
|
||
return (major, minor) >= (8, 0) # Ampere 及以上
|
||
|
||
use_bf16 = bool(args.bf16 and _bf16_supported())
|
||
dtype = (torch.bfloat16 if use_bf16 else
|
||
(torch.float16 if torch.cuda.is_available() else torch.float32))
|
||
|
||
|
||
try:
|
||
torch.backends.cuda.matmul.allow_tf32 = True
|
||
torch.backends.cudnn.allow_tf32 = True
|
||
torch.set_float32_matmul_precision("high")
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
# 交给插件做 ZeRO-Init/分片加载
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
args.model_name_or_path,
|
||
torch_dtype=dtype,
|
||
low_cpu_mem_usage=True,
|
||
trust_remote_code=True,
|
||
attn_implementation="sdpa",
|
||
)
|
||
|
||
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
|
||
dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
|
||
f"use_cache={getattr(model.config,'use_cache',None)} "
|
||
f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
|
||
|
||
# 3) pad/alibi 等配置
|
||
model.config.pad_token_id = tokenizer.pad_token_id
|
||
|
||
if getattr(model, "generation_config", None) is not None:
|
||
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
||
|
||
model.config.use_cache = False
|
||
if args.gradient_checkpointing:
|
||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||
try:
|
||
# 让 PyTorch 自己选,或显式打开高效实现(任选其一):
|
||
torch.backends.cuda.enable_flash_sdp(True)
|
||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||
torch.backends.cuda.enable_math_sdp(False)
|
||
except Exception:
|
||
pass
|
||
|
||
# ===== 数据鲁棒性检查(多机各自执行)=====
|
||
files = sorted(glob.glob(args.data_glob))
|
||
if len(files) == 0:
|
||
raise FileNotFoundError(
|
||
f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n"
|
||
"每台机器都必须在相同本地路径下放置数据;"
|
||
"可通过 DATA_GLOB=<your_glob> ./run_ds.sh 覆写。"
|
||
)
|
||
if is_main_process():
|
||
print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True)
|
||
|
||
# ====== 小探针:样本结构 ======
|
||
ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||
def ex_iter_probe():
|
||
for ex in ds_stream_probe:
|
||
yield ex
|
||
train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len)
|
||
|
||
try:
|
||
sample = next(iter(train_stream_probe))
|
||
except StopIteration:
|
||
raise RuntimeError(
|
||
f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n"
|
||
"请确认每行 JSON 至少包含 'messages'(列表,含 user/assistant)字段;"
|
||
"若含 <think>…</think> 请确保不包含真实思维文本,或移除。\n"
|
||
"另外检查 seq_len 是否过小导致全部被裁。"
|
||
)
|
||
|
||
# 更靠谱的自检(替换你现在的两行 assert)
|
||
ids, attn, labs = sample["input_ids"], sample["attention_mask"], sample["labels"]
|
||
assert (labs != -100).any(), "[fatal] no supervised tokens in first valid sample"
|
||
# pad 区必须被忽略监督
|
||
assert bool((labs[attn == 0] == -100).all()), "[fatal] padded tokens must have label -100"
|
||
|
||
|
||
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)======
|
||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True).shuffle(buffer_size=50000, seed=args.seed)
|
||
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
|
||
|
||
# ====== 一致性探针(不分片)======
|
||
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
|
||
|
||
def has_at_least(stream, n: int):
|
||
it = iter(stream)
|
||
for _ in range(n):
|
||
try:
|
||
next(it)
|
||
except StopIteration:
|
||
return 0
|
||
return 1
|
||
|
||
need = max(1, args.gradient_accumulation_steps)
|
||
|
||
local_ok = has_at_least(probe_stream, need)
|
||
|
||
if dist.is_available() and dist.is_initialized():
|
||
t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu"))
|
||
dist.all_reduce(t, op=dist.ReduceOp.MIN)
|
||
if t.item() == 0:
|
||
if is_main_process():
|
||
print(
|
||
f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
|
||
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
|
||
flush=True
|
||
)
|
||
dist.barrier()
|
||
sys.exit(2)
|
||
else:
|
||
if local_ok == 0:
|
||
print(
|
||
f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批 (GA={need})。 "
|
||
f"请减少 GA 或扩大/清洗数据;本次训练不会启动。",
|
||
flush=True
|
||
)
|
||
sys.exit(2)
|
||
|
||
# ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
|
||
eval_dataset: Optional[Dataset] = None
|
||
|
||
class ListDataset(Dataset):
|
||
def __init__(self, items): self.items = items
|
||
def __len__(self): return len(self.items)
|
||
def __getitem__(self, idx): return self.items[idx]
|
||
|
||
if args.eval_data_glob:
|
||
eval_files = sorted(glob.glob(args.eval_data_glob))
|
||
if len(eval_files) == 0:
|
||
raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
|
||
if is_main_process():
|
||
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
|
||
|
||
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
|
||
def ex_iter_eval():
|
||
for ex in ds_eval_stream:
|
||
yield ex
|
||
|
||
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
|
||
eval_items: List[Dict[str, torch.Tensor]] = []
|
||
for sample in eval_iterable:
|
||
eval_items.append(sample)
|
||
|
||
if len(eval_items) == 0:
|
||
raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。")
|
||
|
||
eval_dataset = ListDataset(eval_items)
|
||
|
||
elif args.eval_ratio and args.eval_ratio > 0:
|
||
desired_eval_batches = 200
|
||
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||
def ex_iter_eval2():
|
||
for ex in tmp_stream:
|
||
yield ex
|
||
eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len)
|
||
eval_samples = []
|
||
it = iter(eval_stream)
|
||
for _ in range(desired_eval_batches):
|
||
try:
|
||
eval_samples.append(next(it))
|
||
except StopIteration:
|
||
break
|
||
if len(eval_samples) > 0:
|
||
eval_dataset = ListDataset(eval_samples)
|
||
|
||
# ---- 统一补齐 eval 集(确保不会出现空 batch)----
|
||
if eval_dataset is not None:
|
||
ws = max(world_size, 1)
|
||
be = max(1, args.per_device_eval_batch_size)
|
||
global_bs = ws * be
|
||
|
||
r = len(eval_dataset) % global_bs
|
||
if r != 0:
|
||
pad_need = global_bs - r
|
||
eval_dataset.items += eval_dataset.items[:pad_need]
|
||
|
||
if is_main_process():
|
||
print(f"[eval] padded eval set to {len(eval_dataset)} "
|
||
f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})",
|
||
flush=True)
|
||
|
||
# 补齐后再做 sanity check
|
||
assert len(eval_dataset) % global_bs == 0, \
|
||
f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}"
|
||
|
||
# 更稳:联调阶段不强行 pad 到 4096
|
||
data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
|
||
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
logging_dir = os.path.join(args.output_dir, "logs")
|
||
os.makedirs(logging_dir, exist_ok=True)
|
||
|
||
# ---- 兼容 4.51(eval_strategy)与旧版(evaluation_strategy) ----
|
||
ta_kwargs = {}
|
||
sig = inspect.signature(TrainingArguments.__init__).parameters
|
||
if eval_dataset is not None:
|
||
if "eval_strategy" in sig:
|
||
ta_kwargs["eval_strategy"] = "steps"
|
||
elif "evaluation_strategy" in sig:
|
||
ta_kwargs["evaluation_strategy"] = "steps"
|
||
else:
|
||
if "eval_strategy" in sig:
|
||
ta_kwargs["eval_strategy"] = "no"
|
||
elif "evaluation_strategy" in sig:
|
||
ta_kwargs["evaluation_strategy"] = "no"
|
||
|
||
ta_sig = inspect.signature(TrainingArguments.__init__).parameters
|
||
ta_kwargs2 = dict(
|
||
output_dir=args.output_dir,
|
||
logging_dir=logging_dir,
|
||
run_name=f"sft-{os.path.basename(args.output_dir)}-{socket.gethostname()}",
|
||
do_train=True,
|
||
do_eval=(eval_dataset is not None),
|
||
# eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
|
||
# 用用户指定的 eval_steps;没有 eval 集就 None
|
||
eval_steps=(args.eval_steps if eval_dataset is not None else None),
|
||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||
learning_rate=args.learning_rate,
|
||
weight_decay=args.weight_decay,
|
||
warmup_ratio=args.warmup_ratio,
|
||
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
|
||
max_steps=args.max_steps if args.max_steps > 0 else -1,
|
||
lr_scheduler_type="cosine",
|
||
logging_steps=args.log_interval,
|
||
save_steps=args.save_steps,
|
||
save_total_limit=2,
|
||
deepspeed=(args.deepspeed if use_ds else None),
|
||
dataloader_drop_last=False,
|
||
dataloader_num_workers=0,
|
||
label_smoothing_factor=0.0,
|
||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||
report_to=([] if args.report_to == "none" else [args.report_to]),
|
||
#bf16=args.bf16,
|
||
#fp16=(not args.bf16),
|
||
gradient_checkpointing=args.gradient_checkpointing,
|
||
remove_unused_columns=False,
|
||
save_on_each_node=True,
|
||
logging_first_step=True,
|
||
**ta_kwargs, # 你之前构造的 eval_strategy 兼容项
|
||
)
|
||
|
||
if "dataloader_pin_memory" in ta_sig:
|
||
ta_kwargs2["dataloader_pin_memory"] = False
|
||
if "torch_compile" in ta_sig:
|
||
ta_kwargs2["torch_compile"] = False
|
||
|
||
# 构造 TrainingArguments 之前,沿用上面的 use_bf16 判定
|
||
ta_kwargs2.update({
|
||
"bf16": use_bf16,
|
||
"fp16": (torch.cuda.is_available() and not use_bf16),
|
||
})
|
||
|
||
training_args = TrainingArguments(**ta_kwargs2)
|
||
|
||
trainer_kwargs = {}
|
||
if "processing_class" in inspect.signature(Trainer.__init__).parameters:
|
||
trainer_kwargs["processing_class"] = tokenizer
|
||
else:
|
||
trainer_kwargs["tokenizer"] = tokenizer
|
||
|
||
|
||
trainer = DebugTrainer(
|
||
model=model,
|
||
args=training_args,
|
||
train_dataset=train_stream,
|
||
eval_dataset=eval_dataset,
|
||
data_collator=data_collator,
|
||
**trainer_kwargs,
|
||
)
|
||
|
||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||
|
||
|
||
|
||
# ==== 断点恢复判定(非共享盘安全写法)====
|
||
def last_step(path: str) -> int:
|
||
ck = get_last_checkpoint(path)
|
||
if ck is None:
|
||
return -1
|
||
base = os.path.basename(ck)
|
||
try:
|
||
return int(base.split("-")[-1])
|
||
except Exception:
|
||
return -1
|
||
|
||
local_last = last_step(args.output_dir) # -1 表示本机没有任何 checkpoint
|
||
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")
|
||
|
||
resume_flag = None
|
||
if dist.is_available() and dist.is_initialized():
|
||
# 只要有任意一个 rank 没有 ckpt -> 不恢复
|
||
has_local = torch.tensor(1 if local_last >= 0 else 0, device=device)
|
||
dist.all_reduce(has_local, op=dist.ReduceOp.MIN)
|
||
if has_local.item() == 1:
|
||
# 全员都有:收集每个 rank 的 last step,取公共最小步 k(每台机器都一定存在)
|
||
ts = torch.tensor(local_last, device=device)
|
||
world = dist.get_world_size()
|
||
buf = [torch.zeros_like(ts) for _ in range(world)]
|
||
dist.all_gather(buf, ts)
|
||
steps = [b.item() for b in buf]
|
||
k = min(steps)
|
||
if k >= 0:
|
||
resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}")
|
||
if is_main_process():
|
||
print(f"[resume] steps={steps}, resume={resume_flag}", flush=True)
|
||
else:
|
||
# 单机或未初始化分布式:本地有就按本机最后一步恢复
|
||
if local_last >= 0:
|
||
resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}")
|
||
|
||
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}")
|
||
|
||
|
||
|
||
# —— 全局一致性检测:如果有任意 rank 缺这个 ckpt,就禁用恢复 ——
|
||
if dist.is_available() and dist.is_initialized():
|
||
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")
|
||
present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device)
|
||
dist.all_reduce(present, op=dist.ReduceOp.MIN)
|
||
if present.item() == 0:
|
||
if is_main_process():
|
||
print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True)
|
||
resume_flag = None
|
||
dist.barrier()
|
||
else:
|
||
if resume_flag is not None and not os.path.isdir(resume_flag):
|
||
# 单机:缺就直接禁用恢复
|
||
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
|
||
resume_flag = None
|
||
|
||
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
|
||
print_once("***** Starting training *****")
|
||
|
||
|
||
dbg(f"allocated={torch.cuda.memory_allocated()/1024**2:.1f} MB, "
|
||
f"reserved={torch.cuda.memory_reserved()/1024**2:.1f} MB")
|
||
|
||
|
||
train_result = trainer.train(resume_from_checkpoint=resume_flag)
|
||
trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型
|
||
|
||
metrics = train_result.metrics
|
||
trainer.log_metrics("train", metrics)
|
||
trainer.save_metrics("train", metrics)
|
||
trainer.save_state()
|
||
|
||
if eval_dataset is not None:
|
||
print_once("***** Running eval *****")
|
||
eval_metrics = trainer.evaluate()
|
||
trainer.log_metrics("eval", eval_metrics)
|
||
trainer.save_metrics("eval", eval_metrics)
|
||
|
||
print_once("Done.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|