This commit is contained in:
parent
4739fc615d
commit
f273231200
194
train_sft_ds.py
194
train_sft_ds.py
|
|
@ -21,7 +21,7 @@ from transformers import (
|
||||||
set_seed
|
set_seed
|
||||||
)
|
)
|
||||||
from transformers.trainer_callback import TrainerCallback
|
from transformers.trainer_callback import TrainerCallback
|
||||||
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
|
||||||
# ----------------- 进程工具 -----------------
|
# ----------------- 进程工具 -----------------
|
||||||
def is_main_process():
|
def is_main_process():
|
||||||
|
|
@ -141,9 +141,17 @@ class QwenChatSFTDataset(IterableDataset):
|
||||||
|
|
||||||
tools = ex.get("tools", None)
|
tools = ex.get("tools", None)
|
||||||
|
|
||||||
rendered: str = self.tok.apply_chat_template(
|
# 兼容老版本 tokenizer.apply_chat_template 不支持 tools 参数的情况
|
||||||
msgs, tools=tools, add_generation_prompt=False, tokenize=False
|
try:
|
||||||
)
|
rendered: str = self.tok.apply_chat_template(
|
||||||
|
msgs, tools=tools, add_generation_prompt=False, tokenize=False
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
rendered: str = self.tok.apply_chat_template(
|
||||||
|
msgs, add_generation_prompt=False, tokenize=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if not isinstance(rendered, str) or not rendered.strip():
|
if not isinstance(rendered, str) or not rendered.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -319,6 +327,9 @@ def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
|
if args.report_to == "wandb":
|
||||||
|
os.environ.setdefault("WANDB_PROJECT", args.wandb_project)
|
||||||
|
|
||||||
|
|
||||||
# -------- 调试打印工具(每个 rank 都打)--------
|
# -------- 调试打印工具(每个 rank 都打)--------
|
||||||
host = socket.gethostname()
|
host = socket.gethostname()
|
||||||
|
|
@ -383,14 +394,47 @@ def main():
|
||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 左侧补齐以匹配 Dataset 的左 pad 策略
|
||||||
|
try:
|
||||||
|
if getattr(tokenizer, "padding_side", None) != "left":
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 强制要求 fast tokenizer(offset_mapping 依赖 fast)
|
||||||
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
if not isinstance(tokenizer, PreTrainedTokenizerFast) or not getattr(tokenizer, "is_fast", False):
|
||||||
|
raise RuntimeError("需要 *Fast* tokenizer 以获取 offset_mapping;请安装 tokenizers>=0.14 并使用对应 Fast 版分词器。")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
tokenizer.model_max_length = args.seq_len
|
tokenizer.model_max_length = args.seq_len
|
||||||
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
|
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
|
||||||
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
|
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
|
||||||
|
|
||||||
# 2) 再加载模型
|
# 2) 再加载模型 之前,先算 dtype
|
||||||
|
def _bf16_supported():
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return False
|
||||||
|
# 兼容不同 torch 版本:优先用 API,退化到算力判断
|
||||||
|
if hasattr(torch.cuda, "is_bf16_supported"):
|
||||||
|
return torch.cuda.is_bf16_supported()
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
return (major, minor) >= (8, 0) # Ampere 及以上
|
||||||
|
|
||||||
|
use_bf16 = bool(args.bf16 and _bf16_supported())
|
||||||
|
dtype = (torch.bfloat16 if use_bf16 else
|
||||||
|
(torch.float16 if torch.cuda.is_available() else torch.float32))
|
||||||
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
args.model_name_or_path,
|
args.model_name_or_path,
|
||||||
torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
|
# torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
|
||||||
|
torch_dtype=dtype,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
trust_remote_code=True
|
trust_remote_code=True
|
||||||
)
|
)
|
||||||
|
|
@ -596,7 +640,43 @@ def main():
|
||||||
elif "evaluation_strategy" in sig:
|
elif "evaluation_strategy" in sig:
|
||||||
ta_kwargs["evaluation_strategy"] = "no"
|
ta_kwargs["evaluation_strategy"] = "no"
|
||||||
|
|
||||||
training_args = TrainingArguments(
|
# training_args = TrainingArguments(
|
||||||
|
# output_dir=args.output_dir,
|
||||||
|
# logging_dir=logging_dir,
|
||||||
|
# do_train=True,
|
||||||
|
# do_eval=(eval_dataset is not None),
|
||||||
|
# eval_steps=max(50, args.save_steps // 5) if eval_dataset is not None else None,
|
||||||
|
# per_device_train_batch_size=args.per_device_train_batch_size,
|
||||||
|
# gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
# learning_rate=args.learning_rate,
|
||||||
|
# weight_decay=args.weight_decay,
|
||||||
|
# warmup_ratio=args.warmup_ratio,
|
||||||
|
# num_train_epochs=args.num_train_epochs if args.max_steps < 0 else 1.0,
|
||||||
|
# max_steps=args.max_steps if args.max_steps > 0 else -1,
|
||||||
|
# lr_scheduler_type="cosine",
|
||||||
|
# logging_steps=args.log_interval,
|
||||||
|
# save_steps=args.save_steps,
|
||||||
|
# save_total_limit=2,
|
||||||
|
# deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
|
||||||
|
# dataloader_drop_last=False, # 关键:别丢尾,避免空 batch
|
||||||
|
# dataloader_num_workers=0,
|
||||||
|
# dataloader_prefetch_factor=None,
|
||||||
|
# dataloader_pin_memory=False,
|
||||||
|
# per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||||
|
# report_to=([] if args.report_to == "none" else [args.report_to]),
|
||||||
|
# bf16=args.bf16,
|
||||||
|
# fp16=(not args.bf16),
|
||||||
|
# gradient_checkpointing=args.gradient_checkpointing,
|
||||||
|
# remove_unused_columns=False,
|
||||||
|
# torch_compile=False,
|
||||||
|
# save_on_each_node=True,
|
||||||
|
# logging_first_step=True,
|
||||||
|
# **ta_kwargs,
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
ta_sig = inspect.signature(TrainingArguments.__init__).parameters
|
||||||
|
ta_kwargs2 = dict(
|
||||||
output_dir=args.output_dir,
|
output_dir=args.output_dir,
|
||||||
logging_dir=logging_dir,
|
logging_dir=logging_dir,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
|
|
@ -614,22 +694,38 @@ def main():
|
||||||
save_steps=args.save_steps,
|
save_steps=args.save_steps,
|
||||||
save_total_limit=2,
|
save_total_limit=2,
|
||||||
deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
|
deepspeed=(args.deepspeed if args.deepspeed and os.path.isfile(args.deepspeed) else None),
|
||||||
dataloader_drop_last=False, # 关键:别丢尾,避免空 batch
|
dataloader_drop_last=False,
|
||||||
dataloader_num_workers=0,
|
dataloader_num_workers=0,
|
||||||
dataloader_prefetch_factor=None,
|
|
||||||
dataloader_pin_memory=False,
|
|
||||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||||
report_to=([] if args.report_to == "none" else [args.report_to]),
|
report_to=([] if args.report_to == "none" else [args.report_to]),
|
||||||
bf16=args.bf16,
|
bf16=args.bf16,
|
||||||
fp16=(not args.bf16),
|
fp16=(not args.bf16),
|
||||||
gradient_checkpointing=args.gradient_checkpointing,
|
gradient_checkpointing=args.gradient_checkpointing,
|
||||||
remove_unused_columns=False,
|
remove_unused_columns=False,
|
||||||
torch_compile=False,
|
save_on_each_node=True,
|
||||||
save_on_each_node=False,
|
|
||||||
logging_first_step=True,
|
logging_first_step=True,
|
||||||
**ta_kwargs,
|
**ta_kwargs, # 你之前构造的 eval_strategy 兼容项
|
||||||
)
|
)
|
||||||
|
# if "dataloader_prefetch_factor" in ta_sig:
|
||||||
|
# ta_kwargs2["dataloader_prefetch_factor"] = None
|
||||||
|
if "dataloader_pin_memory" in ta_sig:
|
||||||
|
ta_kwargs2["dataloader_pin_memory"] = False
|
||||||
|
if "torch_compile" in ta_sig:
|
||||||
|
ta_kwargs2["torch_compile"] = False
|
||||||
|
|
||||||
|
# 构造 TrainingArguments 之前,沿用上面的 use_bf16 判定
|
||||||
|
ta_kwargs2.update({
|
||||||
|
"bf16": use_bf16,
|
||||||
|
"fp16": (torch.cuda.is_available() and not use_bf16),
|
||||||
|
})
|
||||||
|
|
||||||
|
training_args = TrainingArguments(**ta_kwargs2)
|
||||||
|
|
||||||
|
trainer_kwargs = {}
|
||||||
|
if "processing_class" in inspect.signature(Trainer.__init__).parameters:
|
||||||
|
trainer_kwargs["processing_class"] = tokenizer
|
||||||
|
else:
|
||||||
|
trainer_kwargs["tokenizer"] = tokenizer
|
||||||
|
|
||||||
trainer = DebugTrainer(
|
trainer = DebugTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -637,25 +733,75 @@ def main():
|
||||||
train_dataset=train_stream,
|
train_dataset=train_stream,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
#tokenizer=tokenizer,
|
#tokenizer=tokenizer,
|
||||||
processing_class=tokenizer,
|
#processing_class=tokenizer,
|
||||||
data_collator=data_collator
|
data_collator=data_collator,
|
||||||
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||||||
|
|
||||||
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
|
|
||||||
# ckpt_exists = (os.path.isdir(args.output_dir)
|
|
||||||
# and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir)))
|
|
||||||
# resume_flag = True if ckpt_exists else None
|
|
||||||
|
|
||||||
ckpt_local = 1 if (os.path.isdir(args.output_dir) and any(n.startswith("checkpoint-") for n in os.listdir(args.output_dir))) else 0
|
|
||||||
ckpt_tensor = torch.tensor(ckpt_local, device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu"))
|
# ==== 断点恢复判定(非共享盘安全写法)====
|
||||||
|
def last_step(path: str) -> int:
|
||||||
|
ck = get_last_checkpoint(path)
|
||||||
|
if ck is None:
|
||||||
|
return -1
|
||||||
|
base = os.path.basename(ck)
|
||||||
|
try:
|
||||||
|
return int(base.split("-")[-1])
|
||||||
|
except Exception:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
local_last = last_step(args.output_dir) # -1 表示本机没有任何 checkpoint
|
||||||
|
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")
|
||||||
|
|
||||||
|
resume_flag = None
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
dist.all_reduce(ckpt_tensor, op=dist.ReduceOp.MAX)
|
# 只要有任意一个 rank 没有 ckpt -> 不恢复
|
||||||
resume_flag = True if ckpt_tensor.item() > 0 else None
|
has_local = torch.tensor(1 if local_last >= 0 else 0, device=device)
|
||||||
|
dist.all_reduce(has_local, op=dist.ReduceOp.MIN)
|
||||||
|
if has_local.item() == 1:
|
||||||
|
# 全员都有:收集每个 rank 的 last step,取公共最小步 k(每台机器都一定存在)
|
||||||
|
ts = torch.tensor(local_last, device=device)
|
||||||
|
world = dist.get_world_size()
|
||||||
|
buf = [torch.zeros_like(ts) for _ in range(world)]
|
||||||
|
dist.all_gather(buf, ts)
|
||||||
|
steps = [b.item() for b in buf]
|
||||||
|
k = min(steps)
|
||||||
|
if k >= 0:
|
||||||
|
resume_flag = os.path.join(args.output_dir, f"checkpoint-{k}")
|
||||||
|
if is_main_process():
|
||||||
|
print(f"[resume] steps={steps}, resume={resume_flag}", flush=True)
|
||||||
|
else:
|
||||||
|
# 单机或未初始化分布式:本地有就按本机最后一步恢复
|
||||||
|
if local_last >= 0:
|
||||||
|
resume_flag = os.path.join(args.output_dir, f"checkpoint-{local_last}")
|
||||||
|
|
||||||
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is True}")
|
print_once(f"[host={socket.gethostname()}] Resume = {resume_flag is not None}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# —— 全局一致性检测:如果有任意 rank 缺这个 ckpt,就禁用恢复 ——
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
device = torch.device(f"cuda:{local_rank}" if (torch.cuda.is_available() and local_rank >= 0) else "cpu")
|
||||||
|
present = torch.tensor(1 if (resume_flag is not None and os.path.isdir(resume_flag)) else 0, device=device)
|
||||||
|
dist.all_reduce(present, op=dist.ReduceOp.MIN)
|
||||||
|
if present.item() == 0:
|
||||||
|
if is_main_process():
|
||||||
|
print(f"[resume] {resume_flag} missing on some ranks -> disable resume.", flush=True)
|
||||||
|
resume_flag = None
|
||||||
|
dist.barrier()
|
||||||
|
else:
|
||||||
|
if resume_flag is not None and not os.path.isdir(resume_flag):
|
||||||
|
# 单机:缺就直接禁用恢复
|
||||||
|
print(f"[resume] {resume_flag} not found locally -> disable resume.", flush=True)
|
||||||
|
resume_flag = None
|
||||||
|
|
||||||
|
|
||||||
|
print_once(f"[resume] final = {resume_flag if resume_flag else 'None (fresh start)'}")
|
||||||
print_once("***** Starting training *****")
|
print_once("***** Starting training *****")
|
||||||
|
|
||||||
train_result = trainer.train(resume_from_checkpoint=resume_flag)
|
train_result = trainer.train(resume_from_checkpoint=resume_flag)
|
||||||
trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型
|
trainer.save_model() # DeepSpeed stage3_gather_16bit_weights_on_model_save=true 时,在 rank0 聚合整模型
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue