jd_train/train_sft_ds.py

281 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import os
import glob
import socket
import argparse
from typing import Dict, List, Iterable, Iterator
import torch
from torch.utils.data import IterableDataset
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
set_seed
)
from transformers.trainer_callback import TrainerCallback
def is_main_process():
return int(os.environ.get("RANK", "0")) == 0
def print_once(*args, **kwargs):
if is_main_process():
print(*args, **kwargs, flush=True)
class ConstantLengthDataset(IterableDataset):
def __init__(self,
texts_iter: Iterable[str],
tokenizer: AutoTokenizer,
seq_len: int = 4096,
buffer_size: int = 1024 * 1024):
self.texts_iter = texts_iter
self.tokenizer = tokenizer
self.seq_len = seq_len
self.buffer_size = buffer_size
def __iter__(self):
buffer_texts: List[str] = []
token_buffer: List[int] = []
for txt in self.texts_iter:
if not txt:
continue
buffer_texts.append(txt)
if len(buffer_texts) >= 1024:
enc = self.tokenizer(buffer_texts, add_special_tokens=False)['input_ids']
for ids in enc:
token_buffer.extend(ids + [self.tokenizer.eos_token_id])
buffer_texts.clear()
while len(token_buffer) >= self.seq_len:
chunk = token_buffer[:self.seq_len]
del token_buffer[:self.seq_len]
yield {
"input_ids": torch.tensor(chunk, dtype=torch.long),
"attention_mask": torch.ones(self.seq_len, dtype=torch.long),
"labels": torch.tensor(chunk, dtype=torch.long)
}
if buffer_texts:
enc = self.tokenizer(buffer_texts, add_special_tokens=False)['input_ids']
for ids in enc:
token_buffer.extend(ids + [self.tokenizer.eos_token_id])
while len(token_buffer) >= self.seq_len:
chunk = token_buffer[:self.seq_len]
del token_buffer[:self.seq_len]
yield {
"input_ids": torch.tensor(chunk, dtype=torch.long),
"attention_mask": torch.ones(self.seq_len, dtype=torch.long),
"labels": torch.tensor(chunk, dtype=torch.long)
}
class CsvLossLogger(TrainerCallback):
def __init__(self, csv_path: str):
self.csv_path = csv_path
if is_main_process():
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
with open(self.csv_path, "w", encoding="utf-8") as f:
f.write("step,loss,lr,total_flos\n")
def on_log(self, args, state, control, logs=None, **kwargs):
if not is_main_process() or logs is None:
return
with open(self.csv_path, "a", encoding="utf-8") as f:
f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--model_name_or_path", type=str, required=True,
help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B")
ap.add_argument("--data_glob", type=str, required=True,
help="本地 jsonl 通配符(每台机器都需有同路径数据)")
ap.add_argument("--output_dir", type=str, required=True,
help="本地输出目录(各节点各自本地写)")
ap.add_argument("--seq_len", type=int, default=4096)
ap.add_argument("--learning_rate", type=float, default=2e-5)
ap.add_argument("--weight_decay", type=float, default=0.1)
ap.add_argument("--warmup_ratio", type=float, default=0.02)
ap.add_argument("--num_train_epochs", type=float, default=1.0)
ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0)
ap.add_argument("--seed", type=int, default=1337)
ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
ap.add_argument("--gradient_checkpointing", action="store_true")
ap.add_argument("--bf16", action="store_true",
help="3090/A100/H100 等可开 bf16同时在 DS 配置里也要开")
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
ap.add_argument("--report_to", type=str, default="tensorboard",
choices=["none","tensorboard","wandb"])
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
return ap.parse_args()
def main():
args = parse_args()
set_seed(args.seed)
# Tokenizer/Model
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
low_cpu_mem_usage=True,
trust_remote_code=True
)
model.config.use_cache = False
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
# ===== 数据鲁棒性检查(多机各自执行)=====
host = socket.gethostname()
rank = int(os.environ.get("RANK", "0"))
files = sorted(glob.glob(args.data_glob))
if len(files) == 0:
raise FileNotFoundError(
f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n"
"每台机器都必须在相同本地路径下放置数据;"
"可通过 DATA_GLOB=<your_glob> ./launch_ds.sh 覆写。"
)
if is_main_process():
print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True)
# streaming 逐行读取,字段名为 'text'
dataset_iter = load_dataset(
"json",
data_files={"train": files},
split="train",
streaming=True
)
def text_iter():
for ex in dataset_iter:
txt = ex.get("text", None)
if isinstance(txt, str) and len(txt.strip()) > 0:
yield txt
# 先构造一次流,做“非空探针”
train_stream_probe = ConstantLengthDataset(texts_iter=text_iter(), tokenizer=tokenizer, seq_len=args.seq_len)
_probe = iter(train_stream_probe)
try:
_ = next(_probe) # 拉一个 chunk确保真的能产出训练样本
except StopIteration:
raise RuntimeError(
f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n"
"常见原因jsonl 缺少 'text' 字段、内容全为空/空白行、或 --seq_len 过大。\n"
"请检查样例行,或将 --seq_len 调小后再试。"
)
# 探针消耗了流,重新构造一次“干净”的训练流
dataset_iter2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def text_iter2():
for ex in dataset_iter2:
txt = ex.get("text", None)
if isinstance(txt, str) and len(txt.strip()) > 0:
yield txt
train_stream = ConstantLengthDataset(texts_iter=text_iter2(), tokenizer=tokenizer, seq_len=args.seq_len)
# 可选 eval从头部抽样
eval_dataset = None
if args.eval_ratio and args.eval_ratio > 0:
desired_eval_batches = 200
gen = iter(train_stream)
eval_samples = []
for _ in range(desired_eval_batches):
try:
eval_samples.append(next(gen))
except StopIteration:
break
class ListDataset(torch.utils.data.Dataset):
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
eval_dataset = ListDataset(eval_samples)
# 抽样后再重建训练流,防止“吃掉”头部
dataset_iter3 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def text_iter3():
for ex in dataset_iter3:
txt = ex.get("text", None)
if isinstance(txt, str) and len(txt.strip()) > 0:
yield txt
train_stream = ConstantLengthDataset(texts_iter=text_iter3(), tokenizer=tokenizer, seq_len=args.seq_len)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs")
# 无共享盘:各 rank 在各自本地 output_dir 下写入自己的分片
training_args = TrainingArguments(
output_dir=args.output_dir,
logging_dir=logging_dir,
do_train=True,
do_eval=(eval_dataset is not None),
evaluation_strategy=("steps" if eval_dataset is not None else "no"),
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
max_steps=args.max_steps if args.max_steps > 0 else -1,
lr_scheduler_type="cosine",
logging_steps=args.log_interval,
save_steps=args.save_steps,
save_total_limit=2,
deepspeed=args.deepspeed,
dataloader_drop_last=True,
dataloader_num_workers=2,
report_to=([] if args.report_to == "none" else [args.report_to]),
bf16=args.bf16,
fp16=(not args.bf16),
gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False,
torch_compile=False,
save_on_each_node=False
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_stream,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=data_collator
)
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
ckpt_exists = (os.path.isdir(args.output_dir)
and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir)))
resume_flag = True if ckpt_exists else None
print_once(f"[host={host}] Resume = {resume_flag is True}")
print_once("***** Starting training *****")
train_result = trainer.train(resume_from_checkpoint=resume_flag)
trainer.save_model() # 配合 DS 配置 stage3_gather_16bit_weights_on_model_save=true仅在全局 rank0 聚合保存整模型
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if eval_dataset is not None:
print_once("***** Running eval *****")
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)
print_once("Done.")
if __name__ == "__main__":
main()