This commit is contained in:
hailin 2025-09-09 10:27:05 +08:00
parent b130f9e682
commit f7b0eca381
3 changed files with 46 additions and 3 deletions

View File

@ -1,7 +1,9 @@
WANDB_BASE_URL=https://wandb.szaiai.com WANDB_BASE_URL=https://wandb.szaiai.com
WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1 WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1
WANDB_PROJECT=ds-qwen3 WANDB_PROJECT=ds-qwen3
WANDB_GROUP=q3-1.7b-ds4-2025-09-05 WANDB_ENTITY=hailin
WANDB_RUN_ID=q3-31.7b-lr2e-5-train2 WANDB_GROUP=q3-32b-ds4-2025-09-05
WANDB_NAME=q3-32b-lr2e-5-train2
WANDB_RESUME=allow WANDB_RESUME=allow
export WANDB_DIR=/tmp/$USER/wandb WANDB_INIT_TIMEOUT=300
WANDB_DIR=/tmp/$USER/wandb

View File

@ -41,6 +41,7 @@ deepspeed --hostfile hostfile \
--bf16 \ --bf16 \
--deepspeed /home/test/jd_train/ds_config_zero3.json \ --deepspeed /home/test/jd_train/ds_config_zero3.json \
--report_to wandb \ --report_to wandb \
--wandb_project ds-qwen3 \
--eval_steps 10 \ --eval_steps 10 \
--eval_data_glob "/home/test/datasets/my_corpus/test.jsonl" --eval_data_glob "/home/test/datasets/my_corpus/test.jsonl"

View File

@ -9,6 +9,12 @@ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("WANDB_START_METHOD", "thread") os.environ.setdefault("WANDB_START_METHOD", "thread")
os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb") os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb")
# ★ 新增:自建服务的 base_url避免走默认的 cloud
os.environ.setdefault("WANDB_BASE_URL", "https://wandb.szaiai.com")
# (可选)某些版本支持这个 env真正生效仍以下面的 Settings(init_timeout=...) 为准
os.environ.setdefault("WANDB_INIT_TIMEOUT", "300")
import glob import glob
import socket import socket
import argparse import argparse
@ -1000,6 +1006,38 @@ def main():
if args.report_to == "wandb": if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project) os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
# 仅在 rank0 预初始化 W&B
is_rank0 = os.environ.get("RANK", "0") == "0" and os.environ.get("LOCAL_RANK", "-1") in ("0", "-1")
if is_rank0:
import wandb
try:
# 避免外部遗留的 RUN_ID 强制续跑导致卡住
os.environ.pop("WANDB_RUN_ID", None)
# 可选字段从环境注入(有就用)
extra = {}
if os.getenv("WANDB_NAME"): extra["name"] = os.getenv("WANDB_NAME")
if os.getenv("WANDB_GROUP"): extra["group"] = os.getenv("WANDB_GROUP")
if os.getenv("WANDB_RESUME"): extra["resume"] = os.getenv("WANDB_RESUME") # 建议 'allow'
run = wandb.init(
project=args.wandb_project,
entity=os.getenv("WANDB_ENTITY") or os.getenv("WB_ENTITY") or "hailin",
settings=wandb.Settings(
base_url=os.getenv("WANDB_BASE_URL", "https://wandb.szaiai.com"),
init_timeout=int(os.getenv("WANDB_INIT_TIMEOUT", "300")),
),
**extra,
)
print(f"[wandb] run url: {getattr(run, 'url', '(n/a)')}", flush=True)
except Exception as e:
print(f"[wandb] init failed -> disable logging, reason={e}", flush=True)
os.environ["WANDB_DISABLED"] = "true"
args.report_to = "none"
else:
os.environ["WANDB_DISABLED"] = "true"
# 版本 & 启动参数 & 关键环境变量 # 版本 & 启动参数 & 关键环境变量
import transformers as hf import transformers as hf
try: try:
@ -1304,6 +1342,8 @@ def main():
ta_kwargs2 = dict( ta_kwargs2 = dict(
output_dir=args.output_dir, output_dir=args.output_dir,
logging_dir=logging_dir, logging_dir=logging_dir,
# ★ 新增:自定义 run_name避免等于 output_dir 的 warning
run_name=f"sft-{os.path.basename(args.output_dir)}-{socket.gethostname()}",
do_train=True, do_train=True,
do_eval=(eval_dataset is not None), do_eval=(eval_dataset is not None),
# eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None, # eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,