This commit is contained in:
parent
df1710b25d
commit
d57cc03b60
|
|
@ -1,11 +1,19 @@
|
|||
python - <<'PY'
|
||||
import warnings, torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer, AutoModelForCausalLM,
|
||||
Trainer, TrainingArguments, DataCollatorForLanguageModeling,
|
||||
)
|
||||
|
||||
model_id = "sshleifer/tiny-gpt2" # 极小模型
|
||||
tok = AutoTokenizer.from_pretrained(model_id)
|
||||
# 静音警告
|
||||
warnings.filterwarnings("ignore", category=FutureWarning, message=".*clean_up_tokenization_spaces.*")
|
||||
warnings.filterwarnings("ignore", category=UserWarning, message="Was asked to gather.*")
|
||||
|
||||
model_id = "sshleifer/tiny-gpt2"
|
||||
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
tok.pad_token = tok.eos_token
|
||||
tok.clean_up_tokenization_spaces = True # 显式设置,消除 FutureWarning
|
||||
|
||||
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")
|
||||
def tok_fn(ex): return tok(ex["text"], truncation=True, padding="max_length", max_length=64)
|
||||
|
|
@ -13,6 +21,7 @@ ds = ds.map(tok_fn, batched=True, remove_columns=["text"])
|
|||
|
||||
mdl = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
collator = DataCollatorForLanguageModeling(tok, mlm=False)
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir="out-mini",
|
||||
per_device_train_batch_size=2,
|
||||
|
|
@ -21,10 +30,10 @@ args = TrainingArguments(
|
|||
logging_steps=2,
|
||||
save_steps=10,
|
||||
report_to="none",
|
||||
max_grad_norm=1.0, # 可选:顺手收敛梯度范数
|
||||
)
|
||||
|
||||
trainer = Trainer(model=mdl, args=args, train_dataset=ds, data_collator=collator)
|
||||
trainer.train()
|
||||
print("✅ 训练链路 OK")
|
||||
PY
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue