This commit is contained in:
parent
a76284e116
commit
b054b9e805
|
|
@ -11,8 +11,6 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset, Dataset
|
||||
# from contextlib import nullcontext
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
|
|
@ -68,7 +66,7 @@ except Exception:
|
|||
|
||||
# 3) 统一 JIT 缓存目录(可选,但更稳;日志里你现在用的是 ~/.cache/torch_extensions)
|
||||
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
|
||||
os.environ.setdefault("MAX_JOBS", "8")
|
||||
os.environ.setdefault("MAX_JOBS", "12")
|
||||
|
||||
import shutil
|
||||
if shutil.which("ninja") is None:
|
||||
|
|
@ -124,7 +122,6 @@ class DebugTrainer(Trainer):
|
|||
self._dbg_printed = True
|
||||
return super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
|
||||
# ----------------- 日志回调 -----------------
|
||||
class CsvLossLogger(TrainerCallback):
|
||||
def __init__(self, csv_path: str):
|
||||
|
|
@ -189,9 +186,6 @@ class CsvLossLogger(TrainerCallback):
|
|||
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
# ----------------- 仅监督 assistant 的数据集 -----------------
|
||||
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
|
|
@ -353,8 +347,6 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
"labels": torch.tensor(labels, dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
|
||||
# ----------------- 专用 Collator:pad inputs, pad labels=-100 -----------------
|
||||
class SFTDataCollator:
|
||||
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
|
||||
|
|
@ -481,11 +473,6 @@ def main():
|
|||
if args.report_to == "wandb":
|
||||
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
|
||||
|
||||
|
||||
# -------- 调试打印工具(每个 rank 都打)--------
|
||||
# host = socket.gethostname()
|
||||
|
||||
|
||||
# 版本 & 启动参数 & 关键环境变量
|
||||
import transformers as hf
|
||||
try:
|
||||
|
|
@ -493,6 +480,7 @@ def main():
|
|||
ds_ver = ds.__version__
|
||||
except Exception:
|
||||
ds_ver = "n/a"
|
||||
|
||||
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
|
||||
dbg(f"args={args}")
|
||||
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
|
||||
|
|
@ -506,7 +494,6 @@ def main():
|
|||
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
|
||||
|
||||
|
||||
|
||||
# ---- 初始化分布式(供一致性探针使用)----
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
|
|
@ -541,8 +528,6 @@ def main():
|
|||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
|
||||
|
||||
# 左侧补齐以匹配 Dataset 的左 pad 策略
|
||||
try:
|
||||
if getattr(tokenizer, "padding_side", None) != "left":
|
||||
|
|
@ -556,8 +541,6 @@ def main():
|
|||
raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。")
|
||||
|
||||
|
||||
|
||||
|
||||
tokenizer.model_max_length = args.seq_len
|
||||
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
|
||||
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
|
||||
|
|
@ -576,34 +559,6 @@ def main():
|
|||
dtype = (torch.bfloat16 if use_bf16 else
|
||||
(torch.float16 if torch.cuda.is_available() else torch.float32))
|
||||
|
||||
# dschf = None
|
||||
# if args.deepspeed and os.path.isfile(args.deepspeed):
|
||||
# dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件
|
||||
# dbg("HfDeepSpeedConfig loaded")
|
||||
|
||||
|
||||
# try:
|
||||
# import deepspeed
|
||||
# zero_init_ctx = deepspeed.zero.Init(
|
||||
# remote_device="cpu", # 参数最终托管在 CPU(可结合 offload)
|
||||
# device="cpu", # ← 关键:不要用 meta
|
||||
# pin_memory=True,
|
||||
# dtype=dtype,
|
||||
# config_dict_or_path=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
|
||||
# )
|
||||
# except Exception:
|
||||
# zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑
|
||||
|
||||
# with zero_init_ctx:
|
||||
# model = AutoModelForCausalLM.from_pretrained(
|
||||
# args.model_name_or_path,
|
||||
# torch_dtype=dtype,
|
||||
# low_cpu_mem_usage=False,
|
||||
# trust_remote_code=True,
|
||||
# attn_implementation="sdpa"
|
||||
# )
|
||||
|
||||
|
||||
# 交给插件做 ZeRO-Init/分片加载
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
|
|
@ -613,15 +568,6 @@ def main():
|
|||
attn_implementation="sdpa",
|
||||
)
|
||||
|
||||
|
||||
# model = AutoModelForCausalLM.from_pretrained(
|
||||
# args.model_name_or_path,
|
||||
# torch_dtype=dtype,
|
||||
# low_cpu_mem_usage=True,
|
||||
# trust_remote_code=True,
|
||||
# attn_implementation="sdpa"
|
||||
# )
|
||||
|
||||
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
|
||||
dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
|
||||
f"use_cache={getattr(model.config,'use_cache',None)} "
|
||||
|
|
|
|||
Loading…
Reference in New Issue