This commit is contained in:
parent
dc82fcfab8
commit
3232d61c71
151
train_sft_ds.py
151
train_sft_ds.py
|
|
@ -179,56 +179,6 @@ class CsvLossLogger(TrainerCallback):
|
|||
|
||||
from typing import List, Tuple, Iterable, Iterator, Dict
|
||||
|
||||
# ----------------- 工具:提取 assistant 字符区间 -----------------
|
||||
# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||
# spans: List[Tuple[int, int]] = []
|
||||
# open_tag = "<|im_start|>assistant\n"
|
||||
# close_token = "<|im_end|>"
|
||||
# close_tag = close_token + "\n" # 常见模板带换行
|
||||
|
||||
# pos = 0
|
||||
# while True:
|
||||
# a = rendered.find(open_tag, pos)
|
||||
# if a == -1:
|
||||
# break
|
||||
# start = a + len(open_tag)
|
||||
|
||||
# # 先找含换行版本,找不到再退化找不带换行的
|
||||
# b = rendered.find(close_tag, start)
|
||||
# if b == -1:
|
||||
# b = rendered.find(close_token, start)
|
||||
# if b == -1:
|
||||
# break
|
||||
|
||||
# end = b + len(close_token) # 把 <|im_end|> 本体纳入监督
|
||||
# spans.append((start, end))
|
||||
|
||||
# # pos 跳过这一轮结束标记(带换行就多跳一格)
|
||||
# pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token))
|
||||
# return spans
|
||||
|
||||
# def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||
# """
|
||||
# 返回需要忽略监督的区间(仅 <think>...</think> 的“内部”),
|
||||
# 标签本身 <think> 与 </think> 仍参与监督,以便模型学会闭合。
|
||||
# """
|
||||
# spans: List[Tuple[int, int]] = []
|
||||
# open_tag = "<think>"
|
||||
# close_tag = "</think>"
|
||||
# pos = 0
|
||||
# while True:
|
||||
# a = rendered.find(open_tag, pos)
|
||||
# if a == -1:
|
||||
# break
|
||||
# b = rendered.find(close_tag, a + len(open_tag))
|
||||
# if b == -1:
|
||||
# break
|
||||
# # 只忽略内部,不忽略两侧标签
|
||||
# spans.append((a + len(open_tag), b))
|
||||
# pos = b + len(close_tag)
|
||||
# return spans
|
||||
|
||||
|
||||
# ----------------- 仅监督 assistant 内容(token-id 级,不用 offsets) -----------------
|
||||
class QwenChatSFTDataset(IterableDataset):
|
||||
"""
|
||||
|
|
@ -510,12 +460,85 @@ def main():
|
|||
set_seed(args.seed)
|
||||
|
||||
host = socket.gethostname()
|
||||
def dbg(msg):
|
||||
print(
|
||||
f"[dbg][host={host} RANK={os.environ.get('RANK','0')} "
|
||||
f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}",
|
||||
flush=True
|
||||
)
|
||||
|
||||
|
||||
# ==== colored dbg (robust default to info) ====
|
||||
try:
|
||||
import colorama
|
||||
colorama.just_fix_windows_console()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _use_color() -> bool:
|
||||
if os.environ.get("NO_COLOR"): return False
|
||||
if os.environ.get("FORCE_COLOR"): return True
|
||||
return sys.stdout.isatty()
|
||||
|
||||
class _C:
|
||||
reset = "\033[0m"
|
||||
gray = "\033[90m"
|
||||
green = "\033[32m"
|
||||
yellow = "\033[33m"
|
||||
red = "\033[31m"
|
||||
cyan = "\033[36m"
|
||||
|
||||
def _paint(s, color):
|
||||
return f"{color}{s}{_C.reset}" if _use_color() else s
|
||||
|
||||
_LEVEL_ALIAS = {
|
||||
"": "info", None: "info",
|
||||
"ok": "ok", "success": "ok", "pass": "ok",
|
||||
"warn": "warn", "warning": "warn",
|
||||
"err": "err", "error": "err", "fatal": "err", "fail": "err",
|
||||
"info": "info", "information": "info"
|
||||
}
|
||||
|
||||
_LEVEL_COLOR = {
|
||||
"ok": _C.green,
|
||||
"warn": _C.yellow,
|
||||
"err": _C.red,
|
||||
"info": _C.cyan,
|
||||
}
|
||||
|
||||
def _norm_level(level) -> str:
|
||||
# 默认 info
|
||||
if level is None:
|
||||
return "info"
|
||||
# 数字等级兼容(类似 logging)
|
||||
if isinstance(level, (int, float)):
|
||||
if level >= 40: return "err"
|
||||
if level >= 30: return "warn"
|
||||
return "info"
|
||||
# 字符串别名
|
||||
if isinstance(level, str):
|
||||
key = level.strip().lower()
|
||||
return _LEVEL_ALIAS.get(key, "info")
|
||||
return "info"
|
||||
|
||||
def dbg(msg, level=None):
|
||||
lvl = _norm_level(level) # 未指定/非法 -> "info"
|
||||
host = socket.gethostname()
|
||||
rank = os.environ.get("RANK", "0")
|
||||
lrank = os.environ.get("LOCAL_RANK", "-1")
|
||||
prefix = f"[dbg][host={host} RANK={rank} LOCAL_RANK={lrank}] "
|
||||
color = _LEVEL_COLOR.get(lvl, _C.cyan)
|
||||
print(_paint(prefix, _C.gray) + _paint(str(msg), color), flush=True)
|
||||
|
||||
# 便捷别名(可选)
|
||||
def dbg_ok(m): dbg(m, "ok")
|
||||
def dbg_warn(m): dbg(m, "warn")
|
||||
def dbg_err(m):
|
||||
s = _paint(f"[dbg]{m}", _C.red)
|
||||
print(s, flush=True, file=sys.stderr)
|
||||
|
||||
|
||||
|
||||
# def dbg(msg):
|
||||
# print(
|
||||
# f"[dbg][host={host} RANK={os.environ.get('RANK','0')} "
|
||||
# f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}",
|
||||
# flush=True
|
||||
# )
|
||||
|
||||
# 是否真的启用 DeepSpeed(传了配置文件且文件存在)
|
||||
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
|
||||
|
|
@ -682,9 +705,9 @@ def main():
|
|||
)
|
||||
|
||||
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
|
||||
dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
|
||||
f"use_cache={getattr(model.config,'use_cache',None)} "
|
||||
f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
|
||||
# dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
|
||||
# f"use_cache={getattr(model.config,'use_cache',None)} "
|
||||
# f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
|
||||
|
||||
# 3) pad/alibi 等配置
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
|
@ -703,6 +726,16 @@ def main():
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
# ✅ 放在这里打印“修正后”的值
|
||||
dbg(f"post-config: use_cache={model.config.use_cache} "
|
||||
f"model.pad_token_id={model.config.pad_token_id} "
|
||||
f"gen.pad_token_id={getattr(getattr(model,'generation_config',None),'pad_token_id',None)} "
|
||||
f"tok.pad={tokenizer.pad_token}/{tokenizer.pad_token_id}")
|
||||
|
||||
assert tokenizer.pad_token_id is not None, "tokenizer.pad_token_id must not be None"
|
||||
assert model.config.pad_token_id == tokenizer.pad_token_id, \
|
||||
f"model.pad_token_id {model.config.pad_token_id} != tokenizer.pad_token_id {tokenizer.pad_token_id}"
|
||||
|
||||
# ===== 数据鲁棒性检查(多机各自执行)=====
|
||||
files = sorted(glob.glob(args.data_glob))
|
||||
if len(files) == 0:
|
||||
|
|
|
|||
Loading…
Reference in New Issue