fix: provide real image files for video key; set load_original_video=True
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
2cb9a1f29e
commit
7077ac7129
|
|
@ -10,7 +10,7 @@ dataset = dict(
|
||||||
data_path="/data/train-input/smoke_test/data.csv",
|
data_path="/data/train-input/smoke_test/data.csv",
|
||||||
cached_video=True,
|
cached_video=True,
|
||||||
cached_text=True,
|
cached_text=True,
|
||||||
load_original_video=False,
|
load_original_video=True,
|
||||||
memory_efficient=False,
|
memory_efficient=False,
|
||||||
vmaf=False,
|
vmaf=False,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ Shape 说明(256px 单帧,AE_SPATIAL_COMPRESSION=16,patch_size=2):
|
||||||
import os
|
import os
|
||||||
import csv
|
import csv
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
N_SAMPLES = 4 # 生成 4 条假样本
|
N_SAMPLES = 4 # 生成 4 条假样本
|
||||||
BASE_DIR = "/data/train-input/smoke_test"
|
BASE_DIR = "/data/train-input/smoke_test"
|
||||||
|
|
@ -35,21 +36,27 @@ def main():
|
||||||
os.makedirs(f"{BASE_DIR}/latents", exist_ok=True)
|
os.makedirs(f"{BASE_DIR}/latents", exist_ok=True)
|
||||||
os.makedirs(f"{BASE_DIR}/t5", exist_ok=True)
|
os.makedirs(f"{BASE_DIR}/t5", exist_ok=True)
|
||||||
os.makedirs(f"{BASE_DIR}/clip", exist_ok=True)
|
os.makedirs(f"{BASE_DIR}/clip", exist_ok=True)
|
||||||
|
os.makedirs(f"{BASE_DIR}/images", exist_ok=True)
|
||||||
os.makedirs(NULL_DIR, exist_ok=True)
|
os.makedirs(NULL_DIR, exist_ok=True)
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
for i in range(N_SAMPLES):
|
for i in range(N_SAMPLES):
|
||||||
|
img_path = f"{BASE_DIR}/images/sample_{i}.png"
|
||||||
lat_path = f"{BASE_DIR}/latents/sample_{i}.pt"
|
lat_path = f"{BASE_DIR}/latents/sample_{i}.pt"
|
||||||
t5_path = f"{BASE_DIR}/t5/sample_{i}.pt"
|
t5_path = f"{BASE_DIR}/t5/sample_{i}.pt"
|
||||||
clip_path = f"{BASE_DIR}/clip/sample_{i}.pt"
|
clip_path = f"{BASE_DIR}/clip/sample_{i}.pt"
|
||||||
|
|
||||||
|
# 真实的 256×256 图片(用于 load_original_video=True 时的原始帧加载)
|
||||||
|
img = Image.new("RGB", (256, 256), color=(i * 60 % 256, 128, 200))
|
||||||
|
img.save(img_path)
|
||||||
|
|
||||||
torch.save(torch.randn(LAT_C, LAT_T, LAT_H, LAT_W), lat_path)
|
torch.save(torch.randn(LAT_C, LAT_T, LAT_H, LAT_W), lat_path)
|
||||||
torch.save(torch.randn(T5_SEQ, T5_DIM), t5_path)
|
torch.save(torch.randn(T5_SEQ, T5_DIM), t5_path)
|
||||||
torch.save(torch.randn(CLIP_DIM), clip_path)
|
torch.save(torch.randn(CLIP_DIM), clip_path)
|
||||||
|
|
||||||
rows.append({
|
rows.append({
|
||||||
"id": i,
|
"id": i,
|
||||||
"path": lat_path, # 占位(cached 模式不读原始视频)
|
"path": img_path, # 真实图片路径(dataset 读取原始帧用)
|
||||||
"text": f"smoke test sample {i}",
|
"text": f"smoke test sample {i}",
|
||||||
"num_frames": 1, # 单帧图像
|
"num_frames": 1, # 单帧图像
|
||||||
"height": 256,
|
"height": 256,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue