fix: save null vectors as bfloat16 to match training dtype

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-03-06 03:20:02 -08:00
parent 7077ac7129
commit dc1ff288e8
1 changed files with 3 additions and 2 deletions

View File

@ -75,8 +75,9 @@ def main():
print(f"Wrote {N_SAMPLES} samples to {csv_path}")
# null vectors空提示的 embedding用于 cfg dropout
null_t5 = torch.zeros(1, T5_SEQ, T5_DIM)
null_clip = torch.zeros(1, CLIP_DIM)
# 必须用 bf16与训练 dtype 一致train.py 加载后不做 dtype 转换)
null_t5 = torch.zeros(1, T5_SEQ, T5_DIM, dtype=torch.bfloat16)
null_clip = torch.zeros(1, CLIP_DIM, dtype=torch.bfloat16)
torch.save(null_t5, f"{NULL_DIR}/null_t5.pt")
torch.save(null_clip, f"{NULL_DIR}/null_clip.pt")
print(f"Wrote null vectors to {NULL_DIR}/")