first commit

This commit is contained in:
hailin 2025-08-25 10:59:49 +08:00
commit f82c5831ee
4 changed files with 327 additions and 0 deletions

33
ds_config_zero3.json Normal file
View File

@ -0,0 +1,33 @@
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 64,
"steps_per_print": 0,
"gradient_clipping": 1.0,
"fp16": { "enabled": false },
"bf16": { "enabled": true },
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_scatter": true,
"reduce_bucket_size": 50000000,
"stage3_prefetch_bucket_size": 50000000,
"stage3_param_persistence_threshold": 100000,
/* 16bit 便 from_pretrained */
"stage3_gather_16bit_weights_on_model_save": true
},
"activation_checkpointing": {
"partition_activations": true,
"contiguous_memory_optimization": true,
"cpu_checkpointing": false,
"number_checkpoints": 36,
"profile": false,
"synchronize_checkpoint_boundary": true
},
"wall_clock_breakdown": false
}

6
hostfile Normal file
View File

@ -0,0 +1,6 @@
tn01 slots=4
tn02 slots=4
tn03 slots=4
tn04 slots=4
tn05 slots=4
tn06 slots=4

40
run_ds.sh Normal file
View File

@ -0,0 +1,40 @@
#!/usr/bin/env bash
set -euo pipefail
export NCCL_DEBUG=INFO
# 如走 IB/RoCE请按实际网卡开启
# export NCCL_IB_HCA="mlx5_0,mlx5_1"
# export NCCL_SOCKET_IFNAME="ib0"
# 纯以太:
# export NCCL_SOCKET_IFNAME="eth0"
# ==== 超参数(本地路径;可用 VAR=xxx ./launch_ds.sh 覆写)====
MODEL_NAME_OR_PATH="${MODEL_NAME_OR_PATH:-/home/test/Qwen3-8B}"
DATA_GLOB="${DATA_GLOB:-/data/datasets/my_corpus/*.jsonl}" # 每台机器都放相同路径
OUTDIR="${OUTDIR:-/data/checkpoints/run-qwen3-8b}" # 每台机器各自本地输出
SEQ_LEN="${SEQ_LEN:-4096}"
LR="${LR:-2e-5}"
GAS="${GAS:-64}"
LOG_STEPS="${LOG_STEPS:-10}"
SAVE_STEPS="${SAVE_STEPS:-500}"
MAX_STEPS="${MAX_STEPS:-10000}"
mkdir -p "${OUTDIR}"
# ==== 多机 DeepSpeed ====
deepspeed --hostfile hostfile train_sft_ds.py \
--model_name_or_path "${MODEL_NAME_OR_PATH}" \
--data_glob "${DATA_GLOB}" \
--output_dir "${OUTDIR}" \
--seq_len "${SEQ_LEN}" \
--learning_rate "${LR}" \
--gradient_accumulation_steps "${GAS}" \
--per_device_train_batch_size 1 \
--warmup_ratio 0.02 \
--weight_decay 0.1 \
--max_steps "${MAX_STEPS}" \
--log_interval "${LOG_STEPS}" \
--save_steps "${SAVE_STEPS}" \
--deepspeed ds_config_zero3.json \
--gradient_checkpointing \
--bf16

248
train_sft_ds.py Normal file
View File

@ -0,0 +1,248 @@
#!/usr/bin/env python3
import os
import glob
import argparse
from typing import Dict, List, Iterable, Iterator
import torch
from torch.utils.data import IterableDataset
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
set_seed
)
from transformers.trainer_callback import TrainerCallback
def is_main_process():
return int(os.environ.get("RANK", "0")) == 0
def print_once(*args, **kwargs):
if is_main_process():
print(*args, **kwargs, flush=True)
class ConstantLengthDataset(IterableDataset):
def __init__(self,
texts_iter: Iterable[str],
tokenizer: AutoTokenizer,
seq_len: int = 4096,
buffer_size: int = 1024 * 1024):
self.texts_iter = texts_iter
self.tokenizer = tokenizer
self.seq_len = seq_len
self.buffer_size = buffer_size
def __iter__(self):
buffer_texts: List[str] = []
token_buffer: List[int] = []
for txt in self.texts_iter:
if not txt:
continue
buffer_texts.append(txt)
if len(buffer_texts) >= 1024:
enc = self.tokenizer(buffer_texts, add_special_tokens=False)['input_ids']
for ids in enc:
token_buffer.extend(ids + [self.tokenizer.eos_token_id])
buffer_texts.clear()
while len(token_buffer) >= self.seq_len:
chunk = token_buffer[:self.seq_len]
del token_buffer[:self.seq_len]
yield {
"input_ids": torch.tensor(chunk, dtype=torch.long),
"attention_mask": torch.ones(self.seq_len, dtype=torch.long),
"labels": torch.tensor(chunk, dtype=torch.long)
}
if buffer_texts:
enc = self.tokenizer(buffer_texts, add_special_tokens=False)['input_ids']
for ids in enc:
token_buffer.extend(ids + [self.tokenizer.eos_token_id])
while len(token_buffer) >= self.seq_len:
chunk = token_buffer[:self.seq_len]
del token_buffer[:self.seq_len]
yield {
"input_ids": torch.tensor(chunk, dtype=torch.long),
"attention_mask": torch.ones(self.seq_len, dtype=torch.long),
"labels": torch.tensor(chunk, dtype=torch.long)
}
class CsvLossLogger(TrainerCallback):
def __init__(self, csv_path: str):
self.csv_path = csv_path
if is_main_process():
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
with open(self.csv_path, "w", encoding="utf-8") as f:
f.write("step,loss,lr,total_flos\n")
def on_log(self, args, state, control, logs=None, **kwargs):
if not is_main_process() or logs is None:
return
with open(self.csv_path, "a", encoding="utf-8") as f:
f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--model_name_or_path", type=str, required=True,
help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B")
ap.add_argument("--data_glob", type=str, required=True,
help="本地 jsonl 通配符(每台机器都需有同路径数据)")
ap.add_argument("--output_dir", type=str, required=True,
help="本地输出目录(各节点各自本地写)")
ap.add_argument("--seq_len", type=int, default=4096)
ap.add_argument("--learning_rate", type=float, default=2e-5)
ap.add_argument("--weight_decay", type=float, default=0.1)
ap.add_argument("--warmup_ratio", type=float, default=0.02)
ap.add_argument("--num_train_epochs", type=float, default=1.0)
ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0)
ap.add_argument("--seed", type=int, default=1337)
ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
ap.add_argument("--gradient_checkpointing", action="store_true")
ap.add_argument("--bf16", action="store_true",
help="3090/A100/H100 等可开 bf16同时在 DS 配置里也要开")
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
ap.add_argument("--report_to", type=str, default="tensorboard",
choices=["none","tensorboard","wandb"])
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
return ap.parse_args()
def main():
args = parse_args()
set_seed(args.seed)
# Tokenizer/Model
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
low_cpu_mem_usage=True,
trust_remote_code=True
)
model.config.use_cache = False
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
# 本地数据(每台机器同路径各自读取)
files = sorted(glob.glob(args.data_glob))
if len(files) == 0:
raise FileNotFoundError(f"No files match: {args.data_glob}")
dataset_iter = load_dataset(
"json",
data_files={"train": files},
split="train",
streaming=True
)
def text_iter():
for ex in dataset_iter:
txt = ex.get("text", None)
if isinstance(txt, str) and len(txt.strip()) > 0:
yield txt
train_stream = ConstantLengthDataset(texts_iter=text_iter(), tokenizer=tokenizer, seq_len=args.seq_len)
eval_dataset = None
if args.eval_ratio and args.eval_ratio > 0:
# 简易 eval从训练流头部抽若干 batch可按需关闭
desired_eval_batches = 200
gen = iter(train_stream)
eval_samples = []
for _ in range(desired_eval_batches):
try: eval_samples.append(next(gen))
except StopIteration: break
class ListDataset(torch.utils.data.Dataset):
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
eval_dataset = ListDataset(eval_samples)
# 重新构造训练流
dataset_iter2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def text_iter2():
for ex in dataset_iter2:
txt = ex.get("text", None)
if isinstance(txt, str) and len(txt.strip()) > 0:
yield txt
train_stream = ConstantLengthDataset(texts_iter=text_iter2(), tokenizer=tokenizer, seq_len=args.seq_len)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs")
# 关键点无共享盘时HF/DS 的分片 checkpoint 会在每个 rank 的本地 output_dir 下各自写入
# 训练恢复时,各 rank 从各自本地的相同路径读取自己的分片即可(保持相同 world_size 即可)
training_args = TrainingArguments(
output_dir=args.output_dir,
logging_dir=logging_dir,
do_train=True,
do_eval=(eval_dataset is not None),
evaluation_strategy=("steps" if eval_dataset is not None else "no"),
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
max_steps=args.max_steps if args.max_steps > 0 else -1,
lr_scheduler_type="cosine",
logging_steps=args.log_interval,
save_steps=args.save_steps,
save_total_limit=2,
deepspeed=args.deepspeed,
dataloader_drop_last=True,
dataloader_num_workers=2,
report_to=([] if args.report_to == "none" else [args.report_to]),
bf16=args.bf16,
fp16=(not args.bf16),
gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False,
torch_compile=False,
save_on_each_node=False # 保持默认DeepSpeed 的分片 checkpoint 本就会各 rank 本地写
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_stream,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=data_collator
)
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
ckpt_exists = (os.path.isdir(args.output_dir)
and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir)))
resume_flag = True if ckpt_exists else None
print_once(f"Resume = {resume_flag is True}")
print_once("***** Starting training *****")
train_result = trainer.train(resume_from_checkpoint=resume_flag)
trainer.save_model() # 配合 DS 配置中的 gather会在全局 rank0 聚合保存 16-bit 整模型
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if eval_dataset is not None:
print_once("***** Running eval *****")
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)
print_once("Done.")
if __name__ == "__main__":
main()