This commit is contained in:
hailin 2025-08-25 20:30:38 +08:00
parent 4e59c138ea
commit ab55cd17e6
1 changed files with 30 additions and 17 deletions

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import glob import glob
import socket import socket
import argparse import argparse
@ -158,35 +159,36 @@ class QwenChatSFTDataset(IterableDataset):
# ----------------- 专用 Collatorpad inputs, pad labels=-100 ----------------- # ----------------- 专用 Collatorpad inputs, pad labels=-100 -----------------
class SFTDataCollator: class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer): def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
self.tok = tokenizer self.tok = tokenizer
assert self.tok.pad_token_id is not None, "tokenizer.pad_token 不能为空;已在主函数里兜底为 eos_token" self.pad_to_length = pad_to_length
assert self.tok.pad_token_id is not None
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: def __call__(self, features):
# 将变长样本对齐到 batch 内最大长度labels 用 -100 补齐 def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
def _to_list(x): input_ids = [_to_list(f["input_ids"]) for f in features]
return x.tolist() if isinstance(x, torch.Tensor) else list(x) attn_masks = [_to_list(f["attention_mask"]) for f in features]
input_ids = [ _to_list(f["input_ids"]) for f in features ] labels_list = [_to_list(f["labels"]) for f in features]
attn_masks = [ _to_list(f["attention_mask"]) for f in features ]
labels_list = [ _to_list(f["labels"]) for f in features ] max_len_in_batch = max(len(x) for x in input_ids)
target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch
max_len = max(len(x) for x in input_ids)
pad_id = self.tok.pad_token_id pad_id = self.tok.pad_token_id
batch_inp, batch_attn, batch_lab = [], [], [] batch_inp, batch_attn, batch_lab = [], [], []
for inp, msk, lab in zip(input_ids, attn_masks, labels_list): for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
pad_len = max_len - len(inp) pad_len = target_len - len(inp)
if pad_len < 0:
inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len]
pad_len = 0
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long)) batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long)) batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long)) batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
return { return {
"input_ids": torch.stack(batch_inp, dim=0), "input_ids": torch.stack(batch_inp, dim=0),
"attention_mask": torch.stack(batch_attn, dim=0), "attention_mask": torch.stack(batch_attn, dim=0),
"labels": torch.stack(batch_lab, dim=0), "labels": torch.stack(batch_lab, dim=0),
} }
# ----------------- 参数 ----------------- # ----------------- 参数 -----------------
def parse_args(): def parse_args():
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
@ -227,18 +229,27 @@ def main():
args = parse_args() args = parse_args()
set_seed(args.seed) set_seed(args.seed)
# Tokenizer/Model # 1) 先补 tokenizer 的 pad
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None: if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # 供 padding 使用 tokenizer.pad_token = tokenizer.eos_token # 供 padding 使用
# 可选:让警告更少
tokenizer.model_max_length = args.seq_len
# 2) 再加载模型
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path, args.model_name_or_path,
torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16), torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
trust_remote_code=True trust_remote_code=True
) )
# 3) 最后对齐模型的 pad_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False # 训练时禁用 cache model.config.use_cache = False # 训练时禁用 cache
if args.gradient_checkpointing: if args.gradient_checkpointing:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
@ -336,7 +347,7 @@ def main():
if len(eval_samples) > 0: if len(eval_samples) > 0:
eval_dataset = ListDataset(eval_samples) eval_dataset = ListDataset(eval_samples)
data_collator = SFTDataCollator(tokenizer) data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs") logging_dir = os.path.join(args.output_dir, "logs")
@ -376,6 +387,8 @@ def main():
deepspeed=args.deepspeed, deepspeed=args.deepspeed,
dataloader_drop_last=True, dataloader_drop_last=True,
dataloader_num_workers=0, dataloader_num_workers=0,
dataloader_prefetch_factor=None,
dataloader_pin_memory=False,
report_to=([] if args.report_to == "none" else [args.report_to]), report_to=([] if args.report_to == "none" else [args.report_to]),
bf16=args.bf16, bf16=args.bf16,
fp16=(not args.bf16), fp16=(not args.bf16),
@ -392,7 +405,7 @@ def main():
args=training_args, args=training_args,
train_dataset=train_stream, train_dataset=train_stream,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
tokenizer=tokenizer, processing_class=tokenizer,
data_collator=data_collator data_collator=data_collator
) )
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv"))) trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))