From b054b9e8053daaaee94a15e36f329d2d978df436 Mon Sep 17 00:00:00 2001 From: hailin Date: Wed, 3 Sep 2025 16:27:26 +0800 Subject: [PATCH] . --- train_sft_ds.py | 62 ++++--------------------------------------------- 1 file changed, 4 insertions(+), 58 deletions(-) diff --git a/train_sft_ds.py b/train_sft_ds.py index 905abc4..d62c649 100644 --- a/train_sft_ds.py +++ b/train_sft_ds.py @@ -11,8 +11,6 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional import torch import torch.distributed as dist from torch.utils.data import IterableDataset, Dataset -# from contextlib import nullcontext - from datasets import load_dataset from transformers import ( @@ -68,7 +66,7 @@ except Exception: # 3) 统一 JIT 缓存目录(可选,但更稳;日志里你现在用的是 ~/.cache/torch_extensions) os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext") -os.environ.setdefault("MAX_JOBS", "8") +os.environ.setdefault("MAX_JOBS", "12") import shutil if shutil.which("ninja") is None: @@ -124,7 +122,6 @@ class DebugTrainer(Trainer): self._dbg_printed = True return super().training_step(model, inputs, num_items_in_batch) - # ----------------- 日志回调 ----------------- class CsvLossLogger(TrainerCallback): def __init__(self, csv_path: str): @@ -189,9 +186,6 @@ class CsvLossLogger(TrainerCallback): f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n" ) - - - # ----------------- 仅监督 assistant 的数据集 ----------------- def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: """ @@ -353,8 +347,6 @@ class QwenChatSFTDataset(IterableDataset): "labels": torch.tensor(labels, dtype=torch.long), } - - # ----------------- 专用 Collator:pad inputs, pad labels=-100 ----------------- class SFTDataCollator: def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None): @@ -481,11 +473,6 @@ def main(): if args.report_to == "wandb": os.environ.setdefault("WANDB_PROJECT", args.wandb_project) - - # -------- 调试打印工具(每个 rank 都打)-------- - # host = socket.gethostname() - - # 版本 & 启动参数 & 关键环境变量 import transformers as hf try: @@ -493,6 +480,7 @@ def main(): ds_ver = ds.__version__ except Exception: ds_ver = "n/a" + dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}") dbg(f"args={args}") dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % ( @@ -506,7 +494,6 @@ def main(): dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}") - # ---- 初始化分布式(供一致性探针使用)---- world_size = int(os.environ.get("WORLD_SIZE", "1")) rank = int(os.environ.get("RANK", "0")) @@ -541,8 +528,6 @@ def main(): tokenizer.pad_token = tokenizer.eos_token - - # 左侧补齐以匹配 Dataset 的左 pad 策略 try: if getattr(tokenizer, "padding_side", None) != "left": @@ -556,8 +541,6 @@ def main(): raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。") - - tokenizer.model_max_length = args.seq_len dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} " f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}") @@ -576,34 +559,6 @@ def main(): dtype = (torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32)) - # dschf = None - # if args.deepspeed and os.path.isfile(args.deepspeed): - # dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件 - # dbg("HfDeepSpeedConfig loaded") - - - # try: - # import deepspeed - # zero_init_ctx = deepspeed.zero.Init( - # remote_device="cpu", # 参数最终托管在 CPU(可结合 offload) - # device="cpu", # ← 关键:不要用 meta - # pin_memory=True, - # dtype=dtype, - # config_dict_or_path=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None), - # ) - # except Exception: - # zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑 - - # with zero_init_ctx: - # model = AutoModelForCausalLM.from_pretrained( - # args.model_name_or_path, - # torch_dtype=dtype, - # low_cpu_mem_usage=False, - # trust_remote_code=True, - # attn_implementation="sdpa" - # ) - - # 交给插件做 ZeRO-Init/分片加载 model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, @@ -613,19 +568,10 @@ def main(): attn_implementation="sdpa", ) - - # model = AutoModelForCausalLM.from_pretrained( - # args.model_name_or_path, - # torch_dtype=dtype, - # low_cpu_mem_usage=True, - # trust_remote_code=True, - # attn_implementation="sdpa" - # ) - print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True) dbg(f"model loaded: dtype={next(model.parameters()).dtype} " - f"use_cache={getattr(model.config,'use_cache',None)} " - f"pad_token_id={getattr(model.config,'pad_token_id',None)}") + f"use_cache={getattr(model.config,'use_cache',None)} " + f"pad_token_id={getattr(model.config,'pad_token_id',None)}") # 3) pad/alibi 等配置 model.config.pad_token_id = tokenizer.pad_token_id