jd_train/run_ds.sh

59 lines
2.1 KiB
Bash
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env bash
set -euo pipefail
# 可选:调试输出
# [ "${DEBUG:-0}" = "1" ] && set -x
export NCCL_DEBUG=INFO
# 如走 IB/RoCE请按实际网卡开启
# export NCCL_IB_HCA="mlx5_0,mlx5_1"
# export NCCL_SOCKET_IFNAME="ib0"
# 纯以太:
# export NCCL_SOCKET_IFNAME="eth0"
# 解析脚本目录,避免相对路径问题
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
DS_CONFIG="${DS_CONFIG:-$SCRIPT_DIR/ds_config_zero3.json}"
# ==== 超参数(本地路径;可用 VAR=xxx ./launch_ds.sh 覆写)====
MODEL_NAME_OR_PATH="${MODEL_NAME_OR_PATH:-/home/test/Qwen3-8B}"
DATA_GLOB="${DATA_GLOB:-$HOME/datasets/my_corpus/*.jsonl}" # 每台机器都放相同路径
OUTDIR="${OUTDIR:-$HOME/checkpoints/run-qwen3-8b}" # 每台机器各自本地输出
SEQ_LEN="${SEQ_LEN:-4096}"
LR="${LR:-2e-5}"
GAS="${GAS:-64}"
LOG_STEPS="${LOG_STEPS:-10}"
SAVE_STEPS="${SAVE_STEPS:-500}"
MAX_STEPS="${MAX_STEPS:-10000}"
# 轻量校验(只在发起节点做;远端各 rank 会在各自脚本里也 mkdir
[ -d "$MODEL_NAME_OR_PATH" ] || { echo "ERR: model not found at $MODEL_NAME_OR_PATH"; exit 1; }
[ -f "$DS_CONFIG" ] || { echo "ERR: deepspeed config not found at $DS_CONFIG"; exit 1; }
# 数据是通配符,这里只做一次“至少匹配到 1 个文件”的检查(发起节点);各节点需自行确保相同路径存在
shopt -s nullglob
matches=( $DATA_GLOB )
if [ ${#matches[@]} -eq 0 ]; then
echo "WARN: no files matched by DATA_GLOB=$DATA_GLOB on this node (确保每台机器该路径下有数据)"
fi
shopt -u nullglob
mkdir -p "${OUTDIR}"
# ==== 多机 DeepSpeed ====
deepspeed --hostfile "$SCRIPT_DIR/hostfile" "$SCRIPT_DIR/train_sft_ds.py" \
--model_name_or_path "${MODEL_NAME_OR_PATH}" \
--data_glob "${DATA_GLOB}" \
--output_dir "${OUTDIR}" \
--seq_len "${SEQ_LEN}" \
--learning_rate "${LR}" \
--gradient_accumulation_steps "${GAS}" \
--per_device_train_batch_size 1 \
--warmup_ratio 0.02 \
--weight_decay 0.1 \
--max_steps "${MAX_STEPS}" \
--log_interval "${LOG_STEPS}" \
--save_steps "${SAVE_STEPS}" \
--deepspeed "${DS_CONFIG}" \
--gradient_checkpointing \
--bf16