This commit is contained in:
hailin 2025-09-09 12:17:32 +08:00
parent dc82fcfab8
commit 3232d61c71
1 changed files with 92 additions and 59 deletions

View File

@ -179,56 +179,6 @@ class CsvLossLogger(TrainerCallback):
from typing import List, Tuple, Iterable, Iterator, Dict
# ----------------- 工具:提取 assistant 字符区间 -----------------
# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
# spans: List[Tuple[int, int]] = []
# open_tag = "<|im_start|>assistant\n"
# close_token = "<|im_end|>"
# close_tag = close_token + "\n" # 常见模板带换行
# pos = 0
# while True:
# a = rendered.find(open_tag, pos)
# if a == -1:
# break
# start = a + len(open_tag)
# # 先找含换行版本,找不到再退化找不带换行的
# b = rendered.find(close_tag, start)
# if b == -1:
# b = rendered.find(close_token, start)
# if b == -1:
# break
# end = b + len(close_token) # 把 <|im_end|> 本体纳入监督
# spans.append((start, end))
# # pos 跳过这一轮结束标记(带换行就多跳一格)
# pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token))
# return spans
# def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
# """
# 返回需要忽略监督的区间(仅 <think>...</think> 的“内部”),
# 标签本身 <think> 与 </think> 仍参与监督,以便模型学会闭合。
# """
# spans: List[Tuple[int, int]] = []
# open_tag = "<think>"
# close_tag = "</think>"
# pos = 0
# while True:
# a = rendered.find(open_tag, pos)
# if a == -1:
# break
# b = rendered.find(close_tag, a + len(open_tag))
# if b == -1:
# break
# # 只忽略内部,不忽略两侧标签
# spans.append((a + len(open_tag), b))
# pos = b + len(close_tag)
# return spans
# ----------------- 仅监督 assistant 内容token-id 级,不用 offsets -----------------
class QwenChatSFTDataset(IterableDataset):
"""
@ -510,12 +460,85 @@ def main():
set_seed(args.seed)
host = socket.gethostname()
def dbg(msg):
print(
f"[dbg][host={host} RANK={os.environ.get('RANK','0')} "
f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}",
flush=True
)
# ==== colored dbg (robust default to info) ====
try:
import colorama
colorama.just_fix_windows_console()
except Exception:
pass
def _use_color() -> bool:
if os.environ.get("NO_COLOR"): return False
if os.environ.get("FORCE_COLOR"): return True
return sys.stdout.isatty()
class _C:
reset = "\033[0m"
gray = "\033[90m"
green = "\033[32m"
yellow = "\033[33m"
red = "\033[31m"
cyan = "\033[36m"
def _paint(s, color):
return f"{color}{s}{_C.reset}" if _use_color() else s
_LEVEL_ALIAS = {
"": "info", None: "info",
"ok": "ok", "success": "ok", "pass": "ok",
"warn": "warn", "warning": "warn",
"err": "err", "error": "err", "fatal": "err", "fail": "err",
"info": "info", "information": "info"
}
_LEVEL_COLOR = {
"ok": _C.green,
"warn": _C.yellow,
"err": _C.red,
"info": _C.cyan,
}
def _norm_level(level) -> str:
# 默认 info
if level is None:
return "info"
# 数字等级兼容(类似 logging
if isinstance(level, (int, float)):
if level >= 40: return "err"
if level >= 30: return "warn"
return "info"
# 字符串别名
if isinstance(level, str):
key = level.strip().lower()
return _LEVEL_ALIAS.get(key, "info")
return "info"
def dbg(msg, level=None):
lvl = _norm_level(level) # 未指定/非法 -> "info"
host = socket.gethostname()
rank = os.environ.get("RANK", "0")
lrank = os.environ.get("LOCAL_RANK", "-1")
prefix = f"[dbg][host={host} RANK={rank} LOCAL_RANK={lrank}] "
color = _LEVEL_COLOR.get(lvl, _C.cyan)
print(_paint(prefix, _C.gray) + _paint(str(msg), color), flush=True)
# 便捷别名(可选)
def dbg_ok(m): dbg(m, "ok")
def dbg_warn(m): dbg(m, "warn")
def dbg_err(m):
s = _paint(f"[dbg]{m}", _C.red)
print(s, flush=True, file=sys.stderr)
# def dbg(msg):
# print(
# f"[dbg][host={host} RANK={os.environ.get('RANK','0')} "
# f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}",
# flush=True
# )
# 是否真的启用 DeepSpeed传了配置文件且文件存在
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
@ -682,9 +705,9 @@ def main():
)
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
f"use_cache={getattr(model.config,'use_cache',None)} "
f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
# dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
# f"use_cache={getattr(model.config,'use_cache',None)} "
# f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
# 3) pad/alibi 等配置
model.config.pad_token_id = tokenizer.pad_token_id
@ -703,6 +726,16 @@ def main():
except Exception:
pass
# ✅ 放在这里打印“修正后”的值
dbg(f"post-config: use_cache={model.config.use_cache} "
f"model.pad_token_id={model.config.pad_token_id} "
f"gen.pad_token_id={getattr(getattr(model,'generation_config',None),'pad_token_id',None)} "
f"tok.pad={tokenizer.pad_token}/{tokenizer.pad_token_id}")
assert tokenizer.pad_token_id is not None, "tokenizer.pad_token_id must not be None"
assert model.config.pad_token_id == tokenizer.pad_token_id, \
f"model.pad_token_id {model.config.pad_token_id} != tokenizer.pad_token_id {tokenizer.pad_token_id}"
# ===== 数据鲁棒性检查(多机各自执行)=====
files = sorted(glob.glob(args.data_glob))
if len(files) == 0: