This commit is contained in:
parent
b130f9e682
commit
f7b0eca381
|
|
@ -1,7 +1,9 @@
|
|||
WANDB_BASE_URL=https://wandb.szaiai.com
|
||||
WANDB_API_KEY=local-701636f51b4741d3862007df5cf7f12cca53d8d1
|
||||
WANDB_PROJECT=ds-qwen3
|
||||
WANDB_GROUP=q3-1.7b-ds4-2025-09-05
|
||||
WANDB_RUN_ID=q3-31.7b-lr2e-5-train2
|
||||
WANDB_ENTITY=hailin
|
||||
WANDB_GROUP=q3-32b-ds4-2025-09-05
|
||||
WANDB_NAME=q3-32b-lr2e-5-train2
|
||||
WANDB_RESUME=allow
|
||||
export WANDB_DIR=/tmp/$USER/wandb
|
||||
WANDB_INIT_TIMEOUT=300
|
||||
WANDB_DIR=/tmp/$USER/wandb
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ deepspeed --hostfile hostfile \
|
|||
--bf16 \
|
||||
--deepspeed /home/test/jd_train/ds_config_zero3.json \
|
||||
--report_to wandb \
|
||||
--wandb_project ds-qwen3 \
|
||||
--eval_steps 10 \
|
||||
--eval_data_glob "/home/test/datasets/my_corpus/test.jsonl"
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,12 @@ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
|||
os.environ.setdefault("WANDB_START_METHOD", "thread")
|
||||
os.environ.setdefault("WANDB_DIR", f"/tmp/{os.environ.get('USER','user')}/wandb")
|
||||
|
||||
# ★ 新增:自建服务的 base_url(避免走默认的 cloud)
|
||||
os.environ.setdefault("WANDB_BASE_URL", "https://wandb.szaiai.com")
|
||||
# (可选)某些版本支持这个 env;真正生效仍以下面的 Settings(init_timeout=...) 为准
|
||||
os.environ.setdefault("WANDB_INIT_TIMEOUT", "300")
|
||||
|
||||
|
||||
import glob
|
||||
import socket
|
||||
import argparse
|
||||
|
|
@ -1000,6 +1006,38 @@ def main():
|
|||
if args.report_to == "wandb":
|
||||
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
|
||||
|
||||
|
||||
# 仅在 rank0 预初始化 W&B
|
||||
is_rank0 = os.environ.get("RANK", "0") == "0" and os.environ.get("LOCAL_RANK", "-1") in ("0", "-1")
|
||||
if is_rank0:
|
||||
import wandb
|
||||
try:
|
||||
# 避免外部遗留的 RUN_ID 强制续跑导致卡住
|
||||
os.environ.pop("WANDB_RUN_ID", None)
|
||||
|
||||
# 可选字段从环境注入(有就用)
|
||||
extra = {}
|
||||
if os.getenv("WANDB_NAME"): extra["name"] = os.getenv("WANDB_NAME")
|
||||
if os.getenv("WANDB_GROUP"): extra["group"] = os.getenv("WANDB_GROUP")
|
||||
if os.getenv("WANDB_RESUME"): extra["resume"] = os.getenv("WANDB_RESUME") # 建议 'allow'
|
||||
|
||||
run = wandb.init(
|
||||
project=args.wandb_project,
|
||||
entity=os.getenv("WANDB_ENTITY") or os.getenv("WB_ENTITY") or "hailin",
|
||||
settings=wandb.Settings(
|
||||
base_url=os.getenv("WANDB_BASE_URL", "https://wandb.szaiai.com"),
|
||||
init_timeout=int(os.getenv("WANDB_INIT_TIMEOUT", "300")),
|
||||
),
|
||||
**extra,
|
||||
)
|
||||
print(f"[wandb] run url: {getattr(run, 'url', '(n/a)')}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[wandb] init failed -> disable logging, reason={e}", flush=True)
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
args.report_to = "none"
|
||||
else:
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
# 版本 & 启动参数 & 关键环境变量
|
||||
import transformers as hf
|
||||
try:
|
||||
|
|
@ -1304,6 +1342,8 @@ def main():
|
|||
ta_kwargs2 = dict(
|
||||
output_dir=args.output_dir,
|
||||
logging_dir=logging_dir,
|
||||
# ★ 新增:自定义 run_name,避免等于 output_dir 的 warning
|
||||
run_name=f"sft-{os.path.basename(args.output_dir)}-{socket.gethostname()}",
|
||||
do_train=True,
|
||||
do_eval=(eval_dataset is not None),
|
||||
# eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
|
||||
|
|
|
|||
Loading…
Reference in New Issue