This commit is contained in:
parent
4e59c138ea
commit
ab55cd17e6
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||||
import glob
|
||||
import socket
|
||||
import argparse
|
||||
|
|
@ -158,35 +159,36 @@ class QwenChatSFTDataset(IterableDataset):
|
|||
|
||||
# ----------------- 专用 Collator:pad inputs, pad labels=-100 -----------------
|
||||
class SFTDataCollator:
|
||||
def __init__(self, tokenizer: AutoTokenizer):
|
||||
def __init__(self, tokenizer: AutoTokenizer, pad_to_length: Optional[int] = None):
|
||||
self.tok = tokenizer
|
||||
assert self.tok.pad_token_id is not None, "tokenizer.pad_token 不能为空;已在主函数里兜底为 eos_token"
|
||||
self.pad_to_length = pad_to_length
|
||||
assert self.tok.pad_token_id is not None
|
||||
|
||||
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
||||
# 将变长样本对齐到 batch 内最大长度;labels 用 -100 补齐
|
||||
def _to_list(x):
|
||||
return x.tolist() if isinstance(x, torch.Tensor) else list(x)
|
||||
input_ids = [ _to_list(f["input_ids"]) for f in features ]
|
||||
attn_masks = [ _to_list(f["attention_mask"]) for f in features ]
|
||||
labels_list = [ _to_list(f["labels"]) for f in features ]
|
||||
def __call__(self, features):
|
||||
def _to_list(x): return x.tolist() if isinstance(x, torch.Tensor) else list(x)
|
||||
input_ids = [_to_list(f["input_ids"]) for f in features]
|
||||
attn_masks = [_to_list(f["attention_mask"]) for f in features]
|
||||
labels_list = [_to_list(f["labels"]) for f in features]
|
||||
|
||||
max_len_in_batch = max(len(x) for x in input_ids)
|
||||
target_len = self.pad_to_length if self.pad_to_length is not None else max_len_in_batch
|
||||
|
||||
max_len = max(len(x) for x in input_ids)
|
||||
pad_id = self.tok.pad_token_id
|
||||
|
||||
batch_inp, batch_attn, batch_lab = [], [], []
|
||||
for inp, msk, lab in zip(input_ids, attn_masks, labels_list):
|
||||
pad_len = max_len - len(inp)
|
||||
pad_len = target_len - len(inp)
|
||||
if pad_len < 0:
|
||||
inp, msk, lab = inp[:target_len], msk[:target_len], lab[:target_len]
|
||||
pad_len = 0
|
||||
batch_inp.append(torch.tensor(inp + [pad_id]*pad_len, dtype=torch.long))
|
||||
batch_attn.append(torch.tensor(msk + [0]*pad_len, dtype=torch.long))
|
||||
batch_lab.append(torch.tensor(lab + [-100]*pad_len, dtype=torch.long))
|
||||
|
||||
return {
|
||||
"input_ids": torch.stack(batch_inp, dim=0),
|
||||
"attention_mask": torch.stack(batch_attn, dim=0),
|
||||
"labels": torch.stack(batch_lab, dim=0),
|
||||
}
|
||||
|
||||
|
||||
# ----------------- 参数 -----------------
|
||||
def parse_args():
|
||||
ap = argparse.ArgumentParser()
|
||||
|
|
@ -227,18 +229,27 @@ def main():
|
|||
args = parse_args()
|
||||
set_seed(args.seed)
|
||||
|
||||
# Tokenizer/Model
|
||||
# 1) 先补 tokenizer 的 pad
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token # 供 padding 使用
|
||||
|
||||
# 可选:让警告更少
|
||||
tokenizer.model_max_length = args.seq_len
|
||||
|
||||
# 2) 再加载模型
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
torch_dtype=(torch.bfloat16 if args.bf16 else torch.float16),
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# 3) 最后对齐模型的 pad_token_id
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
model.config.use_cache = False # 训练时禁用 cache
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
|
|
@ -336,7 +347,7 @@ def main():
|
|||
if len(eval_samples) > 0:
|
||||
eval_dataset = ListDataset(eval_samples)
|
||||
|
||||
data_collator = SFTDataCollator(tokenizer)
|
||||
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
logging_dir = os.path.join(args.output_dir, "logs")
|
||||
|
|
@ -376,6 +387,8 @@ def main():
|
|||
deepspeed=args.deepspeed,
|
||||
dataloader_drop_last=True,
|
||||
dataloader_num_workers=0,
|
||||
dataloader_prefetch_factor=None,
|
||||
dataloader_pin_memory=False,
|
||||
report_to=([] if args.report_to == "none" else [args.report_to]),
|
||||
bf16=args.bf16,
|
||||
fp16=(not args.bf16),
|
||||
|
|
@ -392,7 +405,7 @@ def main():
|
|||
args=training_args,
|
||||
train_dataset=train_stream,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
processing_class=tokenizer,
|
||||
data_collator=data_collator
|
||||
)
|
||||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||||
|
|
|
|||
Loading…
Reference in New Issue