This commit is contained in:
parent
f273231200
commit
ea93b3c067
|
|
@ -63,11 +63,62 @@ class CsvLossLogger(TrainerCallback):
|
||||||
with open(self.csv_path, "w", encoding="utf-8") as f:
|
with open(self.csv_path, "w", encoding="utf-8") as f:
|
||||||
f.write("step,loss,lr,total_flos\n")
|
f.write("step,loss,lr,total_flos\n")
|
||||||
|
|
||||||
|
# def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
# if not is_main_process() or logs is None:
|
||||||
|
# return
|
||||||
|
# with open(self.csv_path, "a", encoding="utf-8") as f:
|
||||||
|
# f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def on_train_begin(self, args, state, control, **kwargs):
|
||||||
|
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
|
||||||
|
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
|
||||||
|
|
||||||
|
rank = os.environ.get("RANK", "?")
|
||||||
|
host = socket.gethostname()
|
||||||
|
print(f"[{host} rank={rank}] total_steps={tot}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
if not is_main_process() or logs is None:
|
if logs is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# ---- 控制台打印:所有 rank 都打当前步/总步 ----
|
||||||
|
cur = int(getattr(state, "global_step", 0) or 0)
|
||||||
|
|
||||||
|
|
||||||
|
# if getattr(args, "logging_steps", None) and cur % args.logging_steps != 0:
|
||||||
|
# return
|
||||||
|
|
||||||
|
tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0)
|
||||||
|
tot = tmp if isinstance(tmp, int) and tmp > 0 else 0
|
||||||
|
pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a")
|
||||||
|
|
||||||
|
# —— tot 一旦可用,就再宣布一次总步数(只打印一次)
|
||||||
|
if tot and not hasattr(self, "_tot_announced"):
|
||||||
|
print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True)
|
||||||
|
self._tot_announced = True
|
||||||
|
|
||||||
|
# if not is_main_process():
|
||||||
|
# return
|
||||||
|
|
||||||
|
rank = os.environ.get("RANK", "?")
|
||||||
|
host = socket.gethostname()
|
||||||
|
print(
|
||||||
|
f"[{host} rank={rank}] step {cur}/{tot} ({pct}) "
|
||||||
|
f"loss={logs.get('loss')} lr={logs.get('learning_rate')}",
|
||||||
|
flush=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- 只在主进程写 CSV,避免并发写 ----
|
||||||
|
if not is_main_process():
|
||||||
return
|
return
|
||||||
with open(self.csv_path, "a", encoding="utf-8") as f:
|
with open(self.csv_path, "a", encoding="utf-8") as f:
|
||||||
f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
|
f.write(
|
||||||
|
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------- 仅监督 assistant 的数据集 -----------------
|
# ----------------- 仅监督 assistant 的数据集 -----------------
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue