mysora/scripts/diffusion/gen_smoke_data.py

89 lines
3.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
生成 smoke test 所需的合成数据fake latents
数据布局:
/data/train-input/smoke_test/
data.csv
latents/sample_{i}.pt # video latent [16, 1, 32, 32]
t5/sample_{i}.pt # T5 embedding [256, 4096]
clip/sample_{i}.pt # CLIP embedding [768]
/mnt/ddn/sora/tmp_load/
null_t5.pt # null T5 [1, 256, 4096]
null_clip.pt # null CLIP [1, 768]
Shape 说明256px 单帧AE_SPATIAL_COMPRESSION=16patch_size=2
VAE latent: [C=16, T=1, H_lat=32, W_lat=32]
32 = patch_size(2) * ceil(256 / AE_SPATIAL_COMPRESSION(16)) = 2 * 16
T5: [seq=256, dim=4096]
CLIP pooled: [dim=768]
"""
import os
import csv
import torch
from PIL import Image
N_SAMPLES = 16 # 8 GPU × 2 batch/GPU保证每卡至少有数据
BASE_DIR = "/data/train-input/smoke_test"
NULL_DIR = "/mnt/ddn/sora/tmp_load"
LAT_C, LAT_T, LAT_H, LAT_W = 16, 1, 32, 32 # video latent shape
T5_SEQ, T5_DIM = 256, 4096 # T5 embedding
CLIP_DIM = 768 # CLIP pooled
def main():
os.makedirs(f"{BASE_DIR}/latents", exist_ok=True)
os.makedirs(f"{BASE_DIR}/t5", exist_ok=True)
os.makedirs(f"{BASE_DIR}/clip", exist_ok=True)
os.makedirs(f"{BASE_DIR}/images", exist_ok=True)
os.makedirs(NULL_DIR, exist_ok=True)
rows = []
for i in range(N_SAMPLES):
img_path = f"{BASE_DIR}/images/sample_{i}.png"
lat_path = f"{BASE_DIR}/latents/sample_{i}.pt"
t5_path = f"{BASE_DIR}/t5/sample_{i}.pt"
clip_path = f"{BASE_DIR}/clip/sample_{i}.pt"
# 真实的 256×256 图片(用于 load_original_video=True 时的原始帧加载)
img = Image.new("RGB", (256, 256), color=(i * 60 % 256, 128, 200))
img.save(img_path)
torch.save(torch.randn(LAT_C, LAT_T, LAT_H, LAT_W), lat_path)
torch.save(torch.randn(T5_SEQ, T5_DIM), t5_path)
torch.save(torch.randn(CLIP_DIM), clip_path)
rows.append({
"id": i,
"path": img_path, # 真实图片路径dataset 读取原始帧用)
"text": f"smoke test sample {i}",
"num_frames": 1, # 单帧图像
"height": 256,
"width": 256,
"fps": 24.0,
"latents_path": lat_path,
"text_t5_path": t5_path,
"text_clip_path": clip_path,
})
csv_path = f"{BASE_DIR}/data.csv"
with open(csv_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=rows[0].keys())
writer.writeheader()
writer.writerows(rows)
print(f"Wrote {N_SAMPLES} samples to {csv_path}")
# null vectors空提示的 embedding用于 cfg dropout
# 必须用 bf16与训练 dtype 一致train.py 加载后不做 dtype 转换)
null_t5 = torch.zeros(1, T5_SEQ, T5_DIM, dtype=torch.bfloat16)
null_clip = torch.zeros(1, CLIP_DIM, dtype=torch.bfloat16)
torch.save(null_t5, f"{NULL_DIR}/null_t5.pt")
torch.save(null_clip, f"{NULL_DIR}/null_clip.pt")
print(f"Wrote null vectors to {NULL_DIR}/")
print("Done.")
if __name__ == "__main__":
main()