feat: add smoke test config and synthetic data generator
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
bdeb2870d4
commit
ae3e448c8a
|
|
@ -0,0 +1,77 @@
|
||||||
|
_base_ = ["image.py"]
|
||||||
|
|
||||||
|
# ===== Smoke Test Config =====
|
||||||
|
# cached_video + cached_text: 跳过 VAE / T5 / CLIP,直接用预算好的 latent
|
||||||
|
# 仅用 1 GPU,3 个训练步,验证训练 loop 通路
|
||||||
|
|
||||||
|
# --- Dataset ---
|
||||||
|
dataset = dict(
|
||||||
|
type="cached_video_text",
|
||||||
|
data_path="/data/train-input/smoke_test/data.csv",
|
||||||
|
cached_video=True,
|
||||||
|
cached_text=True,
|
||||||
|
load_original_video=False,
|
||||||
|
memory_efficient=False,
|
||||||
|
vmaf=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Bucket: 256px 单帧,batch_size=1 ---
|
||||||
|
bucket_config = {
|
||||||
|
"_delete_": True,
|
||||||
|
"256px": {
|
||||||
|
1: (1.0, 1),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- 跳过 VAE / T5 / CLIP 加载 ---
|
||||||
|
cached_video = True
|
||||||
|
cached_text = True
|
||||||
|
|
||||||
|
# --- Null 向量路径(空提示的 embedding,用于 dropout) ---
|
||||||
|
# train.py 里 cached_text 模式下 hardcode 读这两个路径
|
||||||
|
# 我们在 /mnt/ddn/sora/tmp_load/ 放假的 null 向量
|
||||||
|
|
||||||
|
# --- 模型:随机初始化,不加载预训练权重 ---
|
||||||
|
model = dict(
|
||||||
|
from_pretrained=None,
|
||||||
|
strict_load=False,
|
||||||
|
grad_ckpt_settings=(1, 100),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- 优化器 ---
|
||||||
|
lr = 1e-5
|
||||||
|
optim = dict(
|
||||||
|
cls="HybridAdam",
|
||||||
|
lr=lr,
|
||||||
|
eps=1e-15,
|
||||||
|
weight_decay=0.0,
|
||||||
|
adamw_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- 训练参数 ---
|
||||||
|
epochs = 1
|
||||||
|
ckpt_every = 0 # 不保存 checkpoint
|
||||||
|
log_every = 1
|
||||||
|
warmup_steps = 0
|
||||||
|
grad_clip = 1.0
|
||||||
|
accumulation_steps = 1
|
||||||
|
ema_decay = None # 不用 EMA
|
||||||
|
|
||||||
|
# --- 加速 ---
|
||||||
|
dtype = "bf16"
|
||||||
|
plugin = "zero2"
|
||||||
|
plugin_config = dict(
|
||||||
|
reduce_bucket_size_in_m=128,
|
||||||
|
overlap_allgather=False,
|
||||||
|
)
|
||||||
|
pin_memory_cache_pre_alloc_numels = None # 关掉 pin memory 预分配
|
||||||
|
|
||||||
|
# --- 其他 ---
|
||||||
|
num_workers = 0
|
||||||
|
prefetch_factor = None
|
||||||
|
num_bucket_build_workers = 1
|
||||||
|
seed = 42
|
||||||
|
outputs = "/data/train-output/smoke_test_outputs"
|
||||||
|
grad_checkpoint = False # 关掉 activation checkpoint,减少复杂度
|
||||||
|
|
||||||
|
# dropout ratio 保持原样(0.31...),null vector 在 /mnt/ddn/sora/tmp_load/ 下
|
||||||
|
|
@ -0,0 +1,75 @@
|
||||||
|
"""
|
||||||
|
生成 smoke test 所需的合成数据(fake latents)。
|
||||||
|
|
||||||
|
数据布局:
|
||||||
|
/data/train-input/smoke_test/
|
||||||
|
data.csv
|
||||||
|
latents/sample_{i}.pt # video latent [16, 1, 32, 32]
|
||||||
|
t5/sample_{i}.pt # T5 embedding [256, 4096]
|
||||||
|
clip/sample_{i}.pt # CLIP embedding [768]
|
||||||
|
|
||||||
|
/mnt/ddn/sora/tmp_load/
|
||||||
|
null_t5.pt # null T5 [1, 256, 4096]
|
||||||
|
null_clip.pt # null CLIP [1, 768]
|
||||||
|
|
||||||
|
Shape 说明(256px 单帧,AE_SPATIAL_COMPRESSION=16,patch_size=2):
|
||||||
|
VAE latent: [C=16, T=1, H_lat=32, W_lat=32]
|
||||||
|
32 = patch_size(2) * ceil(256 / AE_SPATIAL_COMPRESSION(16)) = 2 * 16
|
||||||
|
T5: [seq=256, dim=4096]
|
||||||
|
CLIP pooled: [dim=768]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_SAMPLES = 4 # 生成 4 条假样本
|
||||||
|
BASE_DIR = "/data/train-input/smoke_test"
|
||||||
|
NULL_DIR = "/mnt/ddn/sora/tmp_load"
|
||||||
|
|
||||||
|
LAT_C, LAT_T, LAT_H, LAT_W = 16, 1, 32, 32 # video latent shape
|
||||||
|
T5_SEQ, T5_DIM = 256, 4096 # T5 embedding
|
||||||
|
CLIP_DIM = 768 # CLIP pooled
|
||||||
|
|
||||||
|
def main():
|
||||||
|
os.makedirs(f"{BASE_DIR}/latents", exist_ok=True)
|
||||||
|
os.makedirs(f"{BASE_DIR}/t5", exist_ok=True)
|
||||||
|
os.makedirs(f"{BASE_DIR}/clip", exist_ok=True)
|
||||||
|
os.makedirs(NULL_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for i in range(N_SAMPLES):
|
||||||
|
lat_path = f"{BASE_DIR}/latents/sample_{i}.pt"
|
||||||
|
t5_path = f"{BASE_DIR}/t5/sample_{i}.pt"
|
||||||
|
clip_path = f"{BASE_DIR}/clip/sample_{i}.pt"
|
||||||
|
|
||||||
|
torch.save(torch.randn(LAT_C, LAT_T, LAT_H, LAT_W), lat_path)
|
||||||
|
torch.save(torch.randn(T5_SEQ, T5_DIM), t5_path)
|
||||||
|
torch.save(torch.randn(CLIP_DIM), clip_path)
|
||||||
|
|
||||||
|
rows.append({
|
||||||
|
"path": lat_path, # 占位(cached 模式不读原始视频)
|
||||||
|
"text": f"smoke test sample {i}",
|
||||||
|
"latents_path": lat_path,
|
||||||
|
"text_t5_path": t5_path,
|
||||||
|
"text_clip_path": clip_path,
|
||||||
|
})
|
||||||
|
|
||||||
|
csv_path = f"{BASE_DIR}/data.csv"
|
||||||
|
with open(csv_path, "w", newline="") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=rows[0].keys())
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(rows)
|
||||||
|
print(f"Wrote {N_SAMPLES} samples to {csv_path}")
|
||||||
|
|
||||||
|
# null vectors(空提示的 embedding,用于 cfg dropout)
|
||||||
|
null_t5 = torch.zeros(1, T5_SEQ, T5_DIM)
|
||||||
|
null_clip = torch.zeros(1, CLIP_DIM)
|
||||||
|
torch.save(null_t5, f"{NULL_DIR}/null_t5.pt")
|
||||||
|
torch.save(null_clip, f"{NULL_DIR}/null_clip.pt")
|
||||||
|
print(f"Wrote null vectors to {NULL_DIR}/")
|
||||||
|
|
||||||
|
print("Done.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue