fix: save null vectors as bfloat16 to match training dtype
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
7077ac7129
commit
dc1ff288e8
|
|
@ -75,8 +75,9 @@ def main():
|
|||
print(f"Wrote {N_SAMPLES} samples to {csv_path}")
|
||||
|
||||
# null vectors(空提示的 embedding,用于 cfg dropout)
|
||||
null_t5 = torch.zeros(1, T5_SEQ, T5_DIM)
|
||||
null_clip = torch.zeros(1, CLIP_DIM)
|
||||
# 必须用 bf16,与训练 dtype 一致(train.py 加载后不做 dtype 转换)
|
||||
null_t5 = torch.zeros(1, T5_SEQ, T5_DIM, dtype=torch.bfloat16)
|
||||
null_clip = torch.zeros(1, CLIP_DIM, dtype=torch.bfloat16)
|
||||
torch.save(null_t5, f"{NULL_DIR}/null_t5.pt")
|
||||
torch.save(null_clip, f"{NULL_DIR}/null_clip.pt")
|
||||
print(f"Wrote null vectors to {NULL_DIR}/")
|
||||
|
|
|
|||
Loading…
Reference in New Issue