jd_train/train_sft_ds.py

249 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import os
import glob
import argparse
from typing import Dict, List, Iterable, Iterator
import torch
from torch.utils.data import IterableDataset
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
set_seed
)
from transformers.trainer_callback import TrainerCallback
def is_main_process():
return int(os.environ.get("RANK", "0")) == 0
def print_once(*args, **kwargs):
if is_main_process():
print(*args, **kwargs, flush=True)
class ConstantLengthDataset(IterableDataset):
def __init__(self,
texts_iter: Iterable[str],
tokenizer: AutoTokenizer,
seq_len: int = 4096,
buffer_size: int = 1024 * 1024):
self.texts_iter = texts_iter
self.tokenizer = tokenizer
self.seq_len = seq_len
self.buffer_size = buffer_size
def __iter__(self):
buffer_texts: List[str] = []
token_buffer: List[int] = []
for txt in self.texts_iter:
if not txt:
continue
buffer_texts.append(txt)
if len(buffer_texts) >= 1024:
enc = self.tokenizer(buffer_texts, add_special_tokens=False)['input_ids']
for ids in enc:
token_buffer.extend(ids + [self.tokenizer.eos_token_id])
buffer_texts.clear()
while len(token_buffer) >= self.seq_len:
chunk = token_buffer[:self.seq_len]
del token_buffer[:self.seq_len]
yield {
"input_ids": torch.tensor(chunk, dtype=torch.long),
"attention_mask": torch.ones(self.seq_len, dtype=torch.long),
"labels": torch.tensor(chunk, dtype=torch.long)
}
if buffer_texts:
enc = self.tokenizer(buffer_texts, add_special_tokens=False)['input_ids']
for ids in enc:
token_buffer.extend(ids + [self.tokenizer.eos_token_id])
while len(token_buffer) >= self.seq_len:
chunk = token_buffer[:self.seq_len]
del token_buffer[:self.seq_len]
yield {
"input_ids": torch.tensor(chunk, dtype=torch.long),
"attention_mask": torch.ones(self.seq_len, dtype=torch.long),
"labels": torch.tensor(chunk, dtype=torch.long)
}
class CsvLossLogger(TrainerCallback):
def __init__(self, csv_path: str):
self.csv_path = csv_path
if is_main_process():
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
with open(self.csv_path, "w", encoding="utf-8") as f:
f.write("step,loss,lr,total_flos\n")
def on_log(self, args, state, control, logs=None, **kwargs):
if not is_main_process() or logs is None:
return
with open(self.csv_path, "a", encoding="utf-8") as f:
f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--model_name_or_path", type=str, required=True,
help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B")
ap.add_argument("--data_glob", type=str, required=True,
help="本地 jsonl 通配符(每台机器都需有同路径数据)")
ap.add_argument("--output_dir", type=str, required=True,
help="本地输出目录(各节点各自本地写)")
ap.add_argument("--seq_len", type=int, default=4096)
ap.add_argument("--learning_rate", type=float, default=2e-5)
ap.add_argument("--weight_decay", type=float, default=0.1)
ap.add_argument("--warmup_ratio", type=float, default=0.02)
ap.add_argument("--num_train_epochs", type=float, default=1.0)
ap.add_argument("--max_steps", type=int, default=-1)
ap.add_argument("--log_interval", type=int, default=10)
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--eval_ratio", type=float, default=0.0)
ap.add_argument("--seed", type=int, default=1337)
ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
ap.add_argument("--gradient_checkpointing", action="store_true")
ap.add_argument("--bf16", action="store_true",
help="3090/A100/H100 等可开 bf16同时在 DS 配置里也要开")
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
ap.add_argument("--report_to", type=str, default="tensorboard",
choices=["none","tensorboard","wandb"])
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
return ap.parse_args()
def main():
args = parse_args()
set_seed(args.seed)
# Tokenizer/Model
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
low_cpu_mem_usage=True,
trust_remote_code=True
)
model.config.use_cache = False
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
# 本地数据(每台机器同路径各自读取)
files = sorted(glob.glob(args.data_glob))
if len(files) == 0:
raise FileNotFoundError(f"No files match: {args.data_glob}")
dataset_iter = load_dataset(
"json",
data_files={"train": files},
split="train",
streaming=True
)
def text_iter():
for ex in dataset_iter:
txt = ex.get("text", None)
if isinstance(txt, str) and len(txt.strip()) > 0:
yield txt
train_stream = ConstantLengthDataset(texts_iter=text_iter(), tokenizer=tokenizer, seq_len=args.seq_len)
eval_dataset = None
if args.eval_ratio and args.eval_ratio > 0:
# 简易 eval从训练流头部抽若干 batch可按需关闭
desired_eval_batches = 200
gen = iter(train_stream)
eval_samples = []
for _ in range(desired_eval_batches):
try: eval_samples.append(next(gen))
except StopIteration: break
class ListDataset(torch.utils.data.Dataset):
def __init__(self, items): self.items = items
def __len__(self): return len(self.items)
def __getitem__(self, idx): return self.items[idx]
eval_dataset = ListDataset(eval_samples)
# 重新构造训练流
dataset_iter2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
def text_iter2():
for ex in dataset_iter2:
txt = ex.get("text", None)
if isinstance(txt, str) and len(txt.strip()) > 0:
yield txt
train_stream = ConstantLengthDataset(texts_iter=text_iter2(), tokenizer=tokenizer, seq_len=args.seq_len)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs")
# 关键点无共享盘时HF/DS 的分片 checkpoint 会在每个 rank 的本地 output_dir 下各自写入
# 训练恢复时,各 rank 从各自本地的相同路径读取自己的分片即可(保持相同 world_size 即可)
training_args = TrainingArguments(
output_dir=args.output_dir,
logging_dir=logging_dir,
do_train=True,
do_eval=(eval_dataset is not None),
evaluation_strategy=("steps" if eval_dataset is not None else "no"),
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
max_steps=args.max_steps if args.max_steps > 0 else -1,
lr_scheduler_type="cosine",
logging_steps=args.log_interval,
save_steps=args.save_steps,
save_total_limit=2,
deepspeed=args.deepspeed,
dataloader_drop_last=True,
dataloader_num_workers=2,
report_to=([] if args.report_to == "none" else [args.report_to]),
bf16=args.bf16,
fp16=(not args.bf16),
gradient_checkpointing=args.gradient_checkpointing,
remove_unused_columns=False,
torch_compile=False,
save_on_each_node=False # 保持默认DeepSpeed 的分片 checkpoint 本就会各 rank 本地写
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_stream,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=data_collator
)
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
ckpt_exists = (os.path.isdir(args.output_dir)
and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir)))
resume_flag = True if ckpt_exists else None
print_once(f"Resume = {resume_flag is True}")
print_once("***** Starting training *****")
train_result = trainer.train(resume_from_checkpoint=resume_flag)
trainer.save_model() # 配合 DS 配置中的 gather会在全局 rank0 聚合保存 16-bit 整模型
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if eval_dataset is not None:
print_once("***** Running eval *****")
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)
print_once("Done.")
if __name__ == "__main__":
main()