This commit is contained in:
hailin 2025-09-03 16:27:26 +08:00
parent a76284e116
commit b054b9e805
1 changed files with 4 additions and 58 deletions

View File

@ -11,8 +11,6 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset, Dataset
# from contextlib import nullcontext
from datasets import load_dataset
from transformers import (
@ -68,7 +66,7 @@ except Exception:
# 3) 统一 JIT 缓存目录(可选,但更稳;日志里你现在用的是 ~/.cache/torch_extensions
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
os.environ.setdefault("MAX_JOBS", "8")
os.environ.setdefault("MAX_JOBS", "12")
import shutil
if shutil.which("ninja") is None:
@ -124,7 +122,6 @@ class DebugTrainer(Trainer):
self._dbg_printed = True
return super().training_step(model, inputs, num_items_in_batch)
# ----------------- 日志回调 -----------------
class CsvLossLogger(TrainerCallback):
def __init__(self, csv_path: str):
@ -189,9 +186,6 @@ class CsvLossLogger(TrainerCallback):
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
)
# ----------------- 仅监督 assistant 的数据集 -----------------
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
"""
@ -353,8 +347,6 @@ class QwenChatSFTDataset(IterableDataset):
"labels": torch.tensor(labels, dtype=torch.long),
}
# ----------------- 专用 Collatorpad inputs, pad labels=-100 -----------------
class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
@ -481,11 +473,6 @@ def main():
if args.report_to == "wandb":
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
# -------- 调试打印工具(每个 rank 都打)--------
# host = socket.gethostname()
# 版本 & 启动参数 & 关键环境变量
import transformers as hf
try:
@ -493,6 +480,7 @@ def main():
ds_ver = ds.__version__
except Exception:
ds_ver = "n/a"
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
dbg(f"args={args}")
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
@ -506,7 +494,6 @@ def main():
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
# ---- 初始化分布式(供一致性探针使用)----
world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank = int(os.environ.get("RANK", "0"))
@ -541,8 +528,6 @@ def main():
tokenizer.pad_token = tokenizer.eos_token
# 左侧补齐以匹配 Dataset 的左 pad 策略
try:
if getattr(tokenizer, "padding_side", None) != "left":
@ -556,8 +541,6 @@ def main():
raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。")
tokenizer.model_max_length = args.seq_len
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
@ -576,34 +559,6 @@ def main():
dtype = (torch.bfloat16 if use_bf16 else
(torch.float16 if torch.cuda.is_available() else torch.float32))
# dschf = None
# if args.deepspeed and os.path.isfile(args.deepspeed):
# dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件
# dbg("HfDeepSpeedConfig loaded")
# try:
# import deepspeed
# zero_init_ctx = deepspeed.zero.Init(
# remote_device="cpu", # 参数最终托管在 CPU可结合 offload
# device="cpu", # ← 关键:不要用 meta
# pin_memory=True,
# dtype=dtype,
# config_dict_or_path=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
# )
# except Exception:
# zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑
# with zero_init_ctx:
# model = AutoModelForCausalLM.from_pretrained(
# args.model_name_or_path,
# torch_dtype=dtype,
# low_cpu_mem_usage=False,
# trust_remote_code=True,
# attn_implementation="sdpa"
# )
# 交给插件做 ZeRO-Init/分片加载
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
@ -613,15 +568,6 @@ def main():
attn_implementation="sdpa",
)
# model = AutoModelForCausalLM.from_pretrained(
# args.model_name_or_path,
# torch_dtype=dtype,
# low_cpu_mem_usage=True,
# trust_remote_code=True,
# attn_implementation="sdpa"
# )
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
f"use_cache={getattr(model.config,'use_cache',None)} "