This commit is contained in:
parent
3b14067454
commit
061e51f975
71
run_ds.sh
71
run_ds.sh
|
|
@ -1,11 +1,9 @@
|
||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
# 可选:调试输出
|
|
||||||
# [ "${DEBUG:-0}" = "1" ] && set -x
|
# [ "${DEBUG:-0}" = "1" ] && set -x
|
||||||
|
|
||||||
export NCCL_DEBUG=INFO
|
export NCCL_DEBUG=INFO
|
||||||
# 如走 IB/RoCE,请按实际网卡开启:
|
# 如走 IB/RoCE,请按实际网卡开启(示例):
|
||||||
# export NCCL_IB_HCA="mlx5_0,mlx5_1"
|
# export NCCL_IB_HCA="mlx5_0,mlx5_1"
|
||||||
# export NCCL_SOCKET_IFNAME="ib0"
|
# export NCCL_SOCKET_IFNAME="ib0"
|
||||||
# 纯以太:
|
# 纯以太:
|
||||||
|
|
@ -15,10 +13,14 @@ export NCCL_DEBUG=INFO
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
DS_CONFIG="${DS_CONFIG:-$SCRIPT_DIR/ds_config_zero3.json}"
|
DS_CONFIG="${DS_CONFIG:-$SCRIPT_DIR/ds_config_zero3.json}"
|
||||||
|
|
||||||
# ==== 超参数(本地路径;可用 VAR=xxx ./launch_ds.sh 覆写)====
|
# ==== 超参数(本地路径;可用 VAR=xxx ./run_ds.sh 覆写)====
|
||||||
MODEL_NAME_OR_PATH="${MODEL_NAME_OR_PATH:-/home/test/Qwen3-8B}"
|
MODEL_NAME_OR_PATH="${MODEL_NAME_OR_PATH:-/home/test/Qwen3-8B}"
|
||||||
DATA_GLOB="${DATA_GLOB:-$HOME/datasets/my_corpus/*.jsonl}" # 每台机器都放相同路径
|
|
||||||
OUTDIR="${OUTDIR:-$HOME/checkpoints/run-qwen3-8b}" # 每台机器各自本地输出
|
# 明确区分训练/评测文件(可按需改成通配符)
|
||||||
|
DATA_GLOB="${DATA_GLOB:-$HOME/datasets/my_corpus/train.jsonl}"
|
||||||
|
EVAL_DATA_GLOB="${EVAL_DATA_GLOB:-$HOME/datasets/my_corpus/test.jsonl}"
|
||||||
|
|
||||||
|
OUTDIR="${OUTDIR:-$HOME/checkpoints/run-qwen3-8b}"
|
||||||
SEQ_LEN="${SEQ_LEN:-4096}"
|
SEQ_LEN="${SEQ_LEN:-4096}"
|
||||||
LR="${LR:-2e-5}"
|
LR="${LR:-2e-5}"
|
||||||
GAS="${GAS:-64}"
|
GAS="${GAS:-64}"
|
||||||
|
|
@ -26,33 +28,50 @@ LOG_STEPS="${LOG_STEPS:-10}"
|
||||||
SAVE_STEPS="${SAVE_STEPS:-500}"
|
SAVE_STEPS="${SAVE_STEPS:-500}"
|
||||||
MAX_STEPS="${MAX_STEPS:-10000}"
|
MAX_STEPS="${MAX_STEPS:-10000}"
|
||||||
|
|
||||||
# 轻量校验(只在发起节点做;远端各 rank 会在各自脚本里也 mkdir)
|
# 轻量校验(只在发起节点做;各 rank 在脚本里也会 mkdir)
|
||||||
[ -d "$MODEL_NAME_OR_PATH" ] || { echo "ERR: model not found at $MODEL_NAME_OR_PATH"; exit 1; }
|
[ -d "$MODEL_NAME_OR_PATH" ] || { echo "ERR: model not found at $MODEL_NAME_OR_PATH"; exit 1; }
|
||||||
[ -f "$DS_CONFIG" ] || { echo "ERR: deepspeed config not found at $DS_CONFIG"; exit 1; }
|
[ -f "$DS_CONFIG" ] || { echo "ERR: deepspeed config not found at $DS_CONFIG"; exit 1; }
|
||||||
# 数据是通配符,这里只做一次“至少匹配到 1 个文件”的检查(发起节点);各节点需自行确保相同路径存在
|
|
||||||
|
# 检查训练集是否能匹配到
|
||||||
shopt -s nullglob
|
shopt -s nullglob
|
||||||
matches=( $DATA_GLOB )
|
train_matches=( $DATA_GLOB )
|
||||||
if [ ${#matches[@]} -eq 0 ]; then
|
if [ ${#train_matches[@]} -eq 0 ]; then
|
||||||
echo "WARN: no files matched by DATA_GLOB=$DATA_GLOB on this node (确保每台机器该路径下有数据)"
|
echo "WARN: no files matched by DATA_GLOB=$DATA_GLOB on this node (确保每台机器该路径下有数据)"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# 检查评测集(可为空;若为空则不传 --eval_data_glob)
|
||||||
|
eval_matches=()
|
||||||
|
if [ -n "${EVAL_DATA_GLOB:-}" ]; then
|
||||||
|
eval_matches=( $EVAL_DATA_GLOB )
|
||||||
|
if [ ${#eval_matches[@]} -eq 0 ]; then
|
||||||
|
echo "WARN: no files matched by EVAL_DATA_GLOB=$EVAL_DATA_GLOB on this node (将不进行评测)"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
shopt -u nullglob
|
shopt -u nullglob
|
||||||
|
|
||||||
mkdir -p "${OUTDIR}"
|
mkdir -p "${OUTDIR}"
|
||||||
|
|
||||||
# ==== 多机 DeepSpeed ====
|
# 组装参数(只有 eval 有效才加)
|
||||||
deepspeed --hostfile "$SCRIPT_DIR/hostfile" "$SCRIPT_DIR/train_sft_ds.py" \
|
args=(
|
||||||
--model_name_or_path "${MODEL_NAME_OR_PATH}" \
|
--model_name_or_path "${MODEL_NAME_OR_PATH}"
|
||||||
--data_glob "${DATA_GLOB}" \
|
--data_glob "${DATA_GLOB}"
|
||||||
--output_dir "${OUTDIR}" \
|
--output_dir "${OUTDIR}"
|
||||||
--seq_len "${SEQ_LEN}" \
|
--seq_len "${SEQ_LEN}"
|
||||||
--learning_rate "${LR}" \
|
--learning_rate "${LR}"
|
||||||
--gradient_accumulation_steps "${GAS}" \
|
--gradient_accumulation_steps "${GAS}"
|
||||||
--per_device_train_batch_size 1 \
|
--per_device_train_batch_size 1
|
||||||
--warmup_ratio 0.02 \
|
--warmup_ratio 0.02
|
||||||
--weight_decay 0.1 \
|
--weight_decay 0.1
|
||||||
--max_steps "${MAX_STEPS}" \
|
--max_steps "${MAX_STEPS}"
|
||||||
--log_interval "${LOG_STEPS}" \
|
--log_interval "${LOG_STEPS}"
|
||||||
--save_steps "${SAVE_STEPS}" \
|
--save_steps "${SAVE_STEPS}"
|
||||||
--deepspeed "${DS_CONFIG}" \
|
--deepspeed "${DS_CONFIG}"
|
||||||
--gradient_checkpointing \
|
--gradient_checkpointing
|
||||||
--bf16
|
--bf16
|
||||||
|
)
|
||||||
|
if [ ${#eval_matches[@]} -gt 0 ]; then
|
||||||
|
args+=( --eval_data_glob "${EVAL_DATA_GLOB}" )
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ==== 多机 DeepSpeed ====
|
||||||
|
deepspeed --hostfile "$SCRIPT_DIR/hostfile" "$SCRIPT_DIR/train_sft_ds.py" "${args[@]}"
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import argparse
|
||||||
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset, Dataset
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
|
@ -203,7 +203,7 @@ def parse_args():
|
||||||
ap.add_argument("--max_steps", type=int, default=-1)
|
ap.add_argument("--max_steps", type=int, default=-1)
|
||||||
ap.add_argument("--log_interval", type=int, default=10)
|
ap.add_argument("--log_interval", type=int, default=10)
|
||||||
ap.add_argument("--save_steps", type=int, default=500)
|
ap.add_argument("--save_steps", type=int, default=500)
|
||||||
ap.add_argument("--eval_ratio", type=float, default=0.0) # 如需 eval,请准备 messages/工具同格式的数据
|
ap.add_argument("--eval_ratio", type=float, default=0.0) # 兜底抽样评估
|
||||||
ap.add_argument("--seed", type=int, default=1337)
|
ap.add_argument("--seed", type=int, default=1337)
|
||||||
ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
|
ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
|
||||||
ap.add_argument("--gradient_checkpointing", action="store_true")
|
ap.add_argument("--gradient_checkpointing", action="store_true")
|
||||||
|
|
@ -214,6 +214,8 @@ def parse_args():
|
||||||
ap.add_argument("--report_to", type=str, default="tensorboard",
|
ap.add_argument("--report_to", type=str, default="tensorboard",
|
||||||
choices=["none","tensorboard","wandb"])
|
choices=["none","tensorboard","wandb"])
|
||||||
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
|
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
|
||||||
|
ap.add_argument("--eval_data_glob", type=str, default=None,
|
||||||
|
help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用")
|
||||||
return ap.parse_args()
|
return ap.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -283,16 +285,44 @@ def main():
|
||||||
yield ex
|
yield ex
|
||||||
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
||||||
|
|
||||||
# 可选 eval:如果你准备了 messages/同模板的 eval 数据,建议用单独 glob;这里维持与你原逻辑相近的“头部抽样”
|
# ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
|
||||||
eval_dataset = None
|
eval_dataset: Optional[Dataset] = None
|
||||||
if args.eval_ratio and args.eval_ratio > 0:
|
|
||||||
# 简单抽若干样本作为 eval(注意:streaming 情况下这只是粗略评估)
|
class ListDataset(Dataset):
|
||||||
|
def __init__(self, items): self.items = items
|
||||||
|
def __len__(self): return len(self.items)
|
||||||
|
def __getitem__(self, idx): return self.items[idx]
|
||||||
|
|
||||||
|
if args.eval_data_glob:
|
||||||
|
eval_files = sorted(glob.glob(args.eval_data_glob))
|
||||||
|
if len(eval_files) == 0:
|
||||||
|
raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}")
|
||||||
|
if is_main_process():
|
||||||
|
print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True)
|
||||||
|
|
||||||
|
ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True)
|
||||||
|
def ex_iter_eval():
|
||||||
|
for ex in ds_eval_stream:
|
||||||
|
yield ex
|
||||||
|
|
||||||
|
eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
|
||||||
|
eval_items: List[Dict[str, torch.Tensor]] = []
|
||||||
|
for sample in eval_iterable:
|
||||||
|
eval_items.append(sample)
|
||||||
|
|
||||||
|
if len(eval_items) == 0:
|
||||||
|
raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。")
|
||||||
|
|
||||||
|
eval_dataset = ListDataset(eval_items)
|
||||||
|
|
||||||
|
elif args.eval_ratio and args.eval_ratio > 0:
|
||||||
|
# 简易头部抽样(流式下仅作粗评)
|
||||||
desired_eval_batches = 200
|
desired_eval_batches = 200
|
||||||
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||||
def ex_iter_eval():
|
def ex_iter_eval2():
|
||||||
for ex in tmp_stream:
|
for ex in tmp_stream:
|
||||||
yield ex
|
yield ex
|
||||||
eval_stream = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len)
|
eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len)
|
||||||
eval_samples = []
|
eval_samples = []
|
||||||
it = iter(eval_stream)
|
it = iter(eval_stream)
|
||||||
for _ in range(desired_eval_batches):
|
for _ in range(desired_eval_batches):
|
||||||
|
|
@ -300,23 +330,14 @@ def main():
|
||||||
eval_samples.append(next(it))
|
eval_samples.append(next(it))
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
class ListDataset(torch.utils.data.Dataset):
|
if len(eval_samples) > 0:
|
||||||
def __init__(self, items): self.items = items
|
eval_dataset = ListDataset(eval_samples)
|
||||||
def __len__(self): return len(self.items)
|
|
||||||
def __getitem__(self, idx): return self.items[idx]
|
|
||||||
eval_dataset = ListDataset(eval_samples)
|
|
||||||
|
|
||||||
# 再重建训练流
|
|
||||||
ds_stream3 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
|
||||||
def ex_iter3():
|
|
||||||
for ex in ds_stream3:
|
|
||||||
yield ex
|
|
||||||
train_stream = QwenChatSFTDataset(ex_iter3(), tokenizer, seq_len=args.seq_len)
|
|
||||||
|
|
||||||
data_collator = SFTDataCollator(tokenizer)
|
data_collator = SFTDataCollator(tokenizer)
|
||||||
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
logging_dir = os.path.join(args.output_dir, "logs")
|
logging_dir = os.path.join(args.output_dir, "logs")
|
||||||
|
os.makedirs(logging_dir, exist_ok=True)
|
||||||
|
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir=args.output_dir,
|
output_dir=args.output_dir,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue