This commit is contained in:
hailin 2025-08-28 11:10:14 +08:00
parent f273231200
commit ea93b3c067
1 changed files with 53 additions and 2 deletions

View File

@ -63,11 +63,62 @@ class CsvLossLogger(TrainerCallback):
with open(self.csv_path, "w", encoding="utf-8") as f:
f.write("step,loss,lr,total_flos\n")
# def on_log(self, args, state, control, logs=None, **kwargs):
# if not is_main_process() or logs is None:
# return
# with open(self.csv_path, "a", encoding="utf-8") as f:
# f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
def on_train_begin(self, args, state, control, **kwargs):
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
rank = os.environ.get("RANK", "?")
host = socket.gethostname()
print(f"[{host} rank={rank}] total_steps={tot}", flush=True)
def on_log(self, args, state, control, logs=None, **kwargs):
if not is_main_process() or logs is None:
if logs is None:
return
# ---- 控制台打印:所有 rank 都打当前步/总步 ----
cur = int(getattr(state, "global_step", 0) or 0)
# if getattr(args, "logging_steps", None) and cur % args.logging_steps != 0:
# return
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a")
# —— tot 一旦可用,就再宣布一次总步数(只打印一次)
if tot and not hasattr(self, "_tot_announced"):
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
self._tot_announced = True
# if not is_main_process():
# return
rank = os.environ.get("RANK", "?")
host = socket.gethostname()
print(
f"[{host} rank={rank}] step {cur}/{tot} ({pct}) "
f"loss={logs.get('loss')} lr={logs.get('learning_rate')}",
flush=True
)
# ---- 只在主进程写 CSV避免并发写 ----
if not is_main_process():
return
with open(self.csv_path, "a", encoding="utf-8") as f:
f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
f.write(
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
)
# ----------------- 仅监督 assistant 的数据集 -----------------