31 lines
992 B
Bash
31 lines
992 B
Bash
python - <<'PY'
|
|
from datasets import load_dataset
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
|
|
import torch
|
|
|
|
model_id = "sshleifer/tiny-gpt2" # 极小模型
|
|
tok = AutoTokenizer.from_pretrained(model_id)
|
|
tok.pad_token = tok.eos_token
|
|
|
|
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")
|
|
def tok_fn(ex): return tok(ex["text"], truncation=True, padding="max_length", max_length=64)
|
|
ds = ds.map(tok_fn, batched=True, remove_columns=["text"])
|
|
|
|
mdl = AutoModelForCausalLM.from_pretrained(model_id)
|
|
collator = DataCollatorForLanguageModeling(tok, mlm=False)
|
|
args = TrainingArguments(
|
|
output_dir="out-mini",
|
|
per_device_train_batch_size=2,
|
|
num_train_epochs=1,
|
|
fp16=torch.cuda.is_available(),
|
|
logging_steps=2,
|
|
save_steps=10,
|
|
report_to="none",
|
|
)
|
|
|
|
trainer = Trainer(model=mdl, args=args, train_dataset=ds, data_collator=collator)
|
|
trainer.train()
|
|
print("✅ 训练链路 OK")
|
|
PY
|
|
|