python - <<'PY' from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling import torch model_id = "sshleifer/tiny-gpt2" # 极小模型 tok = AutoTokenizer.from_pretrained(model_id) tok.pad_token = tok.eos_token ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]") def tok_fn(ex): return tok(ex["text"], truncation=True, padding="max_length", max_length=64) ds = ds.map(tok_fn, batched=True, remove_columns=["text"]) mdl = AutoModelForCausalLM.from_pretrained(model_id) collator = DataCollatorForLanguageModeling(tok, mlm=False) args = TrainingArguments( output_dir="out-mini", per_device_train_batch_size=2, num_train_epochs=1, fp16=torch.cuda.is_available(), logging_steps=2, save_steps=10, report_to="none", ) trainer = Trainer(model=mdl, args=args, train_dataset=ds, data_collator=collator) trainer.train() print("✅ 训练链路 OK") PY