From 061e51f975e6612de42f113acc82cf67ac1f2088 Mon Sep 17 00:00:00 2001 From: hailin Date: Mon, 25 Aug 2025 13:21:30 +0800 Subject: [PATCH] . --- run_ds.sh | 71 +++++++++++++++++++++++++++++++------------------ train_sft_ds.py | 61 ++++++++++++++++++++++++++++-------------- 2 files changed, 86 insertions(+), 46 deletions(-) diff --git a/run_ds.sh b/run_ds.sh index 44aff89..fbce445 100755 --- a/run_ds.sh +++ b/run_ds.sh @@ -1,11 +1,9 @@ #!/usr/bin/env bash set -euo pipefail - -# 可选:调试输出 # [ "${DEBUG:-0}" = "1" ] && set -x export NCCL_DEBUG=INFO -# 如走 IB/RoCE,请按实际网卡开启: +# 如走 IB/RoCE,请按实际网卡开启(示例): # export NCCL_IB_HCA="mlx5_0,mlx5_1" # export NCCL_SOCKET_IFNAME="ib0" # 纯以太: @@ -15,10 +13,14 @@ export NCCL_DEBUG=INFO SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" DS_CONFIG="${DS_CONFIG:-$SCRIPT_DIR/ds_config_zero3.json}" -# ==== 超参数(本地路径;可用 VAR=xxx ./launch_ds.sh 覆写)==== +# ==== 超参数(本地路径;可用 VAR=xxx ./run_ds.sh 覆写)==== MODEL_NAME_OR_PATH="${MODEL_NAME_OR_PATH:-/home/test/Qwen3-8B}" -DATA_GLOB="${DATA_GLOB:-$HOME/datasets/my_corpus/*.jsonl}" # 每台机器都放相同路径 -OUTDIR="${OUTDIR:-$HOME/checkpoints/run-qwen3-8b}" # 每台机器各自本地输出 + +# 明确区分训练/评测文件(可按需改成通配符) +DATA_GLOB="${DATA_GLOB:-$HOME/datasets/my_corpus/train.jsonl}" +EVAL_DATA_GLOB="${EVAL_DATA_GLOB:-$HOME/datasets/my_corpus/test.jsonl}" + +OUTDIR="${OUTDIR:-$HOME/checkpoints/run-qwen3-8b}" SEQ_LEN="${SEQ_LEN:-4096}" LR="${LR:-2e-5}" GAS="${GAS:-64}" @@ -26,33 +28,50 @@ LOG_STEPS="${LOG_STEPS:-10}" SAVE_STEPS="${SAVE_STEPS:-500}" MAX_STEPS="${MAX_STEPS:-10000}" -# 轻量校验(只在发起节点做;远端各 rank 会在各自脚本里也 mkdir) +# 轻量校验(只在发起节点做;各 rank 在脚本里也会 mkdir) [ -d "$MODEL_NAME_OR_PATH" ] || { echo "ERR: model not found at $MODEL_NAME_OR_PATH"; exit 1; } [ -f "$DS_CONFIG" ] || { echo "ERR: deepspeed config not found at $DS_CONFIG"; exit 1; } -# 数据是通配符,这里只做一次“至少匹配到 1 个文件”的检查(发起节点);各节点需自行确保相同路径存在 + +# 检查训练集是否能匹配到 shopt -s nullglob -matches=( $DATA_GLOB ) -if [ ${#matches[@]} -eq 0 ]; then +train_matches=( $DATA_GLOB ) +if [ ${#train_matches[@]} -eq 0 ]; then echo "WARN: no files matched by DATA_GLOB=$DATA_GLOB on this node (确保每台机器该路径下有数据)" fi + +# 检查评测集(可为空;若为空则不传 --eval_data_glob) +eval_matches=() +if [ -n "${EVAL_DATA_GLOB:-}" ]; then + eval_matches=( $EVAL_DATA_GLOB ) + if [ ${#eval_matches[@]} -eq 0 ]; then + echo "WARN: no files matched by EVAL_DATA_GLOB=$EVAL_DATA_GLOB on this node (将不进行评测)" + fi +fi shopt -u nullglob mkdir -p "${OUTDIR}" -# ==== 多机 DeepSpeed ==== -deepspeed --hostfile "$SCRIPT_DIR/hostfile" "$SCRIPT_DIR/train_sft_ds.py" \ - --model_name_or_path "${MODEL_NAME_OR_PATH}" \ - --data_glob "${DATA_GLOB}" \ - --output_dir "${OUTDIR}" \ - --seq_len "${SEQ_LEN}" \ - --learning_rate "${LR}" \ - --gradient_accumulation_steps "${GAS}" \ - --per_device_train_batch_size 1 \ - --warmup_ratio 0.02 \ - --weight_decay 0.1 \ - --max_steps "${MAX_STEPS}" \ - --log_interval "${LOG_STEPS}" \ - --save_steps "${SAVE_STEPS}" \ - --deepspeed "${DS_CONFIG}" \ - --gradient_checkpointing \ +# 组装参数(只有 eval 有效才加) +args=( + --model_name_or_path "${MODEL_NAME_OR_PATH}" + --data_glob "${DATA_GLOB}" + --output_dir "${OUTDIR}" + --seq_len "${SEQ_LEN}" + --learning_rate "${LR}" + --gradient_accumulation_steps "${GAS}" + --per_device_train_batch_size 1 + --warmup_ratio 0.02 + --weight_decay 0.1 + --max_steps "${MAX_STEPS}" + --log_interval "${LOG_STEPS}" + --save_steps "${SAVE_STEPS}" + --deepspeed "${DS_CONFIG}" + --gradient_checkpointing --bf16 +) +if [ ${#eval_matches[@]} -gt 0 ]; then + args+=( --eval_data_glob "${EVAL_DATA_GLOB}" ) +fi + +# ==== 多机 DeepSpeed ==== +deepspeed --hostfile "$SCRIPT_DIR/hostfile" "$SCRIPT_DIR/train_sft_ds.py" "${args[@]}" diff --git a/train_sft_ds.py b/train_sft_ds.py index c6d5375..a3eaa76 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -6,7 +6,7 @@ import argparse from typing import Dict, List, Iterable, Iterator, Tuple, Optional import torch -from torch.utils.data import IterableDataset +from torch.utils.data import IterableDataset, Dataset from datasets import load_dataset from transformers import ( @@ -203,7 +203,7 @@ def parse_args(): ap.add_argument("--max_steps", type=int, default=-1) ap.add_argument("--log_interval", type=int, default=10) ap.add_argument("--save_steps", type=int, default=500) - ap.add_argument("--eval_ratio", type=float, default=0.0) # 如需 eval,请准备 messages/工具同格式的数据 + ap.add_argument("--eval_ratio", type=float, default=0.0) # 兜底抽样评估 ap.add_argument("--seed", type=int, default=1337) ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json") ap.add_argument("--gradient_checkpointing", action="store_true") @@ -214,6 +214,8 @@ def parse_args(): ap.add_argument("--report_to", type=str, default="tensorboard", choices=["none","tensorboard","wandb"]) ap.add_argument("--wandb_project", type=str, default="ds-qwen3") + ap.add_argument("--eval_data_glob", type=str, default=None, + help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用") return ap.parse_args() @@ -283,16 +285,44 @@ def main(): yield ex train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) - # 可选 eval:如果你准备了 messages/同模板的 eval 数据,建议用单独 glob;这里维持与你原逻辑相近的“头部抽样” - eval_dataset = None - if args.eval_ratio and args.eval_ratio > 0: - # 简单抽若干样本作为 eval(注意:streaming 情况下这只是粗略评估) + # ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ---- + eval_dataset: Optional[Dataset] = None + + class ListDataset(Dataset): + def __init__(self, items): self.items = items + def __len__(self): return len(self.items) + def __getitem__(self, idx): return self.items[idx] + + if args.eval_data_glob: + eval_files = sorted(glob.glob(args.eval_data_glob)) + if len(eval_files) == 0: + raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}") + if is_main_process(): + print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True) + + ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True) + def ex_iter_eval(): + for ex in ds_eval_stream: + yield ex + + eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len) + eval_items: List[Dict[str, torch.Tensor]] = [] + for sample in eval_iterable: + eval_items.append(sample) + + if len(eval_items) == 0: + raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。") + + eval_dataset = ListDataset(eval_items) + + elif args.eval_ratio and args.eval_ratio > 0: + # 简易头部抽样(流式下仅作粗评) desired_eval_batches = 200 tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True) - def ex_iter_eval(): + def ex_iter_eval2(): for ex in tmp_stream: yield ex - eval_stream = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len) + eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len) eval_samples = [] it = iter(eval_stream) for _ in range(desired_eval_batches): @@ -300,23 +330,14 @@ def main(): eval_samples.append(next(it)) except StopIteration: break - class ListDataset(torch.utils.data.Dataset): - def __init__(self, items): self.items = items - def __len__(self): return len(self.items) - def __getitem__(self, idx): return self.items[idx] - eval_dataset = ListDataset(eval_samples) - - # 再重建训练流 - ds_stream3 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) - def ex_iter3(): - for ex in ds_stream3: - yield ex - train_stream = QwenChatSFTDataset(ex_iter3(), tokenizer, seq_len=args.seq_len) + if len(eval_samples) > 0: + eval_dataset = ListDataset(eval_samples) data_collator = SFTDataCollator(tokenizer) os.makedirs(args.output_dir, exist_ok=True) logging_dir = os.path.join(args.output_dir, "logs") + os.makedirs(logging_dir, exist_ok=True) training_args = TrainingArguments( output_dir=args.output_dir,