diff --git a/configs/diffusion/train/smoke_test.py b/configs/diffusion/train/smoke_test.py index 3b8b83b..88ae936 100644 --- a/configs/diffusion/train/smoke_test.py +++ b/configs/diffusion/train/smoke_test.py @@ -10,7 +10,7 @@ dataset = dict( data_path="/data/train-input/smoke_test/data.csv", cached_video=True, cached_text=True, - load_original_video=False, + load_original_video=True, memory_efficient=False, vmaf=False, ) diff --git a/scripts/diffusion/gen_smoke_data.py b/scripts/diffusion/gen_smoke_data.py index b5e36b8..839f8f2 100644 --- a/scripts/diffusion/gen_smoke_data.py +++ b/scripts/diffusion/gen_smoke_data.py @@ -22,6 +22,7 @@ Shape 说明(256px 单帧,AE_SPATIAL_COMPRESSION=16,patch_size=2): import os import csv import torch +from PIL import Image N_SAMPLES = 4 # 生成 4 条假样本 BASE_DIR = "/data/train-input/smoke_test" @@ -35,21 +36,27 @@ def main(): os.makedirs(f"{BASE_DIR}/latents", exist_ok=True) os.makedirs(f"{BASE_DIR}/t5", exist_ok=True) os.makedirs(f"{BASE_DIR}/clip", exist_ok=True) + os.makedirs(f"{BASE_DIR}/images", exist_ok=True) os.makedirs(NULL_DIR, exist_ok=True) rows = [] for i in range(N_SAMPLES): + img_path = f"{BASE_DIR}/images/sample_{i}.png" lat_path = f"{BASE_DIR}/latents/sample_{i}.pt" t5_path = f"{BASE_DIR}/t5/sample_{i}.pt" clip_path = f"{BASE_DIR}/clip/sample_{i}.pt" + # 真实的 256×256 图片(用于 load_original_video=True 时的原始帧加载) + img = Image.new("RGB", (256, 256), color=(i * 60 % 256, 128, 200)) + img.save(img_path) + torch.save(torch.randn(LAT_C, LAT_T, LAT_H, LAT_W), lat_path) torch.save(torch.randn(T5_SEQ, T5_DIM), t5_path) torch.save(torch.randn(CLIP_DIM), clip_path) rows.append({ "id": i, - "path": lat_path, # 占位(cached 模式不读原始视频) + "path": img_path, # 真实图片路径(dataset 读取原始帧用) "text": f"smoke test sample {i}", "num_frames": 1, # 单帧图像 "height": 256,