From f3ee923d611561ef6a0fee3948f5c2d9f562062d Mon Sep 17 00:00:00 2001 From: hailin Date: Tue, 26 Aug 2025 14:49:06 +0800 Subject: [PATCH] . --- ds_config_zero3.json | 1 - train_sft_ds_single_gpu_single_node.py.ok | 657 ++++++++++++++++++++++ 2 files changed, 657 insertions(+), 1 deletion(-) create mode 100644 train_sft_ds_single_gpu_single_node.py.ok diff --git a/ds_config_zero3.json b/ds_config_zero3.json index a81f81c..5961ae9 100644 --- a/ds_config_zero3.json +++ b/ds_config_zero3.json @@ -1,6 +1,5 @@ { "train_micro_batch_size_per_gpu": 1, - "gradient_accumulation_steps": 64, "steps_per_print": 0, "gradient_clipping": 1.0, diff --git a/train_sft_ds_single_gpu_single_node.py.ok b/train_sft_ds_single_gpu_single_node.py.ok new file mode 100644 index 0000000..0bcf065 --- /dev/null +++ b/train_sft_ds_single_gpu_single_node.py.ok @@ -0,0 +1,657 @@ +#!/usr/bin/env python3 +import os +os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") +import glob +import socket +import argparse +import inspect +import sys +from typing import Dict, List, Iterable, Iterator, Tuple, Optional + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset, Dataset + +from datasets import load_dataset +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + TrainingArguments, + Trainer, + set_seed +) +from transformers.trainer_callback import TrainerCallback + + +# ----------------- 进程工具 ----------------- +def is_main_process(): + return int(os.environ.get("RANK", "0")) == 0 + +def print_once(*args, **kwargs): + if is_main_process(): + print(*args, **kwargs, flush=True) + +class DebugTrainer(Trainer): + def training_step(self, model, inputs, num_items_in_batch=None): + if not hasattr(self, "_dbg_printed"): + rank = int(os.environ.get("RANK", "0")) + host = socket.gethostname() + ids = inputs["input_ids"] + msk = inputs["attention_mask"] + labs = inputs["labels"] + print(f"[step0] ids={ids.device} mask={msk.device} labs={labs.device} " + f"supervised={(labs!=-100).sum().item()}", + flush=True) + print( + f"[step0][host={host} RANK={rank}] " + f"input_ids.shape={tuple(ids.shape)} " + f"attention_mask.shape={tuple(msk.shape)} " + f"labels.shape={tuple(labs.shape)} " + f"num_items_in_batch={num_items_in_batch}", + flush=True + ) + self._dbg_printed = True + return super().training_step(model, inputs, num_items_in_batch) + + +# ----------------- 日志回调 ----------------- +class CsvLossLogger(TrainerCallback): + def __init__(self, csv_path: str): + self.csv_path = csv_path + if is_main_process(): + os.makedirs(os.path.dirname(csv_path), exist_ok=True) + with open(self.csv_path, "w", encoding="utf-8") as f: + f.write("step,loss,lr,total_flos\n") + + def on_log(self, args, state, control, logs=None, **kwargs): + if not is_main_process() or logs is None: + return + with open(self.csv_path, "a", encoding="utf-8") as f: + f.write(f"{state.global_step},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n") + + +# ----------------- 仅监督 assistant 的数据集 ----------------- +def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: + """ + 在 apply_chat_template 渲染后的文本中,返回所有 assistant 内容的字符区间 [start, end)。 + """ + spans: List[Tuple[int, int]] = [] + open_tag = "<|im_start|>assistant\n" + close_tag = "<|im_end|>\n" + pos = 0 + while True: + a = rendered.find(open_tag, pos) + if a == -1: + break + start = a + len(open_tag) + b = rendered.find(close_tag, start) + if b == -1: + break + spans.append((start, b)) + pos = b + len(close_tag) + return spans + +class QwenChatSFTDataset(IterableDataset): + """ + 期望 jsonl 每行形如: + {"messages":[{"role":"system","content":"..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]} + 可选包含工具: + {"messages":[...], "tools":[{...}]} + + 工作流: + - 使用 tokenizer.apply_chat_template 渲染 + - 仅对 assistant 片段计损失(其他 token 的 label = -100) + - 超长序列保留尾部(通常包含回答) + """ + def __init__(self, + ex_iter: Iterable[dict], + tokenizer: AutoTokenizer, + seq_len: int = 4096): + self.ex_iter = ex_iter + self.tok = tokenizer + self.seq_len = seq_len + + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + + # >>> DEBUG BEGIN + dbg_on = os.environ.get("DBG_SAMPLES", "0") == "1" + if not hasattr(self, "_dbg_seen"): self._dbg_seen = 0 + dbg_limit = int(os.environ.get("DBG_SAMPLE_LIMIT", "3")) + rank = int(os.environ.get("RANK", "0")) + lrank = int(os.environ.get("LOCAL_RANK", "-1")) + host = socket.gethostname() + # >>> DEBUG END + + for ex in self.ex_iter: + msgs = ex.get("messages", None) + if not msgs or not isinstance(msgs, list): + continue + + # 可选过滤 think + bad = False + for m in msgs: + if m.get("role") == "assistant" and isinstance(m.get("content"), str): + c = m["content"] + if "" in c and "" in c: + inner = c.split("")[-1].split("")[0].strip() + if inner: + bad = True; break + if bad: + continue + + tools = ex.get("tools", None) + + rendered: str = self.tok.apply_chat_template( + msgs, tools=tools, add_generation_prompt=False, tokenize=False + ) + if not isinstance(rendered, str) or not rendered.strip(): + continue + + spans = _assistant_char_spans(rendered) + if not spans: + continue + + enc = self.tok( + rendered, + add_special_tokens=False, + return_offsets_mapping=True + ) + input_ids: List[int] = enc["input_ids"] + offsets: List[Tuple[int, int]] = enc["offset_mapping"] + + if not input_ids: + continue + + labels = [-100] * len(input_ids) + + def in_any_span(lo: int, hi: int) -> bool: + for s, e in spans: + if not (hi <= s or lo >= e): + return True + return False + + for i, (lo, hi) in enumerate(offsets): + if in_any_span(lo, hi): + labels[i] = input_ids[i] + + # —— 固定长度策略:先截尾,再在 Dataset 层补到固定 seq_len —— + # 1) 截断到 seq_len(保留尾部) + if len(input_ids) > self.seq_len: + input_ids = input_ids[-self.seq_len:] + labels = labels[-self.seq_len:] + + # 2) 左侧补齐到 seq_len(保证所有样本长度一致) + pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id + L = len(input_ids) + if L < self.seq_len: + pad = self.seq_len - L + input_ids = ([pad_id] * pad) + input_ids + labels = ([-100] * pad) + labels + attn_mask = [0] * pad + [1] * L + else: + # 恰好等于 seq_len + attn_mask = [1] * self.seq_len + + # 若没有任何可训练 token(labels 全 -100),跳过 + if all(v == -100 for v in labels): + continue + + assert len(input_ids) == self.seq_len + assert len(labels) == self.seq_len + assert len(attn_mask) == self.seq_len + + # >>> DEBUG PRINT(此时变量已定义) + if dbg_on and self._dbg_seen < dbg_limit: + sup_tok = sum(1 for v in labels if v != -100) + print( + f"[sample][host={host} RANK={rank} LRank={lrank}] " + f"rendered_len={len(rendered)} toks={len(input_ids)} sup_toks={sup_tok} " + f"seq_len={self.seq_len} pad_id={pad_id}", + flush=True + ) + if sup_tok == 0: + print( + f"[WARN][host={host} RANK={rank}] sample has 0 supervised tokens -> would be skipped", + flush=True + ) + self._dbg_seen += 1 + # <<< DEBUG PRINT + + yield { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "attention_mask": torch.tensor(attn_mask, dtype=torch.long), + "labels": torch.tensor(labels, dtype=torch.long), + } + + + +# ----------------- 专用 Collator:pad inputs, pad labels=-100 ----------------- +class SFTDataCollator: + def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None): + self.tok = tokenizer + self.pad_to_length = pad_to_length + assert self.tok.pad_token_id is not None + + def __call__(self, features): + # if not features: + # raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. " + # f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.") + + def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x) + input_ids = [_to_list(f["input_ids"]) for f in features] + attn_masks = [_to_list(f["attention_mask"]) for f in features] + labels_list = [_to_list(f["labels"]) for f in features] + + max_len_in_batch = max(len(x) for x in input_ids) + target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch + + pad_id = self.tok.pad_token_id + batch_inp, batch_attn, batch_lab = [], [], [] + for inp, msk, lab in zip(input_ids, attn_masks, labels_list): + pad_len = target_len - len(inp) + if pad_len < 0: + inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len] + pad_len = 0 + batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long)) + batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long)) + batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long)) + + # >>> DEBUG BEGIN + dbg_on = os.environ.get("DBG_COLLATE", "0") == "1" + if dbg_on: + rank = int(os.environ.get("RANK", "0")) + host = socket.gethostname() + bs = len(features) + first_len = len(input_ids[0]) if bs > 0 else None + print( + f"[collate][host={host} RANK={rank}] features={bs} " + f"target_len={target_len} first_len={first_len}", + flush=True + ) + # 额外严苛校验:防止空 batch 继续往下走 + if not features: + raise RuntimeError( + f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. " + f"Check dataset sharding/streaming." + ) + # >>> DEBUG END + + return { + "input_ids": torch.stack(batch_inp, dim=0), + "attention_mask": torch.stack(batch_attn, dim=0), + "labels": torch.stack(batch_lab, dim=0), + } + +# ----------------- 参数 ----------------- +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("--model_name_or_path", type=str, required=True, + help="本地权重目录或 HF 名称(如 /home/test/Qwen3-8B)") + ap.add_argument("--data_glob", type=str, required=True, + help="本地 jsonl 通配符(每台机器都需有同路径数据;每行应含 messages/可选 tools)") + ap.add_argument("--output_dir", type=str, required=True, + help="本地输出目录(各节点各自本地写)") + ap.add_argument("--seq_len", type=int, default=4096) + ap.add_argument("--learning_rate", type=float, default=2e-5) + ap.add_argument("--weight_decay", type=float, default=0.1) + ap.add_argument("--warmup_ratio", type=float, default=0.02) + ap.add_argument("--num_train_epochs", type=float, default=1.0) + ap.add_argument("--max_steps", type=int, default=-1) + ap.add_argument("--log_interval", type=int, default=10) + ap.add_argument("--save_steps", type=int, default=500) + ap.add_argument("--eval_ratio", type=float, default=0.0) + ap.add_argument("--seed", type=int, default=1337) + #ap.add_argument("--deepspeed", type=str, default="ds_config_zero3.json") + ap.add_argument("--gradient_checkpointing", action="store_true") + ap.add_argument("--bf16", action="store_true", + help="3090/A100/H100 等可开 bf16;同时在 DS 配置里也要开") + ap.add_argument("--per_device_train_batch_size", type=int, default=1) + ap.add_argument("--gradient_accumulation_steps", type=int, default=64) + ap.add_argument("--report_to", type=str, default="tensorboard", + choices=["none","tensorboard","wandb"]) + ap.add_argument("--wandb_project", type=str, default="ds-qwen3") + ap.add_argument("--eval_data_glob", type=str, default=None, + help="(可选) 测试/验证集 jsonl 通配符;如提供则优先使用") + ap.add_argument("--local_rank", type=int, default=-1, + help="for deepspeed/torchrun launcher; ignored by user code") + ap.add_argument("--per_device_eval_batch_size", type=int, default=1) + ap.add_argument("--deepspeed", type=str, default=None) + return ap.parse_args() + + +# ----------------- 主函数 ----------------- +def main(): + args = parse_args() + set_seed(args.seed) + + + # -------- 调试打印工具(每个 rank 都打)-------- + host = socket.gethostname() + def dbg(msg): + print( + f"[dbg][host={host} RANK={os.environ.get('RANK','0')} " + f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}", + flush=True + ) + + # 版本 & 启动参数 & 关键环境变量 + import transformers as hf + try: + import deepspeed as ds + ds_ver = ds.__version__ + except Exception: + ds_ver = "n/a" + dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}") + dbg(f"args={args}") + dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % ( + os.environ.get("WORLD_SIZE"), + os.environ.get("RANK"), + os.environ.get("LOCAL_RANK", str(args.local_rank)), + os.environ.get("MASTER_ADDR"), + os.environ.get("MASTER_PORT"), + os.environ.get("CUDA_VISIBLE_DEVICES"), + )) + dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}") + + + + # ---- 初始化分布式(供一致性探针使用)---- + world_size = int(os.environ.get("WORLD_SIZE", "1")) + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank))) + dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}") + + if torch.cuda.is_available() and local_rank >= 0: + torch.cuda.set_device(local_rank) + dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} " + f"name={torch.cuda.get_device_name(torch.cuda.current_device())}") + else: + dbg("no cuda or invalid local_rank; not calling set_device") + + if world_size > 1 and dist.is_available() and not dist.is_initialized(): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dbg(f"init_process_group backend={backend} via env://") + dist.init_process_group(backend=backend, init_method="env://") + else: + dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}") + + + if dist.is_available() and dist.is_initialized(): + try: + dbg(f"dist.get_backend()={dist.get_backend()} " + f"dist.get_world_size()={dist.get_world_size()} dist.get_rank()={dist.get_rank()}") + except Exception as e: + dbg(f"dist query error: {e}") + + # 1) 先补 tokenizer 的 pad + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + tokenizer.model_max_length = args.seq_len + dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} " + f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}") + + # 2) 再加载模型 + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16), + low_cpu_mem_usage=True, + trust_remote_code=True + ) + + dbg(f"model loaded: dtype={next(model.parameters()).dtype} " + f"use_cache={getattr(model.config,'use_cache',None)} " + f"pad_token_id={getattr(model.config,'pad_token_id',None)}") + + # 3) pad/alibi 等配置 + model.config.pad_token_id = tokenizer.pad_token_id + model.config.use_cache = False + if args.gradient_checkpointing: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + try: + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_math_sdp(True) + except Exception: + pass + + # ===== 数据鲁棒性检查(多机各自执行)===== + host = socket.gethostname() + + files = sorted(glob.glob(args.data_glob)) + if len(files) == 0: + raise FileNotFoundError( + f"[host={host} rank={rank}] No files matched DATA_GLOB={args.data_glob}\n" + "每台机器都必须在相同本地路径下放置数据;" + "可通过 DATA_GLOB= ./run_ds.sh 覆写。" + ) + if is_main_process(): + print(f"[data] matched {len(files)} files on host={host}, example[0]={files[0]}", flush=True) + + # ====== 小探针:样本结构 ====== + ds_stream_probe = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + def ex_iter_probe(): + for ex in ds_stream_probe: + yield ex + train_stream_probe = QwenChatSFTDataset(ex_iter_probe(), tokenizer, seq_len=args.seq_len) + try: + _ = next(iter(train_stream_probe)) + except StopIteration: + raise RuntimeError( + f"[host={host} rank={rank}] 数据文件匹配到了,但没有产生任何可训练样本。\n" + "请确认每行 JSON 至少包含 'messages'(列表,含 user/assistant)字段;" + "若含 请确保不包含真实思维文本,或移除。\n" + "另外检查 seq_len 是否过小导致全部被裁。" + ) + + # ====== 正式训练流 + 模数分片(不要求样本数整除 world_size) ====== + ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + def ex_iter2(): + for i, ex in enumerate(ds_stream2): + if i % max(world_size, 1) == rank: + yield ex + train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len) + + # ====== 一致性探针:任意 rank 无样本 -> 全体退出 ====== + def has_one_sample(stream): + it = iter(stream) + try: + next(it); return 1 + except StopIteration: + return 0 + + ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + def ex_iter2_probe(): + for i, ex in enumerate(ds_stream_probe2): + if i % max(world_size, 1) == rank: + yield ex + probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len) + local_ok = has_one_sample(probe_stream) + + if dist.is_available() and dist.is_initialized(): + # t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu")) + t = torch.tensor( + local_ok, + device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu") + ) + dist.all_reduce(t, op=dist.ReduceOp.MIN) + if t.item() == 0: + if is_main_process(): + print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True) + dist.barrier() + sys.exit(2) + else: + if local_ok == 0: + print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2) + + # ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ---- + eval_dataset: Optional[Dataset] = None + + class ListDataset(Dataset): + def __init__(self, items): self.items = items + def __len__(self): return len(self.items) + def __getitem__(self, idx): return self.items[idx] + + if args.eval_data_glob: + eval_files = sorted(glob.glob(args.eval_data_glob)) + if len(eval_files) == 0: + raise FileNotFoundError(f"[host={host} rank={rank}] No eval files matched EVAL_DATA_GLOB={args.eval_data_glob}") + if is_main_process(): + print(f"[eval] matched {len(eval_files)} files, example[0]={eval_files[0]}", flush=True) + + ds_eval_stream = load_dataset("json", data_files={"eval": eval_files}, split="eval", streaming=True) + def ex_iter_eval(): + for ex in ds_eval_stream: + yield ex + + eval_iterable = QwenChatSFTDataset(ex_iter_eval(), tokenizer, seq_len=args.seq_len) + eval_items: List[Dict[str, torch.Tensor]] = [] + for sample in eval_iterable: + eval_items.append(sample) + + if len(eval_items) == 0: + raise RuntimeError("[eval] eval_data_glob 读到了 0 条有效样本,请检查 messages 结构。") + + eval_dataset = ListDataset(eval_items) + + elif args.eval_ratio and args.eval_ratio > 0: + desired_eval_batches = 200 + tmp_stream = load_dataset("json", data_files={"train": files}, split="train", streaming=True) + def ex_iter_eval2(): + for ex in tmp_stream: + yield ex + eval_stream = QwenChatSFTDataset(ex_iter_eval2(), tokenizer, seq_len=args.seq_len) + eval_samples = [] + it = iter(eval_stream) + for _ in range(desired_eval_batches): + try: + eval_samples.append(next(it)) + except StopIteration: + break + if len(eval_samples) > 0: + eval_dataset = ListDataset(eval_samples) + + # ---- 统一补齐 eval 集(确保不会出现空 batch)---- + if eval_dataset is not None: + ws = max(world_size, 1) + be = max(1, args.per_device_eval_batch_size) + global_bs = ws * be + + r = len(eval_dataset) % global_bs + if r != 0: + need = global_bs - r + # 你的 eval_dataset 是上面自定义的 ListDataset,带 .items + eval_dataset.items += eval_dataset.items[:need] + if is_main_process(): + print(f"[eval] padded eval set to {len(eval_dataset)} " + f"(world_size={ws}, per_device_eval_batch_size={be}, global_bs={global_bs})", + flush=True) + + # 补齐后再做 sanity check + assert len(eval_dataset) % global_bs == 0, \ + f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}" + + # 更稳:联调阶段不强行 pad 到 4096 + # data_collator = SFTDataCollator(tokenizer, pad_to_length=None) + data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len) + + os.makedirs(args.output_dir, exist_ok=True) + logging_dir = os.path.join(args.output_dir, "logs") + os.makedirs(logging_dir, exist_ok=True) + + # ---- 兼容 4.51(eval_strategy)与旧版(evaluation_strategy) ---- + ta_kwargs = {} + sig = inspect.signature(TrainingArguments.__init__).parameters + if eval_dataset is not None: + if "eval_strategy" in sig: + ta_kwargs["eval_strategy"] = "steps" + elif "evaluation_strategy" in sig: + ta_kwargs["evaluation_strategy"] = "steps" + else: + if "eval_strategy" in sig: + ta_kwargs["eval_strategy"] = "no" + elif "evaluation_strategy" in sig: + ta_kwargs["evaluation_strategy"] = "no" + + training_args = TrainingArguments( + output_dir=args.output_dir, + logging_dir=logging_dir, + do_train=True, + do_eval=(eval_dataset is not None), + eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None, + per_device_train_batch_size=args.per_device_train_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + warmup_ratio=args.warmup_ratio, + num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0, + max_steps=args.max_steps if args.max_steps > 0 else -1, + lr_scheduler_type="cosine", + logging_steps=args.log_interval, + save_steps=args.save_steps, + save_total_limit=2, + deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), + dataloader_drop_last=False, # 关键:别丢尾,避免空 batch + dataloader_num_workers=0, + dataloader_prefetch_factor=None, + dataloader_pin_memory=False, + per_device_eval_batch_size=args.per_device_eval_batch_size, + report_to=([] if args.report_to == "none" else [args.report_to]), + bf16=args.bf16, + fp16=(not args.bf16), + gradient_checkpointing=args.gradient_checkpointing, + remove_unused_columns=False, + torch_compile=False, + save_on_each_node=False, + logging_first_step=True, + **ta_kwargs, + ) + + + trainer = DebugTrainer( + model=model, + args=training_args, + train_dataset=train_stream, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + # processing_class=tokenizer, + data_collator=data_collator + ) + + # trainer = Trainer( + # model=model, + # args=training_args, + # train_dataset=train_stream, + # eval_dataset=eval_dataset, + # processing_class=tokenizer, + # data_collator=data_collator + # ) + trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) + + # 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-* + ckpt_exists = (os.path.isdir(args.output_dir) + and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir))) + resume_flag = True if ckpt_exists else None + + print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is True}") + print_once("***** Starting training *****") + train_result = trainer.train(resume_from_checkpoint=resume_flag) + trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型 + + metrics = train_result.metrics + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + if eval_dataset is not None: + print_once("***** Running eval *****") + eval_metrics = trainer.evaluate() + trainer.log_metrics("eval", eval_metrics) + trainer.save_metrics("eval", eval_metrics) + + print_once("Done.") + + +if __name__ == "__main__": + main()