diff --git a/check_train.sh b/check_train.sh index 96db7ba..3cf983c 100644 --- a/check_train.sh +++ b/check_train.sh @@ -1,11 +1,19 @@ python - <<'PY' +import warnings, torch from datasets import load_dataset -from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling -import torch +from transformers import ( + AutoTokenizer, AutoModelForCausalLM, + Trainer, TrainingArguments, DataCollatorForLanguageModeling, +) -model_id = "sshleifer/tiny-gpt2" # 极小模型 -tok = AutoTokenizer.from_pretrained(model_id) +# 静音警告 +warnings.filterwarnings("ignore", category=FutureWarning, message=".*clean_up_tokenization_spaces.*") +warnings.filterwarnings("ignore", category=UserWarning, message="Was asked to gather.*") + +model_id = "sshleifer/tiny-gpt2" +tok = AutoTokenizer.from_pretrained(model_id, use_fast=True) tok.pad_token = tok.eos_token +tok.clean_up_tokenization_spaces = True # 显式设置,消除 FutureWarning ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]") def tok_fn(ex): return tok(ex["text"], truncation=True, padding="max_length", max_length=64) @@ -13,6 +21,7 @@ ds = ds.map(tok_fn, batched=True, remove_columns=["text"]) mdl = AutoModelForCausalLM.from_pretrained(model_id) collator = DataCollatorForLanguageModeling(tok, mlm=False) + args = TrainingArguments( output_dir="out-mini", per_device_train_batch_size=2, @@ -21,10 +30,10 @@ args = TrainingArguments( logging_steps=2, save_steps=10, report_to="none", + max_grad_norm=1.0, # 可选:顺手收敛梯度范数 ) trainer = Trainer(model=mdl, args=args, train_dataset=ds, data_collator=collator) trainer.train() print("✅ 训练链路 OK") PY -