python - <<'PY' import warnings, torch warnings.filterwarnings("ignore", category=UserWarning, message=".*TypedStorage is deprecated.*") warnings.filterwarnings("ignore", category=UserWarning, message="Was asked to gather.*") from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, ) # 静音警告 warnings.filterwarnings("ignore", category=FutureWarning, message=".*clean_up_tokenization_spaces.*") warnings.filterwarnings("ignore", category=UserWarning, message="Was asked to gather.*") model_id = "sshleifer/tiny-gpt2" tok = AutoTokenizer.from_pretrained(model_id, use_fast=True) tok.pad_token = tok.eos_token tok.clean_up_tokenization_spaces = True # 显式设置,消除 FutureWarning ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]") def tok_fn(ex): return tok(ex["text"], truncation=True, padding="max_length", max_length=64) ds = ds.map(tok_fn, batched=True, remove_columns=["text"]) mdl = AutoModelForCausalLM.from_pretrained(model_id) collator = DataCollatorForLanguageModeling(tok, mlm=False) args = TrainingArguments( output_dir="out-mini", per_device_train_batch_size=2, num_train_epochs=1, fp16=torch.cuda.is_available(), logging_steps=2, save_steps=10, report_to="none", max_grad_norm=1.0, # 可选:顺手收敛梯度范数 ) trainer = Trainer(model=mdl, args=args, train_dataset=ds, data_collator=collator) trainer.train() print("✅ 训练链路 OK") PY