From f7b0eca3811ff0d986e0e06417d836c2007ca72c Mon Sep 17 00:00:00 2001 From: hailin Date: Tue, 9 Sep 2025 10:27:05 +0800 Subject: [PATCH] . --- .deepspeed_env | 8 +++++--- mm-zero3.sh | 1 + train_sft_ds.py | 40 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/.deepspeed_env b/.deepspeed_env index 3a81de0..84b51fe 100644 --- a/.deepspeed_env +++ b/.deepspeed_env @@ -1,7 +1,9 @@ WANDB_BASE_URL=https://wandb.szaiai.com WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1 WANDB_PROJECT=ds-qwen3 -WANDB_GROUP=q3-1.7b-ds4-2025-09-05 -WANDB_RUN_ID=q3-31.7b-lr2e-5-train2 +WANDB_ENTITY=hailin +WANDB_GROUP=q3-32b-ds4-2025-09-05 +WANDB_NAME=q3-32b-lr2e-5-train2 WANDB_RESUME=allow -export WANDB_DIR=/tmp/$USER/wandb +WANDB_INIT_TIMEOUT=300 +WANDB_DIR=/tmp/$USER/wandb diff --git a/mm-zero3.sh b/mm-zero3.sh index 054bcca..7e47cc4 100755 --- a/mm-zero3.sh +++ b/mm-zero3.sh @@ -41,6 +41,7 @@ deepspeed --hostfile hostfile \ --bf16 \ --deepspeed /home/test/jd_train/ds_config_zero3.json \ --report_to wandb \ + --wandb_project ds-qwen3 \ --eval_steps 10 \ --eval_data_glob "/home/test/datasets/my_corpus/test.jsonl" diff --git a/train_sft_ds.py b/train_sft_ds.py index a4c07df..e3abcd7 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -9,6 +9,12 @@ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("WANDB_START_METHOD", "thread") os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb") +# ★ 新增:自建服务的 base_url(避免走默认的 cloud) +os.environ.setdefault("WANDB_BASE_URL", "https://wandb.szaiai.com") +# (可选)某些版本支持这个 env;真正生效仍以下面的 Settings(init_timeout=...) 为准 +os.environ.setdefault("WANDB_INIT_TIMEOUT", "300") + + import glob import socket import argparse @@ -1000,6 +1006,38 @@ def main(): if args.report_to == "wandb": os.environ.setdefault("WANDB_PROJECT", args.wandb_project) + + # 仅在 rank0 预初始化 W&B + is_rank0 = os.environ.get("RANK", "0") == "0" and os.environ.get("LOCAL_RANK", "-1") in ("0", "-1") + if is_rank0: + import wandb + try: + # 避免外部遗留的 RUN_ID 强制续跑导致卡住 + os.environ.pop("WANDB_RUN_ID", None) + + # 可选字段从环境注入(有就用) + extra = {} + if os.getenv("WANDB_NAME"): extra["name"] = os.getenv("WANDB_NAME") + if os.getenv("WANDB_GROUP"): extra["group"] = os.getenv("WANDB_GROUP") + if os.getenv("WANDB_RESUME"): extra["resume"] = os.getenv("WANDB_RESUME") # 建议 'allow' + + run = wandb.init( + project=args.wandb_project, + entity=os.getenv("WANDB_ENTITY") or os.getenv("WB_ENTITY") or "hailin", + settings=wandb.Settings( + base_url=os.getenv("WANDB_BASE_URL", "https://wandb.szaiai.com"), + init_timeout=int(os.getenv("WANDB_INIT_TIMEOUT", "300")), + ), + **extra, + ) + print(f"[wandb] run url: {getattr(run, 'url', '(n/a)')}", flush=True) + except Exception as e: + print(f"[wandb] init failed -> disable logging, reason={e}", flush=True) + os.environ["WANDB_DISABLED"] = "true" + args.report_to = "none" + else: + os.environ["WANDB_DISABLED"] = "true" + # 版本 & 启动参数 & 关键环境变量 import transformers as hf try: @@ -1304,6 +1342,8 @@ def main(): ta_kwargs2 = dict( output_dir=args.output_dir, logging_dir=logging_dir, + # ★ 新增:自定义 run_name,避免等于 output_dir 的 warning + run_name=f"sft-{os.path.basename(args.output_dir)}-{socket.gethostname()}", do_train=True, do_eval=(eval_dataset is not None), # eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,