This commit is contained in:
parent
3b14067454
commit
061e51f975
71
run_ds.sh
71
run_ds.sh
|
|
@ -1,11 +1,9 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# 可选:调试输出
|
||||
# [ "${DEBUG:-0}" = "1" ] && set -x
|
||||
|
||||
export NCCL_DEBUG=INFO
|
||||
# 如走 IB/RoCE,请按实际网卡开启:
|
||||
# 如走 IB/RoCE,请按实际网卡开启(示例):
|
||||
# export NCCL_IB_HCA="mlx5_0,mlx5_1"
|
||||
# export NCCL_SOCKET_IFNAME="ib0"
|
||||
# 纯以太:
|
||||
|
|
@ -15,10 +13,14 @@ export NCCL_DEBUG=INFO
|
|||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
DS_CONFIG="${DS_CONFIG:-$SCRIPT_DIR/ds_config_zero3.json}"
|
||||
|
||||
# ==== 超参数(本地路径;可用 VAR=xxx ./launch_ds.sh 覆写)====
|
||||
# ==== 超参数(本地路径;可用 VAR=xxx ./run_ds.sh 覆写)====
|
||||
MODEL_NAME_OR_PATH="${MODEL_NAME_OR_PATH:-/home/test/Qwen3-8B}"
|
||||
DATA_GLOB="${DATA_GLOB:-$HOME/datasets/my_corpus/*.jsonl}" # 每台机器都放相同路径
|
||||
OUTDIR="${OUTDIR:-$HOME/checkpoints/run-qwen3-8b}" # 每台机器各自本地输出
|
||||
|
||||
# 明确区分训练/评测文件(可按需改成通配符)
|
||||
DATA_GLOB="${DATA_GLOB:-$HOME/datasets/my_corpus/train.jsonl}"
|
||||
EVAL_DATA_GLOB="${EVAL_DATA_GLOB:-$HOME/datasets/my_corpus/test.jsonl}"
|
||||
|
||||
OUTDIR="${OUTDIR:-$HOME/checkpoints/run-qwen3-8b}"
|
||||
SEQ_LEN="${SEQ_LEN:-4096}"
|
||||
LR="${LR:-2e-5}"
|
||||
GAS="${GAS:-64}"
|
||||
|
|
@ -26,33 +28,50 @@ LOG_STEPS="${LOG_STEPS:-10}"
|
|||
SAVE_STEPS="${SAVE_STEPS:-500}"
|
||||
MAX_STEPS="${MAX_STEPS:-10000}"
|
||||
|
||||
# 轻量校验(只在发起节点做;远端各 rank 会在各自脚本里也 mkdir)
|
||||
# 轻量校验(只在发起节点做;各 rank 在脚本里也会 mkdir)
|
||||
[ -d "$MODEL_NAME_OR_PATH" ] || { echo "ERR: model not found at $MODEL_NAME_OR_PATH"; exit 1; }
|
||||
[ -f "$DS_CONFIG" ] || { echo "ERR: deepspeed config not found at $DS_CONFIG"; exit 1; }
|
||||
# 数据是通配符,这里只做一次“至少匹配到 1 个文件”的检查(发起节点);各节点需自行确保相同路径存在
|
||||
|
||||
# 检查训练集是否能匹配到
|
||||
shopt -s nullglob
|
||||
matches=( $DATA_GLOB )
|
||||
if [ ${#matches[@]} -eq 0 ]; then
|
||||
train_matches=( $DATA_GLOB )
|
||||
if [ ${#train_matches[@]} -eq 0 ]; then
|
||||
echo "WARN: no files matched by DATA_GLOB=$DATA_GLOB on this node (确保每台机器该路径下有数据)"
|
||||
fi
|
||||
|
||||
# 检查评测集(可为空;若为空则不传 --eval_data_glob)
|
||||
eval_matches=()
|
||||
if [ -n "${EVAL_DATA_GLOB:-}" ]; then
|
||||
eval_matches=( $EVAL_DATA_GLOB )
|
||||
if [ ${#eval_matches[@]} -eq 0 ]; then
|
||||
echo "WARN: no files matched by EVAL_DATA_GLOB=$EVAL_DATA_GLOB on this node (将不进行评测)"
|
||||
fi
|
||||
fi
|
||||
shopt -u nullglob
|
||||
|
||||
mkdir -p "${OUTDIR}"
|
||||
|
||||
# ==== 多机 DeepSpeed ====
|
||||
deepspeed --hostfile "$SCRIPT_DIR/hostfile" "$SCRIPT_DIR/train_sft_ds.py" \
|
||||
--model_name_or_path "${MODEL_NAME_OR_PATH}" \
|
||||
--data_glob "${DATA_GLOB}" \
|
||||
--output_dir "${OUTDIR}" \
|
||||
--seq_len "${SEQ_LEN}" \
|
||||
--learning_rate "${LR}" \
|
||||
--gradient_accumulation_steps "${GAS}" \
|
||||
--per_device_train_batch_size 1 \
|
||||
--warmup_ratio 0.02 \
|
||||
--weight_decay 0.1 \
|
||||
--max_steps "${MAX_STEPS}" \
|
||||
--log_interval "${LOG_STEPS}" \
|
||||
--save_steps "${SAVE_STEPS}" \
|
||||
--deepspeed "${DS_CONFIG}" \
|
||||
--gradient_checkpointing \
|
||||
# 组装参数(只有 eval 有效才加)
|
||||
args=(
|
||||
--model_name_or_path "${MODEL_NAME_OR_PATH}"
|
||||
--data_glob "${DATA_GLOB}"
|
||||
--output_dir "${OUTDIR}"
|
||||
--seq_len "${SEQ_LEN}"
|
||||
--learning_rate "${LR}"
|
||||
--gradient_accumulation_steps "${GAS}"
|
||||
--per_device_train_batch_size 1
|
||||
--warmup_ratio 0.02
|
||||
--weight_decay 0.1
|
||||
--max_steps "${MAX_STEPS}"
|
||||
--log_interval "${LOG_STEPS}"
|
||||
--save_steps "${SAVE_STEPS}"
|
||||
--deepspeed "${DS_CONFIG}"
|
||||
--gradient_checkpointing
|
||||
--bf16
|
||||
)
|
||||
if [ ${#eval_matches[@]} -gt 0 ]; then
|
||||
args+=( --eval_data_glob "${EVAL_DATA_GLOB}" )
|
||||
fi
|
||||
|
||||
# ==== 多机 DeepSpeed ====
|
||||
deepspeed --hostfile "$SCRIPT_DIR/hostfile" "$SCRIPT_DIR/train_sft_ds.py" "${args[@]}"
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import argparse
|
|||
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
||||
|
||||
import torch
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data import IterableDataset, Dataset
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
|
|
@ -203,7 +203,7 @@ def parse_args():
|
|||
ap.add_argument("--max_steps", type=int, default=-1)
|
||||
ap.add_argument("--log_interval", type=int, default=10)
|
||||
ap.add_argument("--save_steps", type=int, default=500)
|
||||
ap.add_argument("--eval_ratio", type=float, default=0.0) # 如需 eval,请准备 messages/工具同格式的数据
|
||||
ap.add_argument("--eval_ratio", type=float, default=0.0) # 兜底抽样评估
|
||||
ap.add_argument("--seed", type=int, default=1337)
|
||||
ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
|
||||
ap.add_argument("--gradient_checkpointing", action="store_true")
|
||||
|
|
@ -214,6 +214,8 @@ def parse_args():
|
|||
ap.add_argument("--report_to", type=str, default="tensorboard",
|
||||
choices=["none","tensorboard","wandb"])
|
||||
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
|
||||
ap.add_argument("--eval_data_glob", type=str, default=None,
|
||||
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
|
||||
return ap.parse_args()
|
||||
|
||||
|
||||
|
|
@ -283,16 +285,44 @@ def main():
|
|||
yield ex
|
||||
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
# 可选 eval:如果你准备了 messages/同模板的 eval 数据,建议用单独 glob;这里维持与你原逻辑相近的“头部抽样”
|
||||
eval_dataset = None
|
||||
if args.eval_ratio and args.eval_ratio > 0:
|
||||
# 简单抽若干样本作为 eval(注意:streaming 情况下这只是粗略评估)
|
||||
# ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
|
||||
eval_dataset: Optional[Dataset] = None
|
||||
|
||||
class ListDataset(Dataset):
|
||||
def __init__(self, items): self.items = items
|
||||
def __len__(self): return len(self.items)
|
||||
def __getitem__(self, idx): return self.items[idx]
|
||||
|
||||
if args.eval_data_glob:
|
||||
eval_files = sorted(glob.glob(args.eval_data_glob))
|
||||
if len(eval_files) == 0:
|
||||
raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
|
||||
if is_main_process():
|
||||
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
|
||||
|
||||
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
|
||||
def ex_iter_eval():
|
||||
for ex in ds_eval_stream:
|
||||
yield ex
|
||||
|
||||
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
|
||||
eval_items: List[Dict[str, torch.Tensor]] = []
|
||||
for sample in eval_iterable:
|
||||
eval_items.append(sample)
|
||||
|
||||
if len(eval_items) == 0:
|
||||
raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。")
|
||||
|
||||
eval_dataset = ListDataset(eval_items)
|
||||
|
||||
elif args.eval_ratio and args.eval_ratio > 0:
|
||||
# 简易头部抽样(流式下仅作粗评)
|
||||
desired_eval_batches = 200
|
||||
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
def ex_iter_eval():
|
||||
def ex_iter_eval2():
|
||||
for ex in tmp_stream:
|
||||
yield ex
|
||||
eval_stream = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
|
||||
eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len)
|
||||
eval_samples = []
|
||||
it = iter(eval_stream)
|
||||
for _ in range(desired_eval_batches):
|
||||
|
|
@ -300,23 +330,14 @@ def main():
|
|||
eval_samples.append(next(it))
|
||||
except StopIteration:
|
||||
break
|
||||
class ListDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, items): self.items = items
|
||||
def __len__(self): return len(self.items)
|
||||
def __getitem__(self, idx): return self.items[idx]
|
||||
eval_dataset = ListDataset(eval_samples)
|
||||
|
||||
# 再重建训练流
|
||||
ds_stream3 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
def ex_iter3():
|
||||
for ex in ds_stream3:
|
||||
yield ex
|
||||
train_stream = QwenChatSFTDataset(ex_iter3(), tokenizer, seq_len=args.seq_len)
|
||||
if len(eval_samples) > 0:
|
||||
eval_dataset = ListDataset(eval_samples)
|
||||
|
||||
data_collator = SFTDataCollator(tokenizer)
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
logging_dir = os.path.join(args.output_dir, "logs")
|
||||
os.makedirs(logging_dir, exist_ok=True)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.output_dir,
|
||||
|
|
|
|||
Loading…
Reference in New Issue