#!/usr/bin/env python3 import os os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") import glob import socket import argparse import inspect import sys from typing import Dict, List, Iterable, Iterator, Tuple, Optional import torch import torch.distributed as dist from torch.utils.data import IterableDataset, Dataset # from contextlib import nullcontext from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, set_seed ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import get_last_checkpoint # ----------------- 进程工具 ----------------- def is_main_process(): return int(os.environ.get("RANK", "0")) == 0 def print_once(*args, **kwargs): if is_main_process(): print(*args, **kwargs, flush=True) class DebugTrainer(Trainer): def training_step(self, model, inputs, num_items_in_batch=None): if not hasattr(self, "_dbg_printed"): rank = int(os.environ.get("RANK", "0")) host = socket.gethostname() ids = inputs["input_ids"] msk = inputs["attention_mask"] labs = inputs["labels"] print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} " f"supervised={(labs!=-100).sum().item()}", flush=True) print( f"[step0][host={host} RANK={rank}] " f"input_ids.shape={tuple(ids.shape)} " f"attention_mask.shape={tuple(msk.shape)} " f"labels.shape={tuple(labs.shape)} " f"num_items_in_batch={num_items_in_batch}", flush=True ) self._dbg_printed = True return super().training_step(model, inputs, num_items_in_batch) # ----------------- 日志回调 ----------------- class CsvLossLogger(TrainerCallback): def __init__(self, csv_path: str): self.csv_path = csv_path if is_main_process(): os.makedirs(os.path.dirname(csv_path), exist_ok=True) with open(self.csv_path, "w", encoding="utf-8") as f: f.write("step,loss,lr,total_flos\n") # def on_log(self, args, state, control, logs=None, **kwargs): # if not is_main_process() or logs is None: # return # with open(self.csv_path, "a", encoding="utf-8") as f: # f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n") def on_train_begin(self, args, state, control, **kwargs): tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 rank = os.environ.get("RANK", "?") host = socket.gethostname() print(f"[{host} rank={rank}] total_steps={tot}", flush=True) def on_log(self, args, state, control, logs=None, **kwargs): if logs is None: return # ---- 控制台打印:所有 rank 都打当前步/总步 ---- cur = int(getattr(state, "global_step", 0) or 0) # if getattr(args, "logging_steps", None) and cur % args.logging_steps != 0: # return tmp = (getattr(state, "max_steps", 0) or getattr(args, "max_steps", 0) or 0) tot = tmp if isinstance(tmp, int) and tmp > 0 else 0 pct = (f"{(cur / tot * 100):.1f}%" if tot else "n/a") # —— tot 一旦可用,就再宣布一次总步数(只打印一次) if tot and not hasattr(self, "_tot_announced"): print(f"[{socket.gethostname()} rank={os.environ.get('RANK','?')}] total_steps={tot}", flush=True) self._tot_announced = True # if not is_main_process(): # return rank = os.environ.get("RANK", "?") host = socket.gethostname() print( f"[{host} rank={rank}] step {cur}/{tot} ({pct}) " f"loss={logs.get('loss')} lr={logs.get('learning_rate')}", flush=True ) # ---- 只在主进程写 CSV,避免并发写 ---- if not is_main_process(): return with open(self.csv_path, "a", encoding="utf-8") as f: f.write( f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n" ) # ----------------- 仅监督 assistant 的数据集 ----------------- def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: """ 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。 """ spans: List[Tuple[int, int]] = [] open_tag = "<|im_start|>assistant\n" close_tag = "<|im_end|>\n" pos = 0 while True: a = rendered.find(open_tag, pos) if a == -1: break start = a + len(open_tag) b = rendered.find(close_tag, start) if b == -1: break spans.append((start, b)) pos = b + len(close_tag) return spans class QwenChatSFTDataset(IterableDataset): """ 期望 jsonl 每行形如: {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]} 可选包含工具: {"messages":[...], "tools":[{...}]} 工作流: - 使用 tokenizer.apply_chat_template 渲染 - 仅对 assistant 片段计损失(其他 token 的 label = -100) - 超长序列保留尾部(通常包含回答) """ def __init__(self, ex_iter: Iterable[dict], tokenizer: AutoTokenizer, seq_len: int = 4096): self.ex_iter = ex_iter self.tok = tokenizer self.seq_len = seq_len def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: # >>> DEBUG BEGIN dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1" if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0 dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3")) rank = int(os.environ.get("RANK", "0")) lrank = int(os.environ.get("LOCAL_RANK", "-1")) host = socket.gethostname() # >>> DEBUG END for ex in self.ex_iter: msgs = ex.get("messages", None) if not msgs or not isinstance(msgs, list): continue # 可选过滤 think bad = False for m in msgs: if m.get("role") == "assistant" and isinstance(m.get("content"), str): c = m["content"] if "" in c and "" in c: inner = c.split("")[-1].split("")[0].strip() if inner: bad = True; break if bad: continue tools = ex.get("tools", None) # 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况 try: rendered: str = self.tok.apply_chat_template( msgs, tools=tools, add_generation_prompt=False, tokenize=False ) except TypeError: rendered: str = self.tok.apply_chat_template( msgs, add_generation_prompt=False, tokenize=False ) if not isinstance(rendered, str) or not rendered.strip(): continue spans = _assistant_char_spans(rendered) if not spans: continue enc = self.tok( rendered, add_special_tokens=False, return_offsets_mapping=True ) input_ids: List[int] = enc["input_ids"] offsets: List[Tuple[int, int]] = enc["offset_mapping"] if not input_ids: continue labels = [-100] * len(input_ids) def in_any_span(lo: int, hi: int) -> bool: for s, e in spans: if not (hi <= s or lo >= e): return True return False for i, (lo, hi) in enumerate(offsets): if in_any_span(lo, hi): labels[i] = input_ids[i] # —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len —— # 1) 截断到 seq_len(保留尾部) if len(input_ids) > self.seq_len: input_ids = input_ids[-self.seq_len:] labels = labels[-self.seq_len:] # 2) 左侧补齐到 seq_len(保证所有样本长度一致) pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id L = len(input_ids) if L < self.seq_len: pad = self.seq_len - L input_ids = ([pad_id] * pad) + input_ids labels = ([-100] * pad) + labels attn_mask = [0] * pad + [1] * L else: # 恰好等于 seq_len attn_mask = [1] * self.seq_len # 若没有任何可训练 token(labels 全 -100),跳过 if all(v == -100 for v in labels): continue assert len(input_ids) == self.seq_len assert len(labels) == self.seq_len assert len(attn_mask) == self.seq_len # >>> DEBUG PRINT(此时变量已定义) if dbg_on and self._dbg_seen < dbg_limit: sup_tok = sum(1 for v in labels if v != -100) print( f"[sample][host={host} RANK={rank} LRank={lrank}] " f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} " f"seq_len={self.seq_len} pad_id={pad_id}", flush=True ) if sup_tok == 0: print( f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped", flush=True ) self._dbg_seen += 1 # <<< DEBUG PRINT yield { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attn_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } # ----------------- 专用 Collator:pad inputs, pad labels=-100 ----------------- class SFTDataCollator: def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None): self.tok = tokenizer self.pad_to_length = pad_to_length assert self.tok.pad_token_id is not None def __call__(self, features): if not features: raise RuntimeError( f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. " f"Check dataset sharding/streaming." ) def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) input_ids = [_to_list(f["input_ids"]) for f in features] attn_masks = [_to_list(f["attention_mask"]) for f in features] labels_list = [_to_list(f["labels"]) for f in features] max_len_in_batch = max(len(x) for x in input_ids) target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch pad_id = self.tok.pad_token_id batch_inp, batch_attn, batch_lab = [], [], [] for inp, msk, lab in zip(input_ids, attn_masks, labels_list): pad_len = target_len - len(inp) if pad_len < 0: inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len] pad_len = 0 batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long)) batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long)) batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long)) # >>> DEBUG BEGIN dbg_on = os.environ.get("DBG_COLLATE", "0") == "1" if dbg_on: rank = int(os.environ.get("RANK", "0")) host = socket.gethostname() bs = len(features) first_len = len(input_ids[0]) if bs > 0 else None print( f"[collate][host={host} RANK={rank}] features={bs} " f"target_len={target_len} first_len={first_len}", flush=True ) return { "input_ids": torch.stack(batch_inp, dim=0), "attention_mask": torch.stack(batch_attn, dim=0), "labels": torch.stack(batch_lab, dim=0), } # ----------------- 参数 ----------------- def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--model_name_or_path", type=str, required=True, help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B)") ap.add_argument("--data_glob", type=str, required=True, help="本地 jsonl 通配符(每台机器都需有同路径数据;每行应含 messages/可选 tools)") ap.add_argument("--output_dir", type=str, required=True, help="本地输出目录(各节点各自本地写)") ap.add_argument("--seq_len", type=int, default=4096) ap.add_argument("--learning_rate", type=float, default=2e-5) ap.add_argument("--weight_decay", type=float, default=0.1) ap.add_argument("--warmup_ratio", type=float, default=0.02) ap.add_argument("--num_train_epochs", type=float, default=1.0) ap.add_argument("--max_steps", type=int, default=-1) ap.add_argument("--log_interval", type=int, default=10) ap.add_argument("--save_steps", type=int, default=500) ap.add_argument("--eval_ratio", type=float, default=0.0) ap.add_argument("--seed", type=int, default=1337) #ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json") ap.add_argument("--gradient_checkpointing", action="store_true") ap.add_argument("--bf16", action="store_true", help="3090/A100/H100 等可开 bf16;同时在 DS 配置里也要开") ap.add_argument("--per_device_train_batch_size", type=int, default=1) ap.add_argument("--gradient_accumulation_steps", type=int, default=64) ap.add_argument("--report_to", type=str, default="tensorboard", choices=["none","tensorboard","wandb"]) ap.add_argument("--wandb_project", type=str, default="ds-qwen3") ap.add_argument("--eval_data_glob", type=str, default=None, help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用") ap.add_argument("--local_rank", type=int, default=-1, help="for deepspeed/torchrun launcher; ignored by user code") ap.add_argument("--per_device_eval_batch_size", type=int, default=1) ap.add_argument("--deepspeed", type=str, default=None) return ap.parse_args() # ----------------- 主函数 ----------------- def main(): args = parse_args() set_seed(args.seed) host = socket.gethostname() def dbg(msg): print( f"[dbg][host={host} RANK={os.environ.get('RANK','0')} " f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}", flush=True ) # 是否真的启用 DeepSpeed(传了配置文件且文件存在) use_ds = bool(args.deepspeed and os.path.isfile(args.deepspeed)) dschf = None if use_ds: try: from transformers.integrations.deepspeed import HfDeepSpeedConfig src = "transformers.integrations.deepspeed" except Exception: try: # 备用:部分版本直接从 transformers 暴露 from transformers import HfDeepSpeedConfig src = "transformers" except Exception as e: raise RuntimeError( "当前 transformers 版本未提供 HfDeepSpeedConfig,请升级/降级 transformers") from e dschf = HfDeepSpeedConfig(args.deepspeed) dbg(f"HfDeepSpeedConfig loaded from {src}") if args.report_to == "wandb": os.environ.setdefault("WANDB_PROJECT", args.wandb_project) # -------- 调试打印工具(每个 rank 都打)-------- # host = socket.gethostname() # 版本 & 启动参数 & 关键环境变量 import transformers as hf try: import deepspeed as ds ds_ver = ds.__version__ except Exception: ds_ver = "n/a" dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}") dbg(f"args={args}") dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % ( os.environ.get("WORLD_SIZE"), os.environ.get("RANK"), os.environ.get("LOCAL_RANK", str(args.local_rank)), os.environ.get("MASTER_ADDR"), os.environ.get("MASTER_PORT"), os.environ.get("CUDA_VISIBLE_DEVICES"), )) dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}") # ---- 初始化分布式(供一致性探针使用)---- world_size = int(os.environ.get("WORLD_SIZE", "1")) rank = int(os.environ.get("RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank))) dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}") if torch.cuda.is_available() and local_rank >= 0: torch.cuda.set_device(local_rank) dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} " f"name={torch.cuda.get_device_name(torch.cuda.current_device())}") else: dbg("no cuda or invalid local_rank; not calling set_device") if world_size > 1 and dist.is_available() and not dist.is_initialized(): backend = "nccl" if torch.cuda.is_available() else "gloo" dbg(f"init_process_group backend={backend} via env://") dist.init_process_group(backend=backend, init_method="env://") else: dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}") if dist.is_available() and dist.is_initialized(): try: dbg(f"dist.get_backend()={dist.get_backend()} " f"dist.get_world_size()={dist.get_world_size()} dist.get_rank()={dist.get_rank()}") except Exception as e: dbg(f"dist query error: {e}") # 1) 先补 tokenizer 的 pad tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 左侧补齐以匹配 Dataset 的左 pad 策略 try: if getattr(tokenizer, "padding_side", None) != "left": tokenizer.padding_side = "left" except Exception: pass # 强制要求 fast tokenizer(offset_mapping 依赖 fast) from transformers import PreTrainedTokenizerFast if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False): raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。") tokenizer.model_max_length = args.seq_len dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} " f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}") # 2) 再加载模型 之前,先算 dtype def _bf16_supported(): if not torch.cuda.is_available(): return False # 兼容不同 torch 版本:优先用 API,退化到算力判断 if hasattr(torch.cuda, "is_bf16_supported"): return torch.cuda.is_bf16_supported() major, minor = torch.cuda.get_device_capability() return (major, minor) >= (8, 0) # Ampere 及以上 use_bf16 = bool(args.bf16 and _bf16_supported()) dtype = (torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32)) # dschf = None # if args.deepspeed and os.path.isfile(args.deepspeed): # dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件 # dbg("HfDeepSpeedConfig loaded") # try: # import deepspeed # zero_init_ctx = deepspeed.zero.Init( # remote_device="cpu", # 参数最终托管在 CPU(可结合 offload) # device="cpu", # ← 关键:不要用 meta # pin_memory=True, # dtype=dtype, # config_dict_or_path=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), # ) # except Exception: # zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑 # with zero_init_ctx: # model = AutoModelForCausalLM.from_pretrained( # args.model_name_or_path, # torch_dtype=dtype, # low_cpu_mem_usage=False, # trust_remote_code=True, # attn_implementation="sdpa" # ) # 交给插件做 ZeRO-Init/分片加载 model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=dtype, low_cpu_mem_usage=True, trust_remote_code=True, attn_implementation="sdpa", ) # model = AutoModelForCausalLM.from_pretrained( # args.model_name_or_path, # torch_dtype=dtype, # low_cpu_mem_usage=True, # trust_remote_code=True, # attn_implementation="sdpa" # ) print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True) dbg(f"model loaded: dtype={next(model.parameters()).dtype} " f"use_cache={getattr(model.config,'use_cache',None)} " f"pad_token_id={getattr(model.config,'pad_token_id',None)}") # 3) pad/alibi 等配置 model.config.pad_token_id = tokenizer.pad_token_id model.config.use_cache = False if args.gradient_checkpointing: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) try: # torch.backends.cuda.enable_flash_sdp(False) # torch.backends.cuda.enable_mem_efficient_sdp(False) # torch.backends.cuda.enable_math_sdp(True) # 让 PyTorch 自己选,或显式打开高效实现(任选其一): torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(False) except Exception: pass # ===== 数据鲁棒性检查(多机各自执行)===== # host = socket.gethostname() files = sorted(glob.glob(args.data_glob)) if len(files) == 0: raise FileNotFoundError( f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n" "每台机器都必须在相同本地路径下放置数据;" "可通过 DATA_GLOB= ./run_ds.sh 覆写。" ) if is_main_process(): print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True) # ====== 小探针:样本结构 ====== ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def ex_iter_probe(): for ex in ds_stream_probe: yield ex train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len) try: _ = next(iter(train_stream_probe)) except StopIteration: raise RuntimeError( f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n" "请确认每行 JSON 至少包含 'messages'(列表,含 user/assistant)字段;" "若含 请确保不包含真实思维文本,或移除。\n" "另外检查 seq_len 是否过小导致全部被裁。" ) # # ====== 正式训练流 ====== # ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) # if world_size > 1 and len(files) >= world_size: # # 多文件,按文件连续分片 # ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True) # train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) # else: # # 单文件或文件数不足,按样本取模轮转 # def ex_iter2(): # for i, ex in enumerate(ds_stream2): # if i % max(world_size, 1) == rank: # yield ex # train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) # ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)====== ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len) # # ====== 一致性探针(与上面保持同逻辑)===== # ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) # if world_size > 1 and len(files) >= world_size: # ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True) # probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) # else: # def ex_iter2_probe(): # for i, ex in enumerate(ds_stream_probe2): # if i % max(world_size, 1) == rank: # yield ex # probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len) # ====== 一致性探针(不分片)====== ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len) def has_at_least(stream, n: int): it = iter(stream) for _ in range(n): try: next(it) except StopIteration: return 0 return 1 need = max(1, args.gradient_accumulation_steps) local_ok = has_at_least(probe_stream, need) if dist.is_available() and dist.is_initialized(): t = torch.tensor(local_ok, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")) dist.all_reduce(t, op=dist.ReduceOp.MIN) if t.item() == 0: if is_main_process(): print( f"[FATAL] 至少有一个 rank 在一个优化 step 内供不上 {need} 个微批 (GA={need})。 " f"请减少 GA 或扩大/清洗数据;本次训练不会启动。", flush=True ) dist.barrier() sys.exit(2) else: if local_ok == 0: print( f"[FATAL] 本机在一个优化 step 内供不上 {need} 个微批 (GA={need})。 " f"请减少 GA 或扩大/清洗数据;本次训练不会启动。", flush=True ) sys.exit(2) # ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ---- eval_dataset: Optional[Dataset] = None class ListDataset(Dataset): def __init__(self, items): self.items = items def __len__(self): return len(self.items) def __getitem__(self, idx): return self.items[idx] if args.eval_data_glob: eval_files = sorted(glob.glob(args.eval_data_glob)) if len(eval_files) == 0: raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}") if is_main_process(): print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True) ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True) def ex_iter_eval(): for ex in ds_eval_stream: yield ex eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len) eval_items: List[Dict[str, torch.Tensor]] = [] for sample in eval_iterable: eval_items.append(sample) if len(eval_items) == 0: raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。") eval_dataset = ListDataset(eval_items) elif args.eval_ratio and args.eval_ratio > 0: desired_eval_batches = 200 tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True) def ex_iter_eval2(): for ex in tmp_stream: yield ex eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len) eval_samples = [] it = iter(eval_stream) for _ in range(desired_eval_batches): try: eval_samples.append(next(it)) except StopIteration: break if len(eval_samples) > 0: eval_dataset = ListDataset(eval_samples) # ---- 统一补齐 eval 集(确保不会出现空 batch)---- if eval_dataset is not None: ws = max(world_size, 1) be = max(1, args.per_device_eval_batch_size) global_bs = ws * be r = len(eval_dataset) % global_bs if r != 0: pad_need = global_bs - r eval_dataset.items += eval_dataset.items[:pad_need] if is_main_process(): print(f"[eval] padded eval set to {len(eval_dataset)} " f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})", flush=True) # 补齐后再做 sanity check assert len(eval_dataset) % global_bs == 0, \ f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}" # 更稳:联调阶段不强行 pad 到 4096 data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len) os.makedirs(args.output_dir, exist_ok=True) logging_dir = os.path.join(args.output_dir, "logs") os.makedirs(logging_dir, exist_ok=True) # ---- 兼容 4.51(eval_strategy)与旧版(evaluation_strategy) ---- ta_kwargs = {} sig = inspect.signature(TrainingArguments.__init__).parameters if eval_dataset is not None: if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "steps" elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "steps" else: if "eval_strategy" in sig: ta_kwargs["eval_strategy"] = "no" elif "evaluation_strategy" in sig: ta_kwargs["evaluation_strategy"] = "no" ta_sig = inspect.signature(TrainingArguments.__init__).parameters ta_kwargs2 = dict( output_dir=args.output_dir, logging_dir=logging_dir, do_train=True, do_eval=(eval_dataset is not None), eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, weight_decay=args.weight_decay, warmup_ratio=args.warmup_ratio, num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0, max_steps=args.max_steps if args.max_steps > 0 else -1, lr_scheduler_type="cosine", logging_steps=args.log_interval, save_steps=args.save_steps, save_total_limit=2, # deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), deepspeed=(args.deepspeed if use_ds else None), dataloader_drop_last=False, dataloader_num_workers=0, per_device_eval_batch_size=args.per_device_eval_batch_size, report_to=([] if args.report_to == "none" else [args.report_to]), bf16=args.bf16, fp16=(not args.bf16), gradient_checkpointing=args.gradient_checkpointing, remove_unused_columns=False, save_on_each_node=True, logging_first_step=True, **ta_kwargs, # 你之前构造的 eval_strategy 兼容项 ) # if "dataloader_prefetch_factor" in ta_sig: # ta_kwargs2["dataloader_prefetch_factor"] = None if "dataloader_pin_memory" in ta_sig: ta_kwargs2["dataloader_pin_memory"] = False if "torch_compile" in ta_sig: ta_kwargs2["torch_compile"] = False # 构造 TrainingArguments 之前,沿用上面的 use_bf16 判定 ta_kwargs2.update({ "bf16": use_bf16, "fp16": (torch.cuda.is_available() and not use_bf16), }) training_args = TrainingArguments(**ta_kwargs2) trainer_kwargs = {} if "processing_class" in inspect.signature(Trainer.__init__).parameters: trainer_kwargs["processing_class"] = tokenizer else: trainer_kwargs["tokenizer"] = tokenizer trainer = DebugTrainer( model=model, args=training_args, train_dataset=train_stream, eval_dataset=eval_dataset, #tokenizer=tokenizer, #processing_class=tokenizer, data_collator=data_collator, **trainer_kwargs, ) trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) # ==== 断点恢复判定(非共享盘安全写法)==== def last_step(path: str) -> int: ck = get_last_checkpoint(path) if ck is None: return -1 base = os.path.basename(ck) try: return int(base.split("-")[-1]) except Exception: return -1 local_last = last_step(args.output_dir) # -1 表示本机没有任何 checkpoint device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu") resume_flag = None if dist.is_available() and dist.is_initialized(): # 只要有任意一个 rank 没有 ckpt -> 不恢复 has_local = torch.tensor(1 if local_last >= 0 else 0, device=device) dist.all_reduce(has_local, op=dist.ReduceOp.MIN) if has_local.item() == 1: # 全员都有:收集每个 rank 的 last step,取公共最小步 k(每台机器都一定存在) ts = torch.tensor(local_last, device=device) world = dist.get_world_size() buf = [torch.zeros_like(ts) for _ in range(world)] dist.all_gather(buf, ts) steps = [b.item() for b in buf] k = min(steps) if k >= 0: resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}") if is_main_process(): print(f"[resume] steps={steps}, resume={resume_flag}", flush=True) else: # 单机或未初始化分布式:本地有就按本机最后一步恢复 if local_last >= 0: resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}") print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}") # —— 全局一致性检测:如果有任意 rank 缺这个 ckpt,就禁用恢复 —— if dist.is_available() and dist.is_initialized(): device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu") present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device) dist.all_reduce(present, op=dist.ReduceOp.MIN) if present.item() == 0: if is_main_process(): print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True) resume_flag = None dist.barrier() else: if resume_flag is not None and not os.path.isdir(resume_flag): # 单机:缺就直接禁用恢复 print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True) resume_flag = None print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}") print_once("***** Starting training *****") dbg(f"allocated={torch.cuda.memory_allocated()/1024**2:.1f} MB, " f"reserved={torch.cuda.memory_reserved()/1024**2:.1f} MB") train_result = trainer.train(resume_from_checkpoint=resume_flag) trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型 metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() if eval_dataset is not None: print_once("***** Running eval *****") eval_metrics = trainer.evaluate() trainer.log_metrics("eval", eval_metrics) trainer.save_metrics("eval", eval_metrics) print_once("Done.") if __name__ == "__main__": main()