This commit is contained in:
hailin 2025-08-25 13:21:30 +08:00
parent 3b14067454
commit 061e51f975
2 changed files with 86 additions and 46 deletions

View File

@ -1,11 +1,9 @@
#!/usr/bin/env bash
set -euo pipefail
# 可选:调试输出
# [ "${DEBUG:-0}" = "1" ] && set -x
export NCCL_DEBUG=INFO
# 如走 IB/RoCE请按实际网卡开启
# 如走 IB/RoCE请按实际网卡开启(示例)
# export NCCL_IB_HCA="mlx5_0,mlx5_1"
# export NCCL_SOCKET_IFNAME="ib0"
# 纯以太:
@ -15,10 +13,14 @@ export NCCL_DEBUG=INFO
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
DS_CONFIG="${DS_CONFIG:-$SCRIPT_DIR/ds_config_zero3.json}"
# ==== 超参数(本地路径;可用 VAR=xxx ./launch_ds.sh 覆写)====
# ==== 超参数(本地路径;可用 VAR=xxx ./run_ds.sh 覆写)====
MODEL_NAME_OR_PATH="${MODEL_NAME_OR_PATH:-/home/test/Qwen3-8B}"
DATA_GLOB="${DATA_GLOB:-$HOME/datasets/my_corpus/*.jsonl}" # 每台机器都放相同路径
OUTDIR="${OUTDIR:-$HOME/checkpoints/run-qwen3-8b}" # 每台机器各自本地输出
# 明确区分训练/评测文件(可按需改成通配符)
DATA_GLOB="${DATA_GLOB:-$HOME/datasets/my_corpus/train.jsonl}"
EVAL_DATA_GLOB="${EVAL_DATA_GLOB:-$HOME/datasets/my_corpus/test.jsonl}"
OUTDIR="${OUTDIR:-$HOME/checkpoints/run-qwen3-8b}"
SEQ_LEN="${SEQ_LEN:-4096}"
LR="${LR:-2e-5}"
GAS="${GAS:-64}"
@ -26,33 +28,50 @@ LOG_STEPS="${LOG_STEPS:-10}"
SAVE_STEPS="${SAVE_STEPS:-500}"
MAX_STEPS="${MAX_STEPS:-10000}"
# 轻量校验(只在发起节点做;远端各 rank 各自脚本里也 mkdir
# 轻量校验(只在发起节点做;各 rank 在脚本里也 mkdir
[ -d "$MODEL_NAME_OR_PATH" ] || { echo "ERR: model not found at $MODEL_NAME_OR_PATH"; exit 1; }
[ -f "$DS_CONFIG" ] || { echo "ERR: deepspeed config not found at $DS_CONFIG"; exit 1; }
# 数据是通配符,这里只做一次“至少匹配到 1 个文件”的检查(发起节点);各节点需自行确保相同路径存在
# 检查训练集是否能匹配到
shopt -s nullglob
matches=( $DATA_GLOB )
if [ ${#matches[@]} -eq 0 ]; then
train_matches=( $DATA_GLOB )
if [ ${#train_matches[@]} -eq 0 ]; then
echo "WARN: no files matched by DATA_GLOB=$DATA_GLOB on this node (确保每台机器该路径下有数据)"
fi
# 检查评测集(可为空;若为空则不传 --eval_data_glob
eval_matches=()
if [ -n "${EVAL_DATA_GLOB:-}" ]; then
eval_matches=( $EVAL_DATA_GLOB )
if [ ${#eval_matches[@]} -eq 0 ]; then
echo "WARN: no files matched by EVAL_DATA_GLOB=$EVAL_DATA_GLOB on this node (将不进行评测)"
fi
fi
shopt -u nullglob
mkdir -p "${OUTDIR}"
# ==== 多机 DeepSpeed ====
deepspeed --hostfile "$SCRIPT_DIR/hostfile" "$SCRIPT_DIR/train_sft_ds.py" \
--model_name_or_path "${MODEL_NAME_OR_PATH}" \
--data_glob "${DATA_GLOB}" \
--output_dir "${OUTDIR}" \
--seq_len "${SEQ_LEN}" \
--learning_rate "${LR}" \
--gradient_accumulation_steps "${GAS}" \
--per_device_train_batch_size 1 \
--warmup_ratio 0.02 \
--weight_decay 0.1 \
--max_steps "${MAX_STEPS}" \
--log_interval "${LOG_STEPS}" \
--save_steps "${SAVE_STEPS}" \
--deepspeed "${DS_CONFIG}" \
--gradient_checkpointing \
# 组装参数(只有 eval 有效才加)
args=(
--model_name_or_path "${MODEL_NAME_OR_PATH}"
--data_glob "${DATA_GLOB}"
--output_dir "${OUTDIR}"
--seq_len "${SEQ_LEN}"
--learning_rate "${LR}"
--gradient_accumulation_steps "${GAS}"
--per_device_train_batch_size 1
--warmup_ratio 0.02
--weight_decay 0.1
--max_steps "${MAX_STEPS}"
--log_interval "${LOG_STEPS}"
--save_steps "${SAVE_STEPS}"
--deepspeed "${DS_CONFIG}"
--gradient_checkpointing
--bf16
)
if [ ${#eval_matches[@]} -gt 0 ]; then
args+=( --eval_data_glob "${EVAL_DATA_GLOB}" )
fi
# ==== 多机 DeepSpeed ====
deepspeed --hostfile "$SCRIPT_DIR/hostfile" "$SCRIPT_DIR/train_sft_ds.py" "${args[@]}"

View File

@ -6,7 +6,7 @@ import argparse
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch
from torch.utils.data import IterableDataset
from torch.utils.data import IterableDataset, Dataset
from datasets import load_dataset
from transformers import (
@ -203,7 +203,7 @@ def parse_args():
ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0) # 如需 eval请准备 messages/工具同格式的数据
ap.add_argument("--eval_ratio", type=float, default=0.0) # 兜底抽样评估
ap.add_argument("--seed", type=int, default=1337)
ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
ap.add_argument("--gradient_checkpointing", action="store_true")
@ -214,6 +214,8 @@ def parse_args():
ap.add_argument("--report_to", type=str, default="tensorboard",
choices=["none","tensorboard","wandb"])
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
ap.add_argument("--eval_data_glob", type=str, default=None,
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
return ap.parse_args()
@ -283,16 +285,44 @@ def main():
yield ex
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
# 可选 eval如果你准备了 messages/同模板的 eval 数据,建议用单独 glob这里维持与你原逻辑相近的“头部抽样”
eval_dataset = None
if args.eval_ratio and args.eval_ratio > 0:
# 简单抽若干样本作为 eval注意streaming 情况下这只是粗略评估)
# ---- Eval 构造:优先使用 --eval_data_glob否则才用 eval_ratio 抽样 ----
eval_dataset: Optional[Dataset] = None
class ListDataset(Dataset):
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
if args.eval_data_glob:
eval_files = sorted(glob.glob(args.eval_data_glob))
if len(eval_files) == 0:
raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
if is_main_process():
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
def ex_iter_eval():
for ex in ds_eval_stream:
yield ex
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
eval_items: List[Dict[str, torch.Tensor]] = []
for sample in eval_iterable:
eval_items.append(sample)
if len(eval_items) == 0:
raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。")
eval_dataset = ListDataset(eval_items)
elif args.eval_ratio and args.eval_ratio > 0:
# 简易头部抽样(流式下仅作粗评)
desired_eval_batches = 200
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_eval():
def ex_iter_eval2():
for ex in tmp_stream:
yield ex
eval_stream = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len)
eval_samples = []
it = iter(eval_stream)
for _ in range(desired_eval_batches):
@ -300,23 +330,14 @@ def main():
eval_samples.append(next(it))
except StopIteration:
break
class ListDataset(torch.utils.data.Dataset):
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
eval_dataset = ListDataset(eval_samples)
# 再重建训练流
ds_stream3 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter3():
for ex in ds_stream3:
yield ex
train_stream = QwenChatSFTDataset(ex_iter3(), tokenizer, seq_len=args.seq_len)
if len(eval_samples) > 0:
eval_dataset = ListDataset(eval_samples)
data_collator = SFTDataCollator(tokenizer)
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs")
os.makedirs(logging_dir, exist_ok=True)
training_args = TrainingArguments(
output_dir=args.output_dir,