This commit is contained in:
parent
4739fc615d
commit
f273231200
188
train_sft_ds.py
188
train_sft_ds.py
|
|
@ -21,7 +21,7 @@ from transformers import (
|
|||
set_seed
|
||||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
# ----------------- 进程工具 -----------------
|
||||
def is_main_process():
|
||||
|
|
@ -141,9 +141,17 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
|
||||
tools = ex.get("tools", None)
|
||||
|
||||
# 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况
|
||||
try:
|
||||
rendered: str = self.tok.apply_chat_template(
|
||||
msgs, tools=tools, add_generation_prompt=False, tokenize=False
|
||||
)
|
||||
except TypeError:
|
||||
rendered: str = self.tok.apply_chat_template(
|
||||
msgs, add_generation_prompt=False, tokenize=False
|
||||
)
|
||||
|
||||
|
||||
if not isinstance(rendered, str) or not rendered.strip():
|
||||
continue
|
||||
|
||||
|
|
@ -319,6 +327,9 @@ def main():
|
|||
args = parse_args()
|
||||
set_seed(args.seed)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
|
||||
|
||||
|
||||
# -------- 调试打印工具(每个 rank 都打)--------
|
||||
host = socket.gethostname()
|
||||
|
|
@ -383,14 +394,47 @@ def main():
|
|||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
|
||||
|
||||
# 左侧补齐以匹配 Dataset 的左 pad 策略
|
||||
try:
|
||||
if getattr(tokenizer, "padding_side", None) != "left":
|
||||
tokenizer.padding_side = "left"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 强制要求 fast tokenizer(offset_mapping 依赖 fast)
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False):
|
||||
raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。")
|
||||
|
||||
|
||||
|
||||
|
||||
tokenizer.model_max_length = args.seq_len
|
||||
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
|
||||
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
|
||||
|
||||
# 2) 再加载模型
|
||||
# 2) 再加载模型 之前,先算 dtype
|
||||
def _bf16_supported():
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
# 兼容不同 torch 版本:优先用 API,退化到算力判断
|
||||
if hasattr(torch.cuda, "is_bf16_supported"):
|
||||
return torch.cuda.is_bf16_supported()
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
return (major, minor) >= (8, 0) # Ampere 及以上
|
||||
|
||||
use_bf16 = bool(args.bf16 and _bf16_supported())
|
||||
dtype = (torch.bfloat16 if use_bf16 else
|
||||
(torch.float16 if torch.cuda.is_available() else torch.float32))
|
||||
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
|
||||
# torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
|
||||
torch_dtype=dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
|
@ -596,7 +640,43 @@ def main():
|
|||
elif "evaluation_strategy" in sig:
|
||||
ta_kwargs["evaluation_strategy"] = "no"
|
||||
|
||||
training_args = TrainingArguments(
|
||||
# training_args = TrainingArguments(
|
||||
# output_dir=args.output_dir,
|
||||
# logging_dir=logging_dir,
|
||||
# do_train=True,
|
||||
# do_eval=(eval_dataset is not None),
|
||||
# eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
|
||||
# per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
# gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
# learning_rate=args.learning_rate,
|
||||
# weight_decay=args.weight_decay,
|
||||
# warmup_ratio=args.warmup_ratio,
|
||||
# num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
|
||||
# max_steps=args.max_steps if args.max_steps > 0 else -1,
|
||||
# lr_scheduler_type="cosine",
|
||||
# logging_steps=args.log_interval,
|
||||
# save_steps=args.save_steps,
|
||||
# save_total_limit=2,
|
||||
# deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
|
||||
# dataloader_drop_last=False, # 关键:别丢尾,避免空 batch
|
||||
# dataloader_num_workers=0,
|
||||
# dataloader_prefetch_factor=None,
|
||||
# dataloader_pin_memory=False,
|
||||
# per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
# report_to=([] if args.report_to == "none" else [args.report_to]),
|
||||
# bf16=args.bf16,
|
||||
# fp16=(not args.bf16),
|
||||
# gradient_checkpointing=args.gradient_checkpointing,
|
||||
# remove_unused_columns=False,
|
||||
# torch_compile=False,
|
||||
# save_on_each_node=True,
|
||||
# logging_first_step=True,
|
||||
# **ta_kwargs,
|
||||
# )
|
||||
|
||||
|
||||
ta_sig = inspect.signature(TrainingArguments.__init__).parameters
|
||||
ta_kwargs2 = dict(
|
||||
output_dir=args.output_dir,
|
||||
logging_dir=logging_dir,
|
||||
do_train=True,
|
||||
|
|
@ -614,22 +694,38 @@ def main():
|
|||
save_steps=args.save_steps,
|
||||
save_total_limit=2,
|
||||
deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
|
||||
dataloader_drop_last=False, # 关键:别丢尾,避免空 batch
|
||||
dataloader_drop_last=False,
|
||||
dataloader_num_workers=0,
|
||||
dataloader_prefetch_factor=None,
|
||||
dataloader_pin_memory=False,
|
||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
report_to=([] if args.report_to == "none" else [args.report_to]),
|
||||
bf16=args.bf16,
|
||||
fp16=(not args.bf16),
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
remove_unused_columns=False,
|
||||
torch_compile=False,
|
||||
save_on_each_node=False,
|
||||
save_on_each_node=True,
|
||||
logging_first_step=True,
|
||||
**ta_kwargs,
|
||||
**ta_kwargs, # 你之前构造的 eval_strategy 兼容项
|
||||
)
|
||||
# if "dataloader_prefetch_factor" in ta_sig:
|
||||
# ta_kwargs2["dataloader_prefetch_factor"] = None
|
||||
if "dataloader_pin_memory" in ta_sig:
|
||||
ta_kwargs2["dataloader_pin_memory"] = False
|
||||
if "torch_compile" in ta_sig:
|
||||
ta_kwargs2["torch_compile"] = False
|
||||
|
||||
# 构造 TrainingArguments 之前,沿用上面的 use_bf16 判定
|
||||
ta_kwargs2.update({
|
||||
"bf16": use_bf16,
|
||||
"fp16": (torch.cuda.is_available() and not use_bf16),
|
||||
})
|
||||
|
||||
training_args = TrainingArguments(**ta_kwargs2)
|
||||
|
||||
trainer_kwargs = {}
|
||||
if "processing_class" in inspect.signature(Trainer.__init__).parameters:
|
||||
trainer_kwargs["processing_class"] = tokenizer
|
||||
else:
|
||||
trainer_kwargs["tokenizer"] = tokenizer
|
||||
|
||||
trainer = DebugTrainer(
|
||||
model=model,
|
||||
|
|
@ -637,25 +733,75 @@ def main():
|
|||
train_dataset=train_stream,
|
||||
eval_dataset=eval_dataset,
|
||||
#tokenizer=tokenizer,
|
||||
processing_class=tokenizer,
|
||||
data_collator=data_collator
|
||||
#processing_class=tokenizer,
|
||||
data_collator=data_collator,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
|
||||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||||
|
||||
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
|
||||
# ckpt_exists = (os.path.isdir(args.output_dir)
|
||||
# and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir)))
|
||||
# resume_flag = True if ckpt_exists else None
|
||||
|
||||
ckpt_local = 1 if (os.path.isdir(args.output_dir) and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir))) else 0
|
||||
ckpt_tensor = torch.tensor(ckpt_local, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu"))
|
||||
|
||||
# ==== 断点恢复判定(非共享盘安全写法)====
|
||||
def last_step(path: str) -> int:
|
||||
ck = get_last_checkpoint(path)
|
||||
if ck is None:
|
||||
return -1
|
||||
base = os.path.basename(ck)
|
||||
try:
|
||||
return int(base.split("-")[-1])
|
||||
except Exception:
|
||||
return -1
|
||||
|
||||
local_last = last_step(args.output_dir) # -1 表示本机没有任何 checkpoint
|
||||
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")
|
||||
|
||||
resume_flag = None
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.all_reduce(ckpt_tensor, op=dist.ReduceOp.MAX)
|
||||
resume_flag = True if ckpt_tensor.item() > 0 else None
|
||||
# 只要有任意一个 rank 没有 ckpt -> 不恢复
|
||||
has_local = torch.tensor(1 if local_last >= 0 else 0, device=device)
|
||||
dist.all_reduce(has_local, op=dist.ReduceOp.MIN)
|
||||
if has_local.item() == 1:
|
||||
# 全员都有:收集每个 rank 的 last step,取公共最小步 k(每台机器都一定存在)
|
||||
ts = torch.tensor(local_last, device=device)
|
||||
world = dist.get_world_size()
|
||||
buf = [torch.zeros_like(ts) for _ in range(world)]
|
||||
dist.all_gather(buf, ts)
|
||||
steps = [b.item() for b in buf]
|
||||
k = min(steps)
|
||||
if k >= 0:
|
||||
resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}")
|
||||
if is_main_process():
|
||||
print(f"[resume] steps={steps}, resume={resume_flag}", flush=True)
|
||||
else:
|
||||
# 单机或未初始化分布式:本地有就按本机最后一步恢复
|
||||
if local_last >= 0:
|
||||
resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}")
|
||||
|
||||
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is True}")
|
||||
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}")
|
||||
|
||||
|
||||
|
||||
# —— 全局一致性检测:如果有任意 rank 缺这个 ckpt,就禁用恢复 ——
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")
|
||||
present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device)
|
||||
dist.all_reduce(present, op=dist.ReduceOp.MIN)
|
||||
if present.item() == 0:
|
||||
if is_main_process():
|
||||
print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True)
|
||||
resume_flag = None
|
||||
dist.barrier()
|
||||
else:
|
||||
if resume_flag is not None and not os.path.isdir(resume_flag):
|
||||
# 单机:缺就直接禁用恢复
|
||||
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
|
||||
resume_flag = None
|
||||
|
||||
|
||||
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
|
||||
print_once("***** Starting training *****")
|
||||
|
||||
train_result = trainer.train(resume_from_checkpoint=resume_flag)
|
||||
trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue