89 lines
3.2 KiB
Python
89 lines
3.2 KiB
Python
"""
|
||
生成 smoke test 所需的合成数据(fake latents)。
|
||
|
||
数据布局:
|
||
/data/train-input/smoke_test/
|
||
data.csv
|
||
latents/sample_{i}.pt # video latent [16, 1, 32, 32]
|
||
t5/sample_{i}.pt # T5 embedding [256, 4096]
|
||
clip/sample_{i}.pt # CLIP embedding [768]
|
||
|
||
/mnt/ddn/sora/tmp_load/
|
||
null_t5.pt # null T5 [1, 256, 4096]
|
||
null_clip.pt # null CLIP [1, 768]
|
||
|
||
Shape 说明(256px 单帧,AE_SPATIAL_COMPRESSION=16,patch_size=2):
|
||
VAE latent: [C=16, T=1, H_lat=32, W_lat=32]
|
||
32 = patch_size(2) * ceil(256 / AE_SPATIAL_COMPRESSION(16)) = 2 * 16
|
||
T5: [seq=256, dim=4096]
|
||
CLIP pooled: [dim=768]
|
||
"""
|
||
|
||
import os
|
||
import csv
|
||
import torch
|
||
from PIL import Image
|
||
|
||
N_SAMPLES = 4 # 生成 4 条假样本
|
||
BASE_DIR = "/data/train-input/smoke_test"
|
||
NULL_DIR = "/mnt/ddn/sora/tmp_load"
|
||
|
||
LAT_C, LAT_T, LAT_H, LAT_W = 16, 1, 32, 32 # video latent shape
|
||
T5_SEQ, T5_DIM = 256, 4096 # T5 embedding
|
||
CLIP_DIM = 768 # CLIP pooled
|
||
|
||
def main():
|
||
os.makedirs(f"{BASE_DIR}/latents", exist_ok=True)
|
||
os.makedirs(f"{BASE_DIR}/t5", exist_ok=True)
|
||
os.makedirs(f"{BASE_DIR}/clip", exist_ok=True)
|
||
os.makedirs(f"{BASE_DIR}/images", exist_ok=True)
|
||
os.makedirs(NULL_DIR, exist_ok=True)
|
||
|
||
rows = []
|
||
for i in range(N_SAMPLES):
|
||
img_path = f"{BASE_DIR}/images/sample_{i}.png"
|
||
lat_path = f"{BASE_DIR}/latents/sample_{i}.pt"
|
||
t5_path = f"{BASE_DIR}/t5/sample_{i}.pt"
|
||
clip_path = f"{BASE_DIR}/clip/sample_{i}.pt"
|
||
|
||
# 真实的 256×256 图片(用于 load_original_video=True 时的原始帧加载)
|
||
img = Image.new("RGB", (256, 256), color=(i * 60 % 256, 128, 200))
|
||
img.save(img_path)
|
||
|
||
torch.save(torch.randn(LAT_C, LAT_T, LAT_H, LAT_W), lat_path)
|
||
torch.save(torch.randn(T5_SEQ, T5_DIM), t5_path)
|
||
torch.save(torch.randn(CLIP_DIM), clip_path)
|
||
|
||
rows.append({
|
||
"id": i,
|
||
"path": img_path, # 真实图片路径(dataset 读取原始帧用)
|
||
"text": f"smoke test sample {i}",
|
||
"num_frames": 1, # 单帧图像
|
||
"height": 256,
|
||
"width": 256,
|
||
"fps": 24.0,
|
||
"latents_path": lat_path,
|
||
"text_t5_path": t5_path,
|
||
"text_clip_path": clip_path,
|
||
})
|
||
|
||
csv_path = f"{BASE_DIR}/data.csv"
|
||
with open(csv_path, "w", newline="") as f:
|
||
writer = csv.DictWriter(f, fieldnames=rows[0].keys())
|
||
writer.writeheader()
|
||
writer.writerows(rows)
|
||
print(f"Wrote {N_SAMPLES} samples to {csv_path}")
|
||
|
||
# null vectors(空提示的 embedding,用于 cfg dropout)
|
||
# 必须用 bf16,与训练 dtype 一致(train.py 加载后不做 dtype 转换)
|
||
null_t5 = torch.zeros(1, T5_SEQ, T5_DIM, dtype=torch.bfloat16)
|
||
null_clip = torch.zeros(1, CLIP_DIM, dtype=torch.bfloat16)
|
||
torch.save(null_t5, f"{NULL_DIR}/null_t5.pt")
|
||
torch.save(null_clip, f"{NULL_DIR}/null_clip.pt")
|
||
print(f"Wrote null vectors to {NULL_DIR}/")
|
||
|
||
print("Done.")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|