mysora/scripts/diffusion/gen_smoke_data.py

88 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
生成 smoke test 所需的合成数据fake latents
数据布局:
/data/train-input/smoke_test/
data.csv
latents/sample_{i}.pt # video latent [16, 1, 32, 32]
t5/sample_{i}.pt # T5 embedding [256, 4096]
clip/sample_{i}.pt # CLIP embedding [768]
/mnt/ddn/sora/tmp_load/
null_t5.pt # null T5 [1, 256, 4096]
null_clip.pt # null CLIP [1, 768]
Shape 说明256px 单帧AE_SPATIAL_COMPRESSION=16patch_size=2
VAE latent: [C=16, T=1, H_lat=32, W_lat=32]
32 = patch_size(2) * ceil(256 / AE_SPATIAL_COMPRESSION(16)) = 2 * 16
T5: [seq=256, dim=4096]
CLIP pooled: [dim=768]
"""
import os
import csv
import torch
from PIL import Image
N_SAMPLES = 4 # 生成 4 条假样本
BASE_DIR = "/data/train-input/smoke_test"
NULL_DIR = "/mnt/ddn/sora/tmp_load"
LAT_C, LAT_T, LAT_H, LAT_W = 16, 1, 32, 32 # video latent shape
T5_SEQ, T5_DIM = 256, 4096 # T5 embedding
CLIP_DIM = 768 # CLIP pooled
def main():
os.makedirs(f"{BASE_DIR}/latents", exist_ok=True)
os.makedirs(f"{BASE_DIR}/t5", exist_ok=True)
os.makedirs(f"{BASE_DIR}/clip", exist_ok=True)
os.makedirs(f"{BASE_DIR}/images", exist_ok=True)
os.makedirs(NULL_DIR, exist_ok=True)
rows = []
for i in range(N_SAMPLES):
img_path = f"{BASE_DIR}/images/sample_{i}.png"
lat_path = f"{BASE_DIR}/latents/sample_{i}.pt"
t5_path = f"{BASE_DIR}/t5/sample_{i}.pt"
clip_path = f"{BASE_DIR}/clip/sample_{i}.pt"
# 真实的 256×256 图片(用于 load_original_video=True 时的原始帧加载)
img = Image.new("RGB", (256, 256), color=(i * 60 % 256, 128, 200))
img.save(img_path)
torch.save(torch.randn(LAT_C, LAT_T, LAT_H, LAT_W), lat_path)
torch.save(torch.randn(T5_SEQ, T5_DIM), t5_path)
torch.save(torch.randn(CLIP_DIM), clip_path)
rows.append({
"id": i,
"path": img_path, # 真实图片路径dataset 读取原始帧用)
"text": f"smoke test sample {i}",
"num_frames": 1, # 单帧图像
"height": 256,
"width": 256,
"fps": 24.0,
"latents_path": lat_path,
"text_t5_path": t5_path,
"text_clip_path": clip_path,
})
csv_path = f"{BASE_DIR}/data.csv"
with open(csv_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=rows[0].keys())
writer.writeheader()
writer.writerows(rows)
print(f"Wrote {N_SAMPLES} samples to {csv_path}")
# null vectors空提示的 embedding用于 cfg dropout
null_t5 = torch.zeros(1, T5_SEQ, T5_DIM)
null_clip = torch.zeros(1, CLIP_DIM)
torch.save(null_t5, f"{NULL_DIR}/null_t5.pt")
torch.save(null_clip, f"{NULL_DIR}/null_clip.pt")
print(f"Wrote null vectors to {NULL_DIR}/")
print("Done.")
if __name__ == "__main__":
main()