This commit is contained in:
hailin 2025-08-25 13:21:30 +08:00
parent 3b14067454
commit 061e51f975
2 changed files with 86 additions and 46 deletions

View File

@ -1,11 +1,9 @@
#!/usr/bin/env bash #!/usr/bin/env bash
set -euo pipefail set -euo pipefail
# 可选:调试输出
# [ "${DEBUG:-0}" = "1" ] && set -x # [ "${DEBUG:-0}" = "1" ] && set -x
export NCCL_DEBUG=INFO export NCCL_DEBUG=INFO
# 如走 IB/RoCE请按实际网卡开启 # 如走 IB/RoCE请按实际网卡开启(示例)
# export NCCL_IB_HCA="mlx5_0,mlx5_1" # export NCCL_IB_HCA="mlx5_0,mlx5_1"
# export NCCL_SOCKET_IFNAME="ib0" # export NCCL_SOCKET_IFNAME="ib0"
# 纯以太: # 纯以太:
@ -15,10 +13,14 @@ export NCCL_DEBUG=INFO
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
DS_CONFIG="${DS_CONFIG:-$SCRIPT_DIR/ds_config_zero3.json}" DS_CONFIG="${DS_CONFIG:-$SCRIPT_DIR/ds_config_zero3.json}"
# ==== 超参数(本地路径;可用 VAR=xxx ./launch_ds.sh 覆写)==== # ==== 超参数(本地路径;可用 VAR=xxx ./run_ds.sh 覆写)====
MODEL_NAME_OR_PATH="${MODEL_NAME_OR_PATH:-/home/test/Qwen3-8B}" MODEL_NAME_OR_PATH="${MODEL_NAME_OR_PATH:-/home/test/Qwen3-8B}"
DATA_GLOB="${DATA_GLOB:-$HOME/datasets/my_corpus/*.jsonl}" # 每台机器都放相同路径
OUTDIR="${OUTDIR:-$HOME/checkpoints/run-qwen3-8b}" # 每台机器各自本地输出 # 明确区分训练/评测文件(可按需改成通配符)
DATA_GLOB="${DATA_GLOB:-$HOME/datasets/my_corpus/train.jsonl}"
EVAL_DATA_GLOB="${EVAL_DATA_GLOB:-$HOME/datasets/my_corpus/test.jsonl}"
OUTDIR="${OUTDIR:-$HOME/checkpoints/run-qwen3-8b}"
SEQ_LEN="${SEQ_LEN:-4096}" SEQ_LEN="${SEQ_LEN:-4096}"
LR="${LR:-2e-5}" LR="${LR:-2e-5}"
GAS="${GAS:-64}" GAS="${GAS:-64}"
@ -26,33 +28,50 @@ LOG_STEPS="${LOG_STEPS:-10}"
SAVE_STEPS="${SAVE_STEPS:-500}" SAVE_STEPS="${SAVE_STEPS:-500}"
MAX_STEPS="${MAX_STEPS:-10000}" MAX_STEPS="${MAX_STEPS:-10000}"
# 轻量校验(只在发起节点做;远端各 rank 各自脚本里也 mkdir # 轻量校验(只在发起节点做;各 rank 在脚本里也 mkdir
[ -d "$MODEL_NAME_OR_PATH" ] || { echo "ERR: model not found at $MODEL_NAME_OR_PATH"; exit 1; } [ -d "$MODEL_NAME_OR_PATH" ] || { echo "ERR: model not found at $MODEL_NAME_OR_PATH"; exit 1; }
[ -f "$DS_CONFIG" ] || { echo "ERR: deepspeed config not found at $DS_CONFIG"; exit 1; } [ -f "$DS_CONFIG" ] || { echo "ERR: deepspeed config not found at $DS_CONFIG"; exit 1; }
# 数据是通配符,这里只做一次“至少匹配到 1 个文件”的检查(发起节点);各节点需自行确保相同路径存在
# 检查训练集是否能匹配到
shopt -s nullglob shopt -s nullglob
matches=( $DATA_GLOB ) train_matches=( $DATA_GLOB )
if [ ${#matches[@]} -eq 0 ]; then if [ ${#train_matches[@]} -eq 0 ]; then
echo "WARN: no files matched by DATA_GLOB=$DATA_GLOB on this node (确保每台机器该路径下有数据)" echo "WARN: no files matched by DATA_GLOB=$DATA_GLOB on this node (确保每台机器该路径下有数据)"
fi fi
# 检查评测集(可为空;若为空则不传 --eval_data_glob
eval_matches=()
if [ -n "${EVAL_DATA_GLOB:-}" ]; then
eval_matches=( $EVAL_DATA_GLOB )
if [ ${#eval_matches[@]} -eq 0 ]; then
echo "WARN: no files matched by EVAL_DATA_GLOB=$EVAL_DATA_GLOB on this node (将不进行评测)"
fi
fi
shopt -u nullglob shopt -u nullglob
mkdir -p "${OUTDIR}" mkdir -p "${OUTDIR}"
# ==== 多机 DeepSpeed ==== # 组装参数(只有 eval 有效才加)
deepspeed --hostfile "$SCRIPT_DIR/hostfile" "$SCRIPT_DIR/train_sft_ds.py" \ args=(
--model_name_or_path "${MODEL_NAME_OR_PATH}" \ --model_name_or_path "${MODEL_NAME_OR_PATH}"
--data_glob "${DATA_GLOB}" \ --data_glob "${DATA_GLOB}"
--output_dir "${OUTDIR}" \ --output_dir "${OUTDIR}"
--seq_len "${SEQ_LEN}" \ --seq_len "${SEQ_LEN}"
--learning_rate "${LR}" \ --learning_rate "${LR}"
--gradient_accumulation_steps "${GAS}" \ --gradient_accumulation_steps "${GAS}"
--per_device_train_batch_size 1 \ --per_device_train_batch_size 1
--warmup_ratio 0.02 \ --warmup_ratio 0.02
--weight_decay 0.1 \ --weight_decay 0.1
--max_steps "${MAX_STEPS}" \ --max_steps "${MAX_STEPS}"
--log_interval "${LOG_STEPS}" \ --log_interval "${LOG_STEPS}"
--save_steps "${SAVE_STEPS}" \ --save_steps "${SAVE_STEPS}"
--deepspeed "${DS_CONFIG}" \ --deepspeed "${DS_CONFIG}"
--gradient_checkpointing \ --gradient_checkpointing
--bf16 --bf16
)
if [ ${#eval_matches[@]} -gt 0 ]; then
args+=( --eval_data_glob "${EVAL_DATA_GLOB}" )
fi
# ==== 多机 DeepSpeed ====
deepspeed --hostfile "$SCRIPT_DIR/hostfile" "$SCRIPT_DIR/train_sft_ds.py" "${args[@]}"

View File

@ -6,7 +6,7 @@ import argparse
from typing import Dict, List, Iterable, Iterator, Tuple, Optional from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch import torch
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset, Dataset
from datasets import load_dataset from datasets import load_dataset
from transformers import ( from transformers import (
@ -203,7 +203,7 @@ def parse_args():
ap.add_argument("--max_steps", type=int, default=-1) ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10) ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500) ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0) # 如需 eval请准备 messages/工具同格式的数据 ap.add_argument("--eval_ratio", type=float, default=0.0) # 兜底抽样评估
ap.add_argument("--seed", type=int, default=1337) ap.add_argument("--seed", type=int, default=1337)
ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json") ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
ap.add_argument("--gradient_checkpointing", action="store_true") ap.add_argument("--gradient_checkpointing", action="store_true")
@ -214,6 +214,8 @@ def parse_args():
ap.add_argument("--report_to", type=str, default="tensorboard", ap.add_argument("--report_to", type=str, default="tensorboard",
choices=["none","tensorboard","wandb"]) choices=["none","tensorboard","wandb"])
ap.add_argument("--wandb_project", type=str, default="ds-qwen3") ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
ap.add_argument("--eval_data_glob", type=str, default=None,
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
return ap.parse_args() return ap.parse_args()
@ -283,16 +285,44 @@ def main():
yield ex yield ex
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
# 可选 eval如果你准备了 messages/同模板的 eval 数据,建议用单独 glob这里维持与你原逻辑相近的“头部抽样” # ---- Eval 构造:优先使用 --eval_data_glob否则才用 eval_ratio 抽样 ----
eval_dataset = None eval_dataset: Optional[Dataset] = None
if args.eval_ratio and args.eval_ratio > 0:
# 简单抽若干样本作为 eval注意streaming 情况下这只是粗略评估) class ListDataset(Dataset):
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
if args.eval_data_glob:
eval_files = sorted(glob.glob(args.eval_data_glob))
if len(eval_files) == 0:
raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
if is_main_process():
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
def ex_iter_eval():
for ex in ds_eval_stream:
yield ex
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
eval_items: List[Dict[str, torch.Tensor]] = []
for sample in eval_iterable:
eval_items.append(sample)
if len(eval_items) == 0:
raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。")
eval_dataset = ListDataset(eval_items)
elif args.eval_ratio and args.eval_ratio > 0:
# 简易头部抽样(流式下仅作粗评)
desired_eval_batches = 200 desired_eval_batches = 200
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True) tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter_eval(): def ex_iter_eval2():
for ex in tmp_stream: for ex in tmp_stream:
yield ex yield ex
eval_stream = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len) eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len)
eval_samples = [] eval_samples = []
it = iter(eval_stream) it = iter(eval_stream)
for _ in range(desired_eval_batches): for _ in range(desired_eval_batches):
@ -300,23 +330,14 @@ def main():
eval_samples.append(next(it)) eval_samples.append(next(it))
except StopIteration: except StopIteration:
break break
class ListDataset(torch.utils.data.Dataset): if len(eval_samples) > 0:
def __init__(self, items): self.items = items eval_dataset = ListDataset(eval_samples)
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
eval_dataset = ListDataset(eval_samples)
# 再重建训练流
ds_stream3 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def ex_iter3():
for ex in ds_stream3:
yield ex
train_stream = QwenChatSFTDataset(ex_iter3(), tokenizer, seq_len=args.seq_len)
data_collator = SFTDataCollator(tokenizer) data_collator = SFTDataCollator(tokenizer)
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs") logging_dir = os.path.join(args.output_dir, "logs")
os.makedirs(logging_dir, exist_ok=True)
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir=args.output_dir, output_dir=args.output_dir,