diff --git a/train_sft_ds.py b/train_sft_ds.py
index 1d7d1ee..b4dca77 100644
--- a/train_sft_ds.py
+++ b/train_sft_ds.py
@@ -179,56 +179,6 @@ class CsvLossLogger(TrainerCallback):
from typing import List, Tuple, Iterable, Iterator, Dict
-# ----------------- 工具:提取 assistant 字符区间 -----------------
-# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
-# spans: List[Tuple[int, int]] = []
-# open_tag = "<|im_start|>assistant\n"
-# close_token = "<|im_end|>"
-# close_tag = close_token + "\n" # 常见模板带换行
-
-# pos = 0
-# while True:
-# a = rendered.find(open_tag, pos)
-# if a == -1:
-# break
-# start = a + len(open_tag)
-
-# # 先找含换行版本,找不到再退化找不带换行的
-# b = rendered.find(close_tag, start)
-# if b == -1:
-# b = rendered.find(close_token, start)
-# if b == -1:
-# break
-
-# end = b + len(close_token) # 把 <|im_end|> 本体纳入监督
-# spans.append((start, end))
-
-# # pos 跳过这一轮结束标记(带换行就多跳一格)
-# pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token))
-# return spans
-
-# def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
-# """
-# 返回需要忽略监督的区间(仅 ... 的“内部”),
-# 标签本身 与 仍参与监督,以便模型学会闭合。
-# """
-# spans: List[Tuple[int, int]] = []
-# open_tag = ""
-# close_tag = ""
-# pos = 0
-# while True:
-# a = rendered.find(open_tag, pos)
-# if a == -1:
-# break
-# b = rendered.find(close_tag, a + len(open_tag))
-# if b == -1:
-# break
-# # 只忽略内部,不忽略两侧标签
-# spans.append((a + len(open_tag), b))
-# pos = b + len(close_tag)
-# return spans
-
-
# ----------------- 仅监督 assistant 内容(token-id 级,不用 offsets) -----------------
class QwenChatSFTDataset(IterableDataset):
"""
@@ -510,12 +460,85 @@ def main():
set_seed(args.seed)
host = socket.gethostname()
- def dbg(msg):
- print(
- f"[dbg][host={host} RANK={os.environ.get('RANK','0')} "
- f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}",
- flush=True
- )
+
+
+ # ==== colored dbg (robust default to info) ====
+ try:
+ import colorama
+ colorama.just_fix_windows_console()
+ except Exception:
+ pass
+
+ def _use_color() -> bool:
+ if os.environ.get("NO_COLOR"): return False
+ if os.environ.get("FORCE_COLOR"): return True
+ return sys.stdout.isatty()
+
+ class _C:
+ reset = "\033[0m"
+ gray = "\033[90m"
+ green = "\033[32m"
+ yellow = "\033[33m"
+ red = "\033[31m"
+ cyan = "\033[36m"
+
+ def _paint(s, color):
+ return f"{color}{s}{_C.reset}" if _use_color() else s
+
+ _LEVEL_ALIAS = {
+ "": "info", None: "info",
+ "ok": "ok", "success": "ok", "pass": "ok",
+ "warn": "warn", "warning": "warn",
+ "err": "err", "error": "err", "fatal": "err", "fail": "err",
+ "info": "info", "information": "info"
+ }
+
+ _LEVEL_COLOR = {
+ "ok": _C.green,
+ "warn": _C.yellow,
+ "err": _C.red,
+ "info": _C.cyan,
+ }
+
+ def _norm_level(level) -> str:
+ # 默认 info
+ if level is None:
+ return "info"
+ # 数字等级兼容(类似 logging)
+ if isinstance(level, (int, float)):
+ if level >= 40: return "err"
+ if level >= 30: return "warn"
+ return "info"
+ # 字符串别名
+ if isinstance(level, str):
+ key = level.strip().lower()
+ return _LEVEL_ALIAS.get(key, "info")
+ return "info"
+
+ def dbg(msg, level=None):
+ lvl = _norm_level(level) # 未指定/非法 -> "info"
+ host = socket.gethostname()
+ rank = os.environ.get("RANK", "0")
+ lrank = os.environ.get("LOCAL_RANK", "-1")
+ prefix = f"[dbg][host={host} RANK={rank} LOCAL_RANK={lrank}] "
+ color = _LEVEL_COLOR.get(lvl, _C.cyan)
+ print(_paint(prefix, _C.gray) + _paint(str(msg), color), flush=True)
+
+ # 便捷别名(可选)
+ def dbg_ok(m): dbg(m, "ok")
+ def dbg_warn(m): dbg(m, "warn")
+ def dbg_err(m):
+ s = _paint(f"[dbg]{m}", _C.red)
+ print(s, flush=True, file=sys.stderr)
+
+
+
+ # def dbg(msg):
+ # print(
+ # f"[dbg][host={host} RANK={os.environ.get('RANK','0')} "
+ # f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}",
+ # flush=True
+ # )
# 是否真的启用 DeepSpeed(传了配置文件且文件存在)
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
@@ -682,9 +705,9 @@ def main():
)
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
- dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
- f"use_cache={getattr(model.config,'use_cache',None)} "
- f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
+ # dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
+ # f"use_cache={getattr(model.config,'use_cache',None)} "
+ # f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
# 3) pad/alibi 等配置
model.config.pad_token_id = tokenizer.pad_token_id
@@ -703,6 +726,16 @@ def main():
except Exception:
pass
+ # ✅ 放在这里打印“修正后”的值
+ dbg(f"post-config: use_cache={model.config.use_cache} "
+ f"model.pad_token_id={model.config.pad_token_id} "
+ f"gen.pad_token_id={getattr(getattr(model,'generation_config',None),'pad_token_id',None)} "
+ f"tok.pad={tokenizer.pad_token}/{tokenizer.pad_token_id}")
+
+ assert tokenizer.pad_token_id is not None, "tokenizer.pad_token_id must not be None"
+ assert model.config.pad_token_id == tokenizer.pad_token_id, \
+ f"model.pad_token_id {model.config.pad_token_id} != tokenizer.pad_token_id {tokenizer.pad_token_id}"
+
# ===== 数据鲁棒性检查(多机各自执行)=====
files = sorted(glob.glob(args.data_glob))
if len(files) == 0: