train_env_prepare/check_train.sh

40 lines
1.3 KiB
Bash

python - <<'PY'
import warnings, torch
from datasets import load_dataset
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
Trainer, TrainingArguments, DataCollatorForLanguageModeling,
)
# 静音警告
warnings.filterwarnings("ignore", category=FutureWarning, message=".*clean_up_tokenization_spaces.*")
warnings.filterwarnings("ignore", category=UserWarning, message="Was asked to gather.*")
model_id = "sshleifer/tiny-gpt2"
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
tok.pad_token = tok.eos_token
tok.clean_up_tokenization_spaces = True # 显式设置,消除 FutureWarning
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")
def tok_fn(ex): return tok(ex["text"], truncation=True, padding="max_length", max_length=64)
ds = ds.map(tok_fn, batched=True, remove_columns=["text"])
mdl = AutoModelForCausalLM.from_pretrained(model_id)
collator = DataCollatorForLanguageModeling(tok, mlm=False)
args = TrainingArguments(
output_dir="out-mini",
per_device_train_batch_size=2,
num_train_epochs=1,
fp16=torch.cuda.is_available(),
logging_steps=2,
save_steps=10,
report_to="none",
max_grad_norm=1.0, # 可选:顺手收敛梯度范数
)
trainer = Trainer(model=mdl, args=args, train_dataset=ds, data_collator=collator)
trainer.train()
print("✅ 训练链路 OK")
PY