This commit is contained in:
parent
f273231200
commit
ea93b3c067
|
|
@ -63,11 +63,62 @@ class CsvLossLogger(TrainerCallback):
|
|||
with open(self.csv_path, "w", encoding="utf-8") as f:
|
||||
f.write("step,loss,lr,total_flos\n")
|
||||
|
||||
# def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
# if not is_main_process() or logs is None:
|
||||
# return
|
||||
# with open(self.csv_path, "a", encoding="utf-8") as f:
|
||||
# f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
|
||||
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
|
||||
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
|
||||
|
||||
rank = os.environ.get("RANK", "?")
|
||||
host = socket.gethostname()
|
||||
print(f"[{host} rank={rank}] total_steps={tot}", flush=True)
|
||||
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if not is_main_process() or logs is None:
|
||||
if logs is None:
|
||||
return
|
||||
|
||||
# ---- 控制台打印:所有 rank 都打当前步/总步 ----
|
||||
cur = int(getattr(state, "global_step", 0) or 0)
|
||||
|
||||
|
||||
# if getattr(args, "logging_steps", None) and cur % args.logging_steps != 0:
|
||||
# return
|
||||
|
||||
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
|
||||
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
|
||||
pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a")
|
||||
|
||||
# —— tot 一旦可用,就再宣布一次总步数(只打印一次)
|
||||
if tot and not hasattr(self, "_tot_announced"):
|
||||
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
|
||||
self._tot_announced = True
|
||||
|
||||
# if not is_main_process():
|
||||
# return
|
||||
|
||||
rank = os.environ.get("RANK", "?")
|
||||
host = socket.gethostname()
|
||||
print(
|
||||
f"[{host} rank={rank}] step {cur}/{tot} ({pct}) "
|
||||
f"loss={logs.get('loss')} lr={logs.get('learning_rate')}",
|
||||
flush=True
|
||||
)
|
||||
|
||||
# ---- 只在主进程写 CSV,避免并发写 ----
|
||||
if not is_main_process():
|
||||
return
|
||||
with open(self.csv_path, "a", encoding="utf-8") as f:
|
||||
f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
|
||||
f.write(
|
||||
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
# ----------------- 仅监督 assistant 的数据集 -----------------
|
||||
|
|
|
|||
Loading…
Reference in New Issue