From 3232d61c7166cca0264dba2b43346b2d67d39a0b Mon Sep 17 00:00:00 2001 From: hailin Date: Tue, 9 Sep 2025 12:17:32 +0800 Subject: [PATCH] . --- train_sft_ds.py | 151 +++++++++++++++++++++++++++++------------------- 1 file changed, 92 insertions(+), 59 deletions(-) diff --git a/train_sft_ds.py b/train_sft_ds.py index 1d7d1ee..b4dca77 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -179,56 +179,6 @@ class CsvLossLogger(TrainerCallback): from typing import List, Tuple, Iterable, Iterator, Dict -# ----------------- 工具:提取 assistant 字符区间 ----------------- -# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: -# spans: List[Tuple[int, int]] = [] -# open_tag = "<|im_start|>assistant\n" -# close_token = "<|im_end|>" -# close_tag = close_token + "\n" # 常见模板带换行 - -# pos = 0 -# while True: -# a = rendered.find(open_tag, pos) -# if a == -1: -# break -# start = a + len(open_tag) - -# # 先找含换行版本,找不到再退化找不带换行的 -# b = rendered.find(close_tag, start) -# if b == -1: -# b = rendered.find(close_token, start) -# if b == -1: -# break - -# end = b + len(close_token) # 把 <|im_end|> 本体纳入监督 -# spans.append((start, end)) - -# # pos 跳过这一轮结束标记(带换行就多跳一格) -# pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token)) -# return spans - -# def _think_char_spans(rendered: str) -> List[Tuple[int, int]]: -# """ -# 返回需要忽略监督的区间(仅 ... 的“内部”), -# 标签本身 仍参与监督,以便模型学会闭合。 -# """ -# spans: List[Tuple[int, int]] = [] -# open_tag = "" -# close_tag = "" -# pos = 0 -# while True: -# a = rendered.find(open_tag, pos) -# if a == -1: -# break -# b = rendered.find(close_tag, a + len(open_tag)) -# if b == -1: -# break -# # 只忽略内部,不忽略两侧标签 -# spans.append((a + len(open_tag), b)) -# pos = b + len(close_tag) -# return spans - - # ----------------- 仅监督 assistant 内容(token-id 级,不用 offsets) ----------------- class QwenChatSFTDataset(IterableDataset): """ @@ -510,12 +460,85 @@ def main(): set_seed(args.seed) host = socket.gethostname() - def dbg(msg): - print( - f"[dbg][host={host} RANK={os.environ.get('RANK','0')} " - f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}", - flush=True - ) + + + # ==== colored dbg (robust default to info) ==== + try: + import colorama + colorama.just_fix_windows_console() + except Exception: + pass + + def _use_color() -> bool: + if os.environ.get("NO_COLOR"): return False + if os.environ.get("FORCE_COLOR"): return True + return sys.stdout.isatty() + + class _C: + reset = "\033[0m" + gray = "\033[90m" + green = "\033[32m" + yellow = "\033[33m" + red = "\033[31m" + cyan = "\033[36m" + + def _paint(s, color): + return f"{color}{s}{_C.reset}" if _use_color() else s + + _LEVEL_ALIAS = { + "": "info", None: "info", + "ok": "ok", "success": "ok", "pass": "ok", + "warn": "warn", "warning": "warn", + "err": "err", "error": "err", "fatal": "err", "fail": "err", + "info": "info", "information": "info" + } + + _LEVEL_COLOR = { + "ok": _C.green, + "warn": _C.yellow, + "err": _C.red, + "info": _C.cyan, + } + + def _norm_level(level) -> str: + # 默认 info + if level is None: + return "info" + # 数字等级兼容(类似 logging) + if isinstance(level, (int, float)): + if level >= 40: return "err" + if level >= 30: return "warn" + return "info" + # 字符串别名 + if isinstance(level, str): + key = level.strip().lower() + return _LEVEL_ALIAS.get(key, "info") + return "info" + + def dbg(msg, level=None): + lvl = _norm_level(level) # 未指定/非法 -> "info" + host = socket.gethostname() + rank = os.environ.get("RANK", "0") + lrank = os.environ.get("LOCAL_RANK", "-1") + prefix = f"[dbg][host={host} RANK={rank} LOCAL_RANK={lrank}] " + color = _LEVEL_COLOR.get(lvl, _C.cyan) + print(_paint(prefix, _C.gray) + _paint(str(msg), color), flush=True) + + # 便捷别名(可选) + def dbg_ok(m): dbg(m, "ok") + def dbg_warn(m): dbg(m, "warn") + def dbg_err(m): + s = _paint(f"[dbg]{m}", _C.red) + print(s, flush=True, file=sys.stderr) + + + + # def dbg(msg): + # print( + # f"[dbg][host={host} RANK={os.environ.get('RANK','0')} " + # f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}", + # flush=True + # ) # 是否真的启用 DeepSpeed(传了配置文件且文件存在) use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed)) @@ -682,9 +705,9 @@ def main(): ) print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True) - dbg(f"model loaded: dtype={next(model.parameters()).dtype} " - f"use_cache={getattr(model.config,'use_cache',None)} " - f"pad_token_id={getattr(model.config,'pad_token_id',None)}") + # dbg(f"model loaded: dtype={next(model.parameters()).dtype} " + # f"use_cache={getattr(model.config,'use_cache',None)} " + # f"pad_token_id={getattr(model.config,'pad_token_id',None)}") # 3) pad/alibi 等配置 model.config.pad_token_id = tokenizer.pad_token_id @@ -703,6 +726,16 @@ def main(): except Exception: pass + # ✅ 放在这里打印“修正后”的值 + dbg(f"post-config: use_cache={model.config.use_cache} " + f"model.pad_token_id={model.config.pad_token_id} " + f"gen.pad_token_id={getattr(getattr(model,'generation_config',None),'pad_token_id',None)} " + f"tok.pad={tokenizer.pad_token}/{tokenizer.pad_token_id}") + + assert tokenizer.pad_token_id is not None, "tokenizer.pad_token_id must not be None" + assert model.config.pad_token_id == tokenizer.pad_token_id, \ + f"model.pad_token_id {model.config.pad_token_id} != tokenizer.pad_token_id {tokenizer.pad_token_id}" + # ===== 数据鲁棒性检查(多机各自执行)===== files = sorted(glob.glob(args.data_glob)) if len(files) == 0: