#!/usr/bin/env python3 import os import glob import socket import argparse from typing import Dict, List, Iterable, Iterator import torch from torch.utils.data import IterableDataset from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling, set_seed ) from transformers.trainer_callback import TrainerCallback def is_main_process(): return int(os.environ.get("RANK", "0")) == 0 def print_once(*args, **kwargs): if is_main_process(): print(*args, **kwargs, flush=True) class ConstantLengthDataset(IterableDataset): def __init__(self, texts_iter: Iterable[str], tokenizer: AutoTokenizer, seq_len: int = 4096, buffer_size: int = 1024 * 1024): self.texts_iter = texts_iter self.tokenizer = tokenizer self.seq_len = seq_len self.buffer_size = buffer_size def __iter__(self): buffer_texts: List[str] = [] token_buffer: List[int] = [] for txt in self.texts_iter: if not txt: continue buffer_texts.append(txt) if len(buffer_texts) >= 1024: enc = self.tokenizer(buffer_texts, add_special_tokens=False)['input_ids'] for ids in enc: token_buffer.extend(ids + [self.tokenizer.eos_token_id]) buffer_texts.clear() while len(token_buffer) >= self.seq_len: chunk = token_buffer[:self.seq_len] del token_buffer[:self.seq_len] yield { "input_ids": torch.tensor(chunk, dtype=torch.long), "attention_mask": torch.ones(self.seq_len, dtype=torch.long), "labels": torch.tensor(chunk, dtype=torch.long) } if buffer_texts: enc = self.tokenizer(buffer_texts, add_special_tokens=False)['input_ids'] for ids in enc: token_buffer.extend(ids + [self.tokenizer.eos_token_id]) while len(token_buffer) >= self.seq_len: chunk = token_buffer[:self.seq_len] del token_buffer[:self.seq_len] yield { "input_ids": torch.tensor(chunk, dtype=torch.long), "attention_mask": torch.ones(self.seq_len, dtype=torch.long), "labels": torch.tensor(chunk, dtype=torch.long) } class CsvLossLogger(TrainerCallback): def __init__(self, csv_path: str): self.csv_path = csv_path if is_main_process(): os.makedirs(os.path.dirname(csv_path), exist_ok=True) with open(self.csv_path, "w", encoding="utf-8") as f: f.write("step,loss,lr,total_flos\n") def on_log(self, args, state, control, logs=None, **kwargs): if not is_main_process() or logs is None: return with open(self.csv_path, "a", encoding="utf-8") as f: f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n") def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--model_name_or_path", type=str, required=True, help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B)") ap.add_argument("--data_glob", type=str, required=True, help="本地 jsonl 通配符(每台机器都需有同路径数据)") ap.add_argument("--output_dir", type=str, required=True, help="本地输出目录(各节点各自本地写)") ap.add_argument("--seq_len", type=int, default=4096) ap.add_argument("--learning_rate", type=float, default=2e-5) ap.add_argument("--weight_decay", type=float, default=0.1) ap.add_argument("--warmup_ratio", type=float, default=0.02) ap.add_argument("--num_train_epochs", type=float, default=1.0) ap.add_argument("--max_steps", type=int, default=-1) ap.add_argument("--log_interval", type=int, default=10) ap.add_argument("--save_steps", type=int, default=500) ap.add_argument("--eval_ratio", type=float, default=0.0) ap.add_argument("--seed", type=int, default=1337) ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json") ap.add_argument("--gradient_checkpointing", action="store_true") ap.add_argument("--bf16", action="store_true", help="3090/A100/H100 等可开 bf16;同时在 DS 配置里也要开") ap.add_argument("--per_device_train_batch_size", type=int, default=1) ap.add_argument("--gradient_accumulation_steps", type=int, default=64) ap.add_argument("--report_to", type=str, default="tensorboard", choices=["none","tensorboard","wandb"]) ap.add_argument("--wandb_project", type=str, default="ds-qwen3") return ap.parse_args() def main(): args = parse_args() set_seed(args.seed) # Tokenizer/Model tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16), low_cpu_mem_usage=True, trust_remote_code=True ) model.config.use_cache = False if args.gradient_checkpointing: model.gradient_checkpointing_enable() # ===== 数据鲁棒性检查(多机各自执行)===== host = socket.gethostname() rank = int(os.environ.get("RANK", "0")) files = sorted(glob.glob(args.data_glob)) if len(files) == 0: raise FileNotFoundError( f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n" "每台机器都必须在相同本地路径下放置数据;" "可通过 DATA_GLOB= ./launch_ds.sh 覆写。" ) if is_main_process(): print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True) # streaming 逐行读取,字段名为 'text' dataset_iter = load_dataset( "json", data_files={"train": files}, split="train", streaming=True ) def text_iter(): for ex in dataset_iter: txt = ex.get("text", None) if isinstance(txt, str) and len(txt.strip()) > 0: yield txt # 先构造一次流,做“非空探针” train_stream_probe = ConstantLengthDataset(texts_iter=text_iter(), tokenizer=tokenizer, seq_len=args.seq_len) _probe = iter(train_stream_probe) try: _ = next(_probe) # 拉一个 chunk,确保真的能产出训练样本 except StopIteration: raise RuntimeError( f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n" "常见原因:jsonl 缺少 'text' 字段、内容全为空/空白行、或 --seq_len 过大。\n" "请检查样例行,或将 --seq_len 调小后再试。" ) # 探针消耗了流,重新构造一次“干净”的训练流 dataset_iter2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def text_iter2(): for ex in dataset_iter2: txt = ex.get("text", None) if isinstance(txt, str) and len(txt.strip()) > 0: yield txt train_stream = ConstantLengthDataset(texts_iter=text_iter2(), tokenizer=tokenizer, seq_len=args.seq_len) # 可选 eval(从头部抽样) eval_dataset = None if args.eval_ratio and args.eval_ratio > 0: desired_eval_batches = 200 gen = iter(train_stream) eval_samples = [] for _ in range(desired_eval_batches): try: eval_samples.append(next(gen)) except StopIteration: break class ListDataset(torch.utils.data.Dataset): def __init__(self, items): self.items = items def __len__(self): return len(self.items) def __getitem__(self, idx): return self.items[idx] eval_dataset = ListDataset(eval_samples) # 抽样后再重建训练流,防止“吃掉”头部 dataset_iter3 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def text_iter3(): for ex in dataset_iter3: txt = ex.get("text", None) if isinstance(txt, str) and len(txt.strip()) > 0: yield txt train_stream = ConstantLengthDataset(texts_iter=text_iter3(), tokenizer=tokenizer, seq_len=args.seq_len) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) os.makedirs(args.output_dir, exist_ok=True) logging_dir = os.path.join(args.output_dir, "logs") # 无共享盘:各 rank 在各自本地 output_dir 下写入自己的分片 training_args = TrainingArguments( output_dir=args.output_dir, logging_dir=logging_dir, do_train=True, do_eval=(eval_dataset is not None), evaluation_strategy=("steps" if eval_dataset is not None else "no"), eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, weight_decay=args.weight_decay, warmup_ratio=args.warmup_ratio, num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0, max_steps=args.max_steps if args.max_steps > 0 else -1, lr_scheduler_type="cosine", logging_steps=args.log_interval, save_steps=args.save_steps, save_total_limit=2, deepspeed=args.deepspeed, dataloader_drop_last=True, dataloader_num_workers=2, report_to=([] if args.report_to == "none" else [args.report_to]), bf16=args.bf16, fp16=(not args.bf16), gradient_checkpointing=args.gradient_checkpointing, remove_unused_columns=False, torch_compile=False, save_on_each_node=False ) trainer = Trainer( model=model, args=training_args, train_dataset=train_stream, eval_dataset=eval_dataset, tokenizer=tokenizer, data_collator=data_collator ) trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) # 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-* ckpt_exists = (os.path.isdir(args.output_dir) and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir))) resume_flag = True if ckpt_exists else None print_once(f"[host={host}] Resume = {resume_flag is True}") print_once("***** Starting training *****") train_result = trainer.train(resume_from_checkpoint=resume_flag) trainer.save_model() # 配合 DS 配置 stage3_gather_16bit_weights_on_model_save=true,仅在全局 rank0 聚合保存整模型 metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() if eval_dataset is not None: print_once("***** Running eval *****") eval_metrics = trainer.evaluate() trainer.log_metrics("eval", eval_metrics) trainer.save_metrics("eval", eval_metrics) print_once("Done.") if __name__ == "__main__": main()