249 lines
10 KiB
Python
249 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
import os
|
||
import glob
|
||
import argparse
|
||
from typing import Dict, List, Iterable, Iterator
|
||
|
||
import torch
|
||
from torch.utils.data import IterableDataset
|
||
|
||
from datasets import load_dataset
|
||
from transformers import (
|
||
AutoTokenizer,
|
||
AutoModelForCausalLM,
|
||
TrainingArguments,
|
||
Trainer,
|
||
DataCollatorForLanguageModeling,
|
||
set_seed
|
||
)
|
||
from transformers.trainer_callback import TrainerCallback
|
||
|
||
def is_main_process():
|
||
return int(os.environ.get("RANK", "0")) == 0
|
||
|
||
def print_once(*args, **kwargs):
|
||
if is_main_process():
|
||
print(*args, **kwargs, flush=True)
|
||
|
||
class ConstantLengthDataset(IterableDataset):
|
||
def __init__(self,
|
||
texts_iter: Iterable[str],
|
||
tokenizer: AutoTokenizer,
|
||
seq_len: int = 4096,
|
||
buffer_size: int = 1024 * 1024):
|
||
self.texts_iter = texts_iter
|
||
self.tokenizer = tokenizer
|
||
self.seq_len = seq_len
|
||
self.buffer_size = buffer_size
|
||
|
||
def __iter__(self):
|
||
buffer_texts: List[str] = []
|
||
token_buffer: List[int] = []
|
||
for txt in self.texts_iter:
|
||
if not txt:
|
||
continue
|
||
buffer_texts.append(txt)
|
||
if len(buffer_texts) >= 1024:
|
||
enc = self.tokenizer(buffer_texts, add_special_tokens=False)['input_ids']
|
||
for ids in enc:
|
||
token_buffer.extend(ids + [self.tokenizer.eos_token_id])
|
||
buffer_texts.clear()
|
||
while len(token_buffer) >= self.seq_len:
|
||
chunk = token_buffer[:self.seq_len]
|
||
del token_buffer[:self.seq_len]
|
||
yield {
|
||
"input_ids": torch.tensor(chunk, dtype=torch.long),
|
||
"attention_mask": torch.ones(self.seq_len, dtype=torch.long),
|
||
"labels": torch.tensor(chunk, dtype=torch.long)
|
||
}
|
||
if buffer_texts:
|
||
enc = self.tokenizer(buffer_texts, add_special_tokens=False)['input_ids']
|
||
for ids in enc:
|
||
token_buffer.extend(ids + [self.tokenizer.eos_token_id])
|
||
while len(token_buffer) >= self.seq_len:
|
||
chunk = token_buffer[:self.seq_len]
|
||
del token_buffer[:self.seq_len]
|
||
yield {
|
||
"input_ids": torch.tensor(chunk, dtype=torch.long),
|
||
"attention_mask": torch.ones(self.seq_len, dtype=torch.long),
|
||
"labels": torch.tensor(chunk, dtype=torch.long)
|
||
}
|
||
|
||
class CsvLossLogger(TrainerCallback):
|
||
def __init__(self, csv_path: str):
|
||
self.csv_path = csv_path
|
||
if is_main_process():
|
||
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
|
||
with open(self.csv_path, "w", encoding="utf-8") as f:
|
||
f.write("step,loss,lr,total_flos\n")
|
||
|
||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||
if not is_main_process() or logs is None:
|
||
return
|
||
with open(self.csv_path, "a", encoding="utf-8") as f:
|
||
f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n")
|
||
|
||
def parse_args():
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument("--model_name_or_path", type=str, required=True,
|
||
help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B)")
|
||
ap.add_argument("--data_glob", type=str, required=True,
|
||
help="本地 jsonl 通配符(每台机器都需有同路径数据)")
|
||
ap.add_argument("--output_dir", type=str, required=True,
|
||
help="本地输出目录(各节点各自本地写)")
|
||
ap.add_argument("--seq_len", type=int, default=4096)
|
||
ap.add_argument("--learning_rate", type=float, default=2e-5)
|
||
ap.add_argument("--weight_decay", type=float, default=0.1)
|
||
ap.add_argument("--warmup_ratio", type=float, default=0.02)
|
||
ap.add_argument("--num_train_epochs", type=float, default=1.0)
|
||
ap.add_argument("--max_steps", type=int, default=-1)
|
||
ap.add_argument("--log_interval", type=int, default=10)
|
||
ap.add_argument("--save_steps", type=int, default=500)
|
||
ap.add_argument("--eval_ratio", type=float, default=0.0)
|
||
ap.add_argument("--seed", type=int, default=1337)
|
||
ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json")
|
||
ap.add_argument("--gradient_checkpointing", action="store_true")
|
||
ap.add_argument("--bf16", action="store_true",
|
||
help="3090/A100/H100 等可开 bf16;同时在 DS 配置里也要开")
|
||
ap.add_argument("--per_device_train_batch_size", type=int, default=1)
|
||
ap.add_argument("--gradient_accumulation_steps", type=int, default=64)
|
||
ap.add_argument("--report_to", type=str, default="tensorboard",
|
||
choices=["none","tensorboard","wandb"])
|
||
ap.add_argument("--wandb_project", type=str, default="ds-qwen3")
|
||
return ap.parse_args()
|
||
|
||
def main():
|
||
args = parse_args()
|
||
set_seed(args.seed)
|
||
|
||
# Tokenizer/Model
|
||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
|
||
if tokenizer.pad_token is None:
|
||
tokenizer.pad_token = tokenizer.eos_token
|
||
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
args.model_name_or_path,
|
||
torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
|
||
low_cpu_mem_usage=True,
|
||
trust_remote_code=True
|
||
)
|
||
model.config.use_cache = False
|
||
if args.gradient_checkpointing:
|
||
model.gradient_checkpointing_enable()
|
||
|
||
# 本地数据(每台机器同路径各自读取)
|
||
files = sorted(glob.glob(args.data_glob))
|
||
if len(files) == 0:
|
||
raise FileNotFoundError(f"No files match: {args.data_glob}")
|
||
|
||
dataset_iter = load_dataset(
|
||
"json",
|
||
data_files={"train": files},
|
||
split="train",
|
||
streaming=True
|
||
)
|
||
|
||
def text_iter():
|
||
for ex in dataset_iter:
|
||
txt = ex.get("text", None)
|
||
if isinstance(txt, str) and len(txt.strip()) > 0:
|
||
yield txt
|
||
|
||
train_stream = ConstantLengthDataset(texts_iter=text_iter(), tokenizer=tokenizer, seq_len=args.seq_len)
|
||
|
||
eval_dataset = None
|
||
if args.eval_ratio and args.eval_ratio > 0:
|
||
# 简易 eval:从训练流头部抽若干 batch(可按需关闭)
|
||
desired_eval_batches = 200
|
||
gen = iter(train_stream)
|
||
eval_samples = []
|
||
for _ in range(desired_eval_batches):
|
||
try: eval_samples.append(next(gen))
|
||
except StopIteration: break
|
||
class ListDataset(torch.utils.data.Dataset):
|
||
def __init__(self, items): self.items = items
|
||
def __len__(self): return len(self.items)
|
||
def __getitem__(self, idx): return self.items[idx]
|
||
eval_dataset = ListDataset(eval_samples)
|
||
# 重新构造训练流
|
||
dataset_iter2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||
def text_iter2():
|
||
for ex in dataset_iter2:
|
||
txt = ex.get("text", None)
|
||
if isinstance(txt, str) and len(txt.strip()) > 0:
|
||
yield txt
|
||
train_stream = ConstantLengthDataset(texts_iter=text_iter2(), tokenizer=tokenizer, seq_len=args.seq_len)
|
||
|
||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
logging_dir = os.path.join(args.output_dir, "logs")
|
||
|
||
# 关键点:无共享盘时,HF/DS 的分片 checkpoint 会在每个 rank 的本地 output_dir 下各自写入
|
||
# 训练恢复时,各 rank 从各自本地的相同路径读取自己的分片即可(保持相同 world_size 即可)
|
||
training_args = TrainingArguments(
|
||
output_dir=args.output_dir,
|
||
logging_dir=logging_dir,
|
||
do_train=True,
|
||
do_eval=(eval_dataset is not None),
|
||
evaluation_strategy=("steps" if eval_dataset is not None else "no"),
|
||
eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
|
||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||
learning_rate=args.learning_rate,
|
||
weight_decay=args.weight_decay,
|
||
warmup_ratio=args.warmup_ratio,
|
||
num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
|
||
max_steps=args.max_steps if args.max_steps > 0 else -1,
|
||
lr_scheduler_type="cosine",
|
||
logging_steps=args.log_interval,
|
||
save_steps=args.save_steps,
|
||
save_total_limit=2,
|
||
deepspeed=args.deepspeed,
|
||
dataloader_drop_last=True,
|
||
dataloader_num_workers=2,
|
||
report_to=([] if args.report_to == "none" else [args.report_to]),
|
||
bf16=args.bf16,
|
||
fp16=(not args.bf16),
|
||
gradient_checkpointing=args.gradient_checkpointing,
|
||
remove_unused_columns=False,
|
||
torch_compile=False,
|
||
save_on_each_node=False # 保持默认;DeepSpeed 的分片 checkpoint 本就会各 rank 本地写
|
||
)
|
||
|
||
trainer = Trainer(
|
||
model=model,
|
||
args=training_args,
|
||
train_dataset=train_stream,
|
||
eval_dataset=eval_dataset,
|
||
tokenizer=tokenizer,
|
||
data_collator=data_collator
|
||
)
|
||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||
|
||
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
|
||
ckpt_exists = (os.path.isdir(args.output_dir)
|
||
and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir)))
|
||
resume_flag = True if ckpt_exists else None
|
||
|
||
print_once(f"Resume = {resume_flag is True}")
|
||
print_once("***** Starting training *****")
|
||
train_result = trainer.train(resume_from_checkpoint=resume_flag)
|
||
trainer.save_model() # 配合 DS 配置中的 gather,会在全局 rank0 聚合保存 16-bit 整模型
|
||
|
||
metrics = train_result.metrics
|
||
trainer.log_metrics("train", metrics)
|
||
trainer.save_metrics("train", metrics)
|
||
trainer.save_state()
|
||
|
||
if eval_dataset is not None:
|
||
print_once("***** Running eval *****")
|
||
eval_metrics = trainer.evaluate()
|
||
trainer.log_metrics("eval", eval_metrics)
|
||
trainer.save_metrics("eval", eval_metrics)
|
||
|
||
print_once("Done.")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|