""" 生成 smoke test 所需的合成数据(fake latents)。 数据布局: /data/train-input/smoke_test/ data.csv latents/sample_{i}.pt # video latent [16, 1, 32, 32] t5/sample_{i}.pt # T5 embedding [256, 4096] clip/sample_{i}.pt # CLIP embedding [768] /mnt/ddn/sora/tmp_load/ null_t5.pt # null T5 [1, 256, 4096] null_clip.pt # null CLIP [1, 768] Shape 说明(256px 单帧,AE_SPATIAL_COMPRESSION=16,patch_size=2): VAE latent: [C=16, T=1, H_lat=32, W_lat=32] 32 = patch_size(2) * ceil(256 / AE_SPATIAL_COMPRESSION(16)) = 2 * 16 T5: [seq=256, dim=4096] CLIP pooled: [dim=768] """ import os import csv import torch from PIL import Image N_SAMPLES = 4 # 生成 4 条假样本 BASE_DIR = "/data/train-input/smoke_test" NULL_DIR = "/mnt/ddn/sora/tmp_load" LAT_C, LAT_T, LAT_H, LAT_W = 16, 1, 32, 32 # video latent shape T5_SEQ, T5_DIM = 256, 4096 # T5 embedding CLIP_DIM = 768 # CLIP pooled def main(): os.makedirs(f"{BASE_DIR}/latents", exist_ok=True) os.makedirs(f"{BASE_DIR}/t5", exist_ok=True) os.makedirs(f"{BASE_DIR}/clip", exist_ok=True) os.makedirs(f"{BASE_DIR}/images", exist_ok=True) os.makedirs(NULL_DIR, exist_ok=True) rows = [] for i in range(N_SAMPLES): img_path = f"{BASE_DIR}/images/sample_{i}.png" lat_path = f"{BASE_DIR}/latents/sample_{i}.pt" t5_path = f"{BASE_DIR}/t5/sample_{i}.pt" clip_path = f"{BASE_DIR}/clip/sample_{i}.pt" # 真实的 256×256 图片(用于 load_original_video=True 时的原始帧加载) img = Image.new("RGB", (256, 256), color=(i * 60 % 256, 128, 200)) img.save(img_path) torch.save(torch.randn(LAT_C, LAT_T, LAT_H, LAT_W), lat_path) torch.save(torch.randn(T5_SEQ, T5_DIM), t5_path) torch.save(torch.randn(CLIP_DIM), clip_path) rows.append({ "id": i, "path": img_path, # 真实图片路径(dataset 读取原始帧用) "text": f"smoke test sample {i}", "num_frames": 1, # 单帧图像 "height": 256, "width": 256, "fps": 24.0, "latents_path": lat_path, "text_t5_path": t5_path, "text_clip_path": clip_path, }) csv_path = f"{BASE_DIR}/data.csv" with open(csv_path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=rows[0].keys()) writer.writeheader() writer.writerows(rows) print(f"Wrote {N_SAMPLES} samples to {csv_path}") # null vectors(空提示的 embedding,用于 cfg dropout) null_t5 = torch.zeros(1, T5_SEQ, T5_DIM) null_clip = torch.zeros(1, CLIP_DIM) torch.save(null_t5, f"{NULL_DIR}/null_t5.pt") torch.save(null_clip, f"{NULL_DIR}/null_clip.pt") print(f"Wrote null vectors to {NULL_DIR}/") print("Done.") if __name__ == "__main__": main()