This commit is contained in:
parent
a76284e116
commit
b054b9e805
|
|
@ -11,8 +11,6 @@ from typing import Dict, List, Iterable, Iterator, Tuple, Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils.data import IterableDataset, Dataset
|
from torch.utils.data import IterableDataset, Dataset
|
||||||
# from contextlib import nullcontext
|
|
||||||
|
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
|
@ -68,7 +66,7 @@ except Exception:
|
||||||
|
|
||||||
# 3) 统一 JIT 缓存目录(可选,但更稳;日志里你现在用的是 ~/.cache/torch_extensions)
|
# 3) 统一 JIT 缓存目录(可选,但更稳;日志里你现在用的是 ~/.cache/torch_extensions)
|
||||||
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
|
os.environ.setdefault("TORCH_EXTENSIONS_DIR", f"/tmp/{os.environ.get('USER','user')}/torch_ext")
|
||||||
os.environ.setdefault("MAX_JOBS", "8")
|
os.environ.setdefault("MAX_JOBS", "12")
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
if shutil.which("ninja") is None:
|
if shutil.which("ninja") is None:
|
||||||
|
|
@ -124,7 +122,6 @@ class DebugTrainer(Trainer):
|
||||||
self._dbg_printed = True
|
self._dbg_printed = True
|
||||||
return super().training_step(model, inputs, num_items_in_batch)
|
return super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
|
||||||
|
|
||||||
# ----------------- 日志回调 -----------------
|
# ----------------- 日志回调 -----------------
|
||||||
class CsvLossLogger(TrainerCallback):
|
class CsvLossLogger(TrainerCallback):
|
||||||
def __init__(self, csv_path: str):
|
def __init__(self, csv_path: str):
|
||||||
|
|
@ -189,9 +186,6 @@ class CsvLossLogger(TrainerCallback):
|
||||||
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
|
f"{cur},{logs.get('loss','')},{logs.get('learning_rate','')},{logs.get('total_flos','')}\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------- 仅监督 assistant 的数据集 -----------------
|
# ----------------- 仅监督 assistant 的数据集 -----------------
|
||||||
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
def _assistant_char_spans(rendered: str) -> List[Tuple[int, int]]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -353,8 +347,6 @@ class QwenChatSFTDataset(IterableDataset):
|
||||||
"labels": torch.tensor(labels, dtype=torch.long),
|
"labels": torch.tensor(labels, dtype=torch.long),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------- 专用 Collator:pad inputs, pad labels=-100 -----------------
|
# ----------------- 专用 Collator:pad inputs, pad labels=-100 -----------------
|
||||||
class SFTDataCollator:
|
class SFTDataCollator:
|
||||||
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
|
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
|
||||||
|
|
@ -481,11 +473,6 @@ def main():
|
||||||
if args.report_to == "wandb":
|
if args.report_to == "wandb":
|
||||||
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
|
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
|
||||||
|
|
||||||
|
|
||||||
# -------- 调试打印工具(每个 rank 都打)--------
|
|
||||||
# host = socket.gethostname()
|
|
||||||
|
|
||||||
|
|
||||||
# 版本 & 启动参数 & 关键环境变量
|
# 版本 & 启动参数 & 关键环境变量
|
||||||
import transformers as hf
|
import transformers as hf
|
||||||
try:
|
try:
|
||||||
|
|
@ -493,6 +480,7 @@ def main():
|
||||||
ds_ver = ds.__version__
|
ds_ver = ds.__version__
|
||||||
except Exception:
|
except Exception:
|
||||||
ds_ver = "n/a"
|
ds_ver = "n/a"
|
||||||
|
|
||||||
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
|
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
|
||||||
dbg(f"args={args}")
|
dbg(f"args={args}")
|
||||||
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
|
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
|
||||||
|
|
@ -506,7 +494,6 @@ def main():
|
||||||
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
|
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---- 初始化分布式(供一致性探针使用)----
|
# ---- 初始化分布式(供一致性探针使用)----
|
||||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||||
rank = int(os.environ.get("RANK", "0"))
|
rank = int(os.environ.get("RANK", "0"))
|
||||||
|
|
@ -541,8 +528,6 @@ def main():
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 左侧补齐以匹配 Dataset 的左 pad 策略
|
# 左侧补齐以匹配 Dataset 的左 pad 策略
|
||||||
try:
|
try:
|
||||||
if getattr(tokenizer, "padding_side", None) != "left":
|
if getattr(tokenizer, "padding_side", None) != "left":
|
||||||
|
|
@ -556,8 +541,6 @@ def main():
|
||||||
raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。")
|
raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
tokenizer.model_max_length = args.seq_len
|
tokenizer.model_max_length = args.seq_len
|
||||||
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
|
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
|
||||||
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
|
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
|
||||||
|
|
@ -576,34 +559,6 @@ def main():
|
||||||
dtype = (torch.bfloat16 if use_bf16 else
|
dtype = (torch.bfloat16 if use_bf16 else
|
||||||
(torch.float16 if torch.cuda.is_available() else torch.float32))
|
(torch.float16 if torch.cuda.is_available() else torch.float32))
|
||||||
|
|
||||||
# dschf = None
|
|
||||||
# if args.deepspeed and os.path.isfile(args.deepspeed):
|
|
||||||
# dschf = HfDeepSpeedConfig(args.deepspeed) # ← 关键:提前启用插件
|
|
||||||
# dbg("HfDeepSpeedConfig loaded")
|
|
||||||
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# import deepspeed
|
|
||||||
# zero_init_ctx = deepspeed.zero.Init(
|
|
||||||
# remote_device="cpu", # 参数最终托管在 CPU(可结合 offload)
|
|
||||||
# device="cpu", # ← 关键:不要用 meta
|
|
||||||
# pin_memory=True,
|
|
||||||
# dtype=dtype,
|
|
||||||
# config_dict_or_path=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
|
|
||||||
# )
|
|
||||||
# except Exception:
|
|
||||||
# zero_init_ctx = nullcontext() # 没装 DS 时也能单机跑
|
|
||||||
|
|
||||||
# with zero_init_ctx:
|
|
||||||
# model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
# args.model_name_or_path,
|
|
||||||
# torch_dtype=dtype,
|
|
||||||
# low_cpu_mem_usage=False,
|
|
||||||
# trust_remote_code=True,
|
|
||||||
# attn_implementation="sdpa"
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
# 交给插件做 ZeRO-Init/分片加载
|
# 交给插件做 ZeRO-Init/分片加载
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
args.model_name_or_path,
|
args.model_name_or_path,
|
||||||
|
|
@ -613,15 +568,6 @@ def main():
|
||||||
attn_implementation="sdpa",
|
attn_implementation="sdpa",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
# args.model_name_or_path,
|
|
||||||
# torch_dtype=dtype,
|
|
||||||
# low_cpu_mem_usage=True,
|
|
||||||
# trust_remote_code=True,
|
|
||||||
# attn_implementation="sdpa"
|
|
||||||
# )
|
|
||||||
|
|
||||||
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
|
print(f"GC enabled? {getattr(model, 'is_gradient_checkpointing', False)}", flush=True)
|
||||||
dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
|
dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
|
||||||
f"use_cache={getattr(model.config,'use_cache',None)} "
|
f"use_cache={getattr(model.config,'use_cache',None)} "
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue