diff --git a/scripts/diffusion/gen_smoke_data.py b/scripts/diffusion/gen_smoke_data.py index 839f8f2..6a9634e 100644 --- a/scripts/diffusion/gen_smoke_data.py +++ b/scripts/diffusion/gen_smoke_data.py @@ -75,8 +75,9 @@ def main(): print(f"Wrote {N_SAMPLES} samples to {csv_path}") # null vectors(空提示的 embedding,用于 cfg dropout) - null_t5 = torch.zeros(1, T5_SEQ, T5_DIM) - null_clip = torch.zeros(1, CLIP_DIM) + # 必须用 bf16,与训练 dtype 一致(train.py 加载后不做 dtype 转换) + null_t5 = torch.zeros(1, T5_SEQ, T5_DIM, dtype=torch.bfloat16) + null_clip = torch.zeros(1, CLIP_DIM, dtype=torch.bfloat16) torch.save(null_t5, f"{NULL_DIR}/null_t5.pt") torch.save(null_clip, f"{NULL_DIR}/null_clip.pt") print(f"Wrote null vectors to {NULL_DIR}/")