This commit is contained in:
parent
dc82fcfab8
commit
3232d61c71
151
train_sft_ds.py
151
train_sft_ds.py
|
|
@ -179,56 +179,6 @@ class CsvLossLogger(TrainerCallback):
|
||||||
|
|
||||||
from typing import List, Tuple, Iterable, Iterator, Dict
|
from typing import List, Tuple, Iterable, Iterator, Dict
|
||||||
|
|
||||||
# ----------------- 工具:提取 assistant 字符区间 -----------------
|
|
||||||
# def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
|
||||||
# spans: List[Tuple[int, int]] = []
|
|
||||||
# open_tag = "<|im_start|>assistant\n"
|
|
||||||
# close_token = "<|im_end|>"
|
|
||||||
# close_tag = close_token + "\n" # 常见模板带换行
|
|
||||||
|
|
||||||
# pos = 0
|
|
||||||
# while True:
|
|
||||||
# a = rendered.find(open_tag, pos)
|
|
||||||
# if a == -1:
|
|
||||||
# break
|
|
||||||
# start = a + len(open_tag)
|
|
||||||
|
|
||||||
# # 先找含换行版本,找不到再退化找不带换行的
|
|
||||||
# b = rendered.find(close_tag, start)
|
|
||||||
# if b == -1:
|
|
||||||
# b = rendered.find(close_token, start)
|
|
||||||
# if b == -1:
|
|
||||||
# break
|
|
||||||
|
|
||||||
# end = b + len(close_token) # 把 <|im_end|> 本体纳入监督
|
|
||||||
# spans.append((start, end))
|
|
||||||
|
|
||||||
# # pos 跳过这一轮结束标记(带换行就多跳一格)
|
|
||||||
# pos = b + (len(close_tag) if rendered.startswith(close_tag, b) else len(close_token))
|
|
||||||
# return spans
|
|
||||||
|
|
||||||
# def _think_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
|
||||||
# """
|
|
||||||
# 返回需要忽略监督的区间(仅 <think>...</think> 的“内部”),
|
|
||||||
# 标签本身 <think> 与 </think> 仍参与监督,以便模型学会闭合。
|
|
||||||
# """
|
|
||||||
# spans: List[Tuple[int, int]] = []
|
|
||||||
# open_tag = "<think>"
|
|
||||||
# close_tag = "</think>"
|
|
||||||
# pos = 0
|
|
||||||
# while True:
|
|
||||||
# a = rendered.find(open_tag, pos)
|
|
||||||
# if a == -1:
|
|
||||||
# break
|
|
||||||
# b = rendered.find(close_tag, a + len(open_tag))
|
|
||||||
# if b == -1:
|
|
||||||
# break
|
|
||||||
# # 只忽略内部,不忽略两侧标签
|
|
||||||
# spans.append((a + len(open_tag), b))
|
|
||||||
# pos = b + len(close_tag)
|
|
||||||
# return spans
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------- 仅监督 assistant 内容(token-id 级,不用 offsets) -----------------
|
# ----------------- 仅监督 assistant 内容(token-id 级,不用 offsets) -----------------
|
||||||
class QwenChatSFTDataset(IterableDataset):
|
class QwenChatSFTDataset(IterableDataset):
|
||||||
"""
|
"""
|
||||||
|
|
@ -510,12 +460,85 @@ def main():
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
host = socket.gethostname()
|
host = socket.gethostname()
|
||||||
def dbg(msg):
|
|
||||||
print(
|
|
||||||
f"[dbg][host={host} RANK={os.environ.get('RANK','0')} "
|
# ==== colored dbg (robust default to info) ====
|
||||||
f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}",
|
try:
|
||||||
flush=True
|
import colorama
|
||||||
)
|
colorama.just_fix_windows_console()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _use_color() -> bool:
|
||||||
|
if os.environ.get("NO_COLOR"): return False
|
||||||
|
if os.environ.get("FORCE_COLOR"): return True
|
||||||
|
return sys.stdout.isatty()
|
||||||
|
|
||||||
|
class _C:
|
||||||
|
reset = "\033[0m"
|
||||||
|
gray = "\033[90m"
|
||||||
|
green = "\033[32m"
|
||||||
|
yellow = "\033[33m"
|
||||||
|
red = "\033[31m"
|
||||||
|
cyan = "\033[36m"
|
||||||
|
|
||||||
|
def _paint(s, color):
|
||||||
|
return f"{color}{s}{_C.reset}" if _use_color() else s
|
||||||
|
|
||||||
|
_LEVEL_ALIAS = {
|
||||||
|
"": "info", None: "info",
|
||||||
|
"ok": "ok", "success": "ok", "pass": "ok",
|
||||||
|
"warn": "warn", "warning": "warn",
|
||||||
|
"err": "err", "error": "err", "fatal": "err", "fail": "err",
|
||||||
|
"info": "info", "information": "info"
|
||||||
|
}
|
||||||
|
|
||||||
|
_LEVEL_COLOR = {
|
||||||
|
"ok": _C.green,
|
||||||
|
"warn": _C.yellow,
|
||||||
|
"err": _C.red,
|
||||||
|
"info": _C.cyan,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _norm_level(level) -> str:
|
||||||
|
# 默认 info
|
||||||
|
if level is None:
|
||||||
|
return "info"
|
||||||
|
# 数字等级兼容(类似 logging)
|
||||||
|
if isinstance(level, (int, float)):
|
||||||
|
if level >= 40: return "err"
|
||||||
|
if level >= 30: return "warn"
|
||||||
|
return "info"
|
||||||
|
# 字符串别名
|
||||||
|
if isinstance(level, str):
|
||||||
|
key = level.strip().lower()
|
||||||
|
return _LEVEL_ALIAS.get(key, "info")
|
||||||
|
return "info"
|
||||||
|
|
||||||
|
def dbg(msg, level=None):
|
||||||
|
lvl = _norm_level(level) # 未指定/非法 -> "info"
|
||||||
|
host = socket.gethostname()
|
||||||
|
rank = os.environ.get("RANK", "0")
|
||||||
|
lrank = os.environ.get("LOCAL_RANK", "-1")
|
||||||
|
prefix = f"[dbg][host={host} RANK={rank} LOCAL_RANK={lrank}] "
|
||||||
|
color = _LEVEL_COLOR.get(lvl, _C.cyan)
|
||||||
|
print(_paint(prefix, _C.gray) + _paint(str(msg), color), flush=True)
|
||||||
|
|
||||||
|
# 便捷别名(可选)
|
||||||
|
def dbg_ok(m): dbg(m, "ok")
|
||||||
|
def dbg_warn(m): dbg(m, "warn")
|
||||||
|
def dbg_err(m):
|
||||||
|
s = _paint(f"[dbg]{m}", _C.red)
|
||||||
|
print(s, flush=True, file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# def dbg(msg):
|
||||||
|
# print(
|
||||||
|
# f"[dbg][host={host} RANK={os.environ.get('RANK','0')} "
|
||||||
|
# f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}",
|
||||||
|
# flush=True
|
||||||
|
# )
|
||||||
|
|
||||||
# 是否真的启用 DeepSpeed(传了配置文件且文件存在)
|
# 是否真的启用 DeepSpeed(传了配置文件且文件存在)
|
||||||
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
|
use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed))
|
||||||
|
|
@ -682,9 +705,9 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
|
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
|
||||||
dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
|
# dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
|
||||||
f"use_cache={getattr(model.config,'use_cache',None)} "
|
# f"use_cache={getattr(model.config,'use_cache',None)} "
|
||||||
f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
|
# f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
|
||||||
|
|
||||||
# 3) pad/alibi 等配置
|
# 3) pad/alibi 等配置
|
||||||
model.config.pad_token_id = tokenizer.pad_token_id
|
model.config.pad_token_id = tokenizer.pad_token_id
|
||||||
|
|
@ -703,6 +726,16 @@ def main():
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# ✅ 放在这里打印“修正后”的值
|
||||||
|
dbg(f"post-config: use_cache={model.config.use_cache} "
|
||||||
|
f"model.pad_token_id={model.config.pad_token_id} "
|
||||||
|
f"gen.pad_token_id={getattr(getattr(model,'generation_config',None),'pad_token_id',None)} "
|
||||||
|
f"tok.pad={tokenizer.pad_token}/{tokenizer.pad_token_id}")
|
||||||
|
|
||||||
|
assert tokenizer.pad_token_id is not None, "tokenizer.pad_token_id must not be None"
|
||||||
|
assert model.config.pad_token_id == tokenizer.pad_token_id, \
|
||||||
|
f"model.pad_token_id {model.config.pad_token_id} != tokenizer.pad_token_id {tokenizer.pad_token_id}"
|
||||||
|
|
||||||
# ===== 数据鲁棒性检查(多机各自执行)=====
|
# ===== 数据鲁棒性检查(多机各自执行)=====
|
||||||
files = sorted(glob.glob(args.data_glob))
|
files = sorted(glob.glob(args.data_glob))
|
||||||
if len(files) == 0:
|
if len(files) == 0:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue