diff --git a/train_sft_ds.py b/train_sft_ds.py index 5d1e140..13cb148 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -63,11 +63,62 @@ class CsvLossLogger(TrainerCallback): with open(self.csv_path, "w", encoding="utf-8") as f: f.write("step,loss,lr,total_flos\n") + # def on_log(self, args, state, control, logs=None, **kwargs): + # if not is_main_process() or logs is None: + # return + # with open(self.csv_path, "a", encoding="utf-8") as f: + # f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n") + + + def on_train_begin(self, args, state, control, **kwargs): + tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) + tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 + + rank = os.environ.get("RANK", "?") + host = socket.gethostname() + print(f"[{host} rank={rank}] total_steps={tot}", flush=True) + + def on_log(self, args, state, control, logs=None, **kwargs): - if not is_main_process() or logs is None: + if logs is None: + return + + # ---- 控制台打印:所有 rank 都打当前步/总步 ---- + cur = int(getattr(state, "global_step", 0) or 0) + + + # if getattr(args, "logging_steps", None) and cur % args.logging_steps != 0: + # return + + tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) + tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 + pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a") + + # —— tot 一旦可用,就再宣布一次总步数(只打印一次) + if tot and not hasattr(self, "_tot_announced"): + print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True) + self._tot_announced = True + + # if not is_main_process(): + # return + + rank = os.environ.get("RANK", "?") + host = socket.gethostname() + print( + f"[{host} rank={rank}] step {cur}/{tot} ({pct}) " + f"loss={logs.get('loss')} lr={logs.get('learning_rate')}", + flush=True + ) + + # ---- 只在主进程写 CSV,避免并发写 ---- + if not is_main_process(): return with open(self.csv_path, "a", encoding="utf-8") as f: - f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n") + f.write( + f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n" + ) + + # ----------------- 仅监督 assistant 的数据集 -----------------