This commit is contained in:
hailin 2025-09-03 16:27:26 +08:00
parent a76284e116
commit b054b9e805
1 changed files with 4 additions and 58 deletions

View File

@ -11,8 +11,6 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset from torch.utils.data import IterableDataset, Dataset
# from contextlib import nullcontext
from datasets import load_dataset from datasets import load_dataset
from transformers import ( from transformers import (
@ -68,7 +66,7 @@ except Exception:
# 3) 统一 JIT 缓存目录(可选,但更稳;日志里你现在用的是 ~/.cache/torch_extensions # 3) 统一 JIT 缓存目录(可选,但更稳;日志里你现在用的是 ~/.cache/torch_extensions
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext") os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
os.environ.setdefault("MAX_JOBS", "8") os.environ.setdefault("MAX_JOBS", "12")
import shutil import shutil
if shutil.which("ninja") is None: if shutil.which("ninja") is None:
@ -124,7 +122,6 @@ class DebugTrainer(Trainer):
self._dbg_printed = True self._dbg_printed = True
return super().training_step(model, inputs, num_items_in_batch) return super().training_step(model, inputs, num_items_in_batch)
# ----------------- 日志回调 ----------------- # ----------------- 日志回调 -----------------
class CsvLossLogger(TrainerCallback): class CsvLossLogger(TrainerCallback):
def __init__(self, csv_path: str): def __init__(self, csv_path: str):
@ -189,9 +186,6 @@ class CsvLossLogger(TrainerCallback):
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n" f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
) )
# ----------------- 仅监督 assistant 的数据集 ----------------- # ----------------- 仅监督 assistant 的数据集 -----------------
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]: def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
""" """
@ -353,8 +347,6 @@ class QwenChatSFTDataset(IterableDataset):
"labels": torch.tensor(labels, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long),
} }
# ----------------- 专用 Collatorpad inputs, pad labels=-100 ----------------- # ----------------- 专用 Collatorpad inputs, pad labels=-100 -----------------
class SFTDataCollator: class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None): def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
@ -481,11 +473,6 @@ def main():
if args.report_to == "wandb": if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project) os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
# -------- 调试打印工具(每个 rank 都打)--------
# host = socket.gethostname()
# 版本 & 启动参数 & 关键环境变量 # 版本 & 启动参数 & 关键环境变量
import transformers as hf import transformers as hf
try: try:
@ -493,6 +480,7 @@ def main():
ds_ver = ds.__version__ ds_ver = ds.__version__
except Exception: except Exception:
ds_ver = "n/a" ds_ver = "n/a"
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}") dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
dbg(f"args={args}") dbg(f"args={args}")
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % ( dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
@ -506,7 +494,6 @@ def main():
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}") dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
# ---- 初始化分布式(供一致性探针使用)---- # ---- 初始化分布式(供一致性探针使用)----
world_size = int(os.environ.get("WORLD_SIZE", "1")) world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0")) rank = int(os.environ.get("RANK", "0"))
@ -541,8 +528,6 @@ def main():
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# 左侧补齐以匹配 Dataset 的左 pad 策略 # 左侧补齐以匹配 Dataset 的左 pad 策略
try: try:
if getattr(tokenizer, "padding_side", None) != "left": if getattr(tokenizer, "padding_side", None) != "left":
@ -556,8 +541,6 @@ def main():
raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。") raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。")
tokenizer.model_max_length = args.seq_len tokenizer.model_max_length = args.seq_len
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} " dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}") f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
@ -576,34 +559,6 @@ def main():
dtype = (torch.bfloat16 if use_bf16 else dtype = (torch.bfloat16 if use_bf16 else
(torch.float16 if torch.cuda.is_available() else torch.float32)) (torch.float16 if torch.cuda.is_available() else torch.float32))
# dschf = None
# if args.deepspeed and os.path.isfile(args.deepspeed):
# dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件
# dbg("HfDeepSpeedConfig loaded")
# try:
# import deepspeed
# zero_init_ctx = deepspeed.zero.Init(
# remote_device="cpu", # 参数最终托管在 CPU可结合 offload
# device="cpu", # ← 关键:不要用 meta
# pin_memory=True,
# dtype=dtype,
# config_dict_or_path=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
# )
# except Exception:
# zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑
# with zero_init_ctx:
# model = AutoModelForCausalLM.from_pretrained(
# args.model_name_or_path,
# torch_dtype=dtype,
# low_cpu_mem_usage=False,
# trust_remote_code=True,
# attn_implementation="sdpa"
# )
# 交给插件做 ZeRO-Init/分片加载 # 交给插件做 ZeRO-Init/分片加载
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path, args.model_name_or_path,
@ -613,15 +568,6 @@ def main():
attn_implementation="sdpa", attn_implementation="sdpa",
) )
# model = AutoModelForCausalLM.from_pretrained(
# args.model_name_or_path,
# torch_dtype=dtype,
# low_cpu_mem_usage=True,
# trust_remote_code=True,
# attn_implementation="sdpa"
# )
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True) print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
dbg(f"model loaded: dtype={next(model.parameters()).dtype} " dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
f"use_cache={getattr(model.config,'use_cache',None)} " f"use_cache={getattr(model.config,'use_cache',None)} "