This commit is contained in:
hailin 2025-08-25 20:30:38 +08:00
parent 4e59c138ea
commit ab55cd17e6
1 changed files with 30 additions and 17 deletions

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
import os
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import glob
import socket
import argparse
@ -158,35 +159,36 @@ class QwenChatSFTDataset(IterableDataset):
# ----------------- 专用 Collatorpad inputs, pad labels=-100 -----------------
class SFTDataCollator:
def __init__(self, tokenizer: AutoTokenizer):
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
self.tok = tokenizer
assert self.tok.pad_token_id is not None, "tokenizer.pad_token 不能为空;已在主函数里兜底为 eos_token"
self.pad_to_length = pad_to_length
assert self.tok.pad_token_id is not None
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
# 将变长样本对齐到 batch 内最大长度labels 用 -100 补齐
def _to_list(x):
return x.tolist() if isinstance(x, torch.Tensor) else list(x)
input_ids = [ _to_list(f["input_ids"]) for f in features ]
attn_masks = [ _to_list(f["attention_mask"]) for f in features ]
labels_list = [ _to_list(f["labels"]) for f in features ]
def __call__(self, features):
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
input_ids = [_to_list(f["input_ids"]) for f in features]
attn_masks = [_to_list(f["attention_mask"]) for f in features]
labels_list = [_to_list(f["labels"]) for f in features]
max_len_in_batch = max(len(x) for x in input_ids)
target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch
max_len = max(len(x) for x in input_ids)
pad_id = self.tok.pad_token_id
batch_inp, batch_attn, batch_lab = [], [], []
for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
pad_len = max_len - len(inp)
pad_len = target_len - len(inp)
if pad_len < 0:
inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len]
pad_len = 0
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
return {
"input_ids": torch.stack(batch_inp, dim=0),
"attention_mask": torch.stack(batch_attn, dim=0),
"labels": torch.stack(batch_lab, dim=0),
}
# ----------------- 参数 -----------------
def parse_args():
ap = argparse.ArgumentParser()
@ -227,18 +229,27 @@ def main():
args = parse_args()
set_seed(args.seed)
# Tokenizer/Model
# 1) 先补 tokenizer 的 pad
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # 供 padding 使用
# 可选:让警告更少
tokenizer.model_max_length = args.seq_len
# 2) 再加载模型
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
low_cpu_mem_usage=True,
trust_remote_code=True
)
# 3) 最后对齐模型的 pad_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False # 训练时禁用 cache
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
@ -336,7 +347,7 @@ def main():
if len(eval_samples) > 0:
eval_dataset = ListDataset(eval_samples)
data_collator = SFTDataCollator(tokenizer)
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
os.makedirs(args.output_dir, exist_ok=True)
logging_dir = os.path.join(args.output_dir, "logs")
@ -376,6 +387,8 @@ def main():
deepspeed=args.deepspeed,
dataloader_drop_last=True,
dataloader_num_workers=0,
dataloader_prefetch_factor=None,
dataloader_pin_memory=False,
report_to=([] if args.report_to == "none" else [args.report_to]),
bf16=args.bf16,
fp16=(not args.bf16),
@ -392,7 +405,7 @@ def main():
args=training_args,
train_dataset=train_stream,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
processing_class=tokenizer,
data_collator=data_collator
)
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))