diff --git a/configs/diffusion/train/smoke_test.py b/configs/diffusion/train/smoke_test.py new file mode 100644 index 0000000..3b8b83b --- /dev/null +++ b/configs/diffusion/train/smoke_test.py @@ -0,0 +1,77 @@ +_base_ = ["image.py"] + +# ===== Smoke Test Config ===== +# cached_video + cached_text: 跳过 VAE / T5 / CLIP,直接用预算好的 latent +# 仅用 1 GPU,3 个训练步,验证训练 loop 通路 + +# --- Dataset --- +dataset = dict( + type="cached_video_text", + data_path="/data/train-input/smoke_test/data.csv", + cached_video=True, + cached_text=True, + load_original_video=False, + memory_efficient=False, + vmaf=False, +) + +# --- Bucket: 256px 单帧,batch_size=1 --- +bucket_config = { + "_delete_": True, + "256px": { + 1: (1.0, 1), + }, +} + +# --- 跳过 VAE / T5 / CLIP 加载 --- +cached_video = True +cached_text = True + +# --- Null 向量路径(空提示的 embedding,用于 dropout) --- +# train.py 里 cached_text 模式下 hardcode 读这两个路径 +# 我们在 /mnt/ddn/sora/tmp_load/ 放假的 null 向量 + +# --- 模型:随机初始化,不加载预训练权重 --- +model = dict( + from_pretrained=None, + strict_load=False, + grad_ckpt_settings=(1, 100), +) + +# --- 优化器 --- +lr = 1e-5 +optim = dict( + cls="HybridAdam", + lr=lr, + eps=1e-15, + weight_decay=0.0, + adamw_mode=True, +) + +# --- 训练参数 --- +epochs = 1 +ckpt_every = 0 # 不保存 checkpoint +log_every = 1 +warmup_steps = 0 +grad_clip = 1.0 +accumulation_steps = 1 +ema_decay = None # 不用 EMA + +# --- 加速 --- +dtype = "bf16" +plugin = "zero2" +plugin_config = dict( + reduce_bucket_size_in_m=128, + overlap_allgather=False, +) +pin_memory_cache_pre_alloc_numels = None # 关掉 pin memory 预分配 + +# --- 其他 --- +num_workers = 0 +prefetch_factor = None +num_bucket_build_workers = 1 +seed = 42 +outputs = "/data/train-output/smoke_test_outputs" +grad_checkpoint = False # 关掉 activation checkpoint,减少复杂度 + +# dropout ratio 保持原样(0.31...),null vector 在 /mnt/ddn/sora/tmp_load/ 下 diff --git a/scripts/diffusion/gen_smoke_data.py b/scripts/diffusion/gen_smoke_data.py new file mode 100644 index 0000000..b107dea --- /dev/null +++ b/scripts/diffusion/gen_smoke_data.py @@ -0,0 +1,75 @@ +""" +生成 smoke test 所需的合成数据(fake latents)。 + +数据布局: + /data/train-input/smoke_test/ + data.csv + latents/sample_{i}.pt # video latent [16, 1, 32, 32] + t5/sample_{i}.pt # T5 embedding [256, 4096] + clip/sample_{i}.pt # CLIP embedding [768] + + /mnt/ddn/sora/tmp_load/ + null_t5.pt # null T5 [1, 256, 4096] + null_clip.pt # null CLIP [1, 768] + +Shape 说明(256px 单帧,AE_SPATIAL_COMPRESSION=16,patch_size=2): + VAE latent: [C=16, T=1, H_lat=32, W_lat=32] + 32 = patch_size(2) * ceil(256 / AE_SPATIAL_COMPRESSION(16)) = 2 * 16 + T5: [seq=256, dim=4096] + CLIP pooled: [dim=768] +""" + +import os +import csv +import torch + +N_SAMPLES = 4 # 生成 4 条假样本 +BASE_DIR = "/data/train-input/smoke_test" +NULL_DIR = "/mnt/ddn/sora/tmp_load" + +LAT_C, LAT_T, LAT_H, LAT_W = 16, 1, 32, 32 # video latent shape +T5_SEQ, T5_DIM = 256, 4096 # T5 embedding +CLIP_DIM = 768 # CLIP pooled + +def main(): + os.makedirs(f"{BASE_DIR}/latents", exist_ok=True) + os.makedirs(f"{BASE_DIR}/t5", exist_ok=True) + os.makedirs(f"{BASE_DIR}/clip", exist_ok=True) + os.makedirs(NULL_DIR, exist_ok=True) + + rows = [] + for i in range(N_SAMPLES): + lat_path = f"{BASE_DIR}/latents/sample_{i}.pt" + t5_path = f"{BASE_DIR}/t5/sample_{i}.pt" + clip_path = f"{BASE_DIR}/clip/sample_{i}.pt" + + torch.save(torch.randn(LAT_C, LAT_T, LAT_H, LAT_W), lat_path) + torch.save(torch.randn(T5_SEQ, T5_DIM), t5_path) + torch.save(torch.randn(CLIP_DIM), clip_path) + + rows.append({ + "path": lat_path, # 占位(cached 模式不读原始视频) + "text": f"smoke test sample {i}", + "latents_path": lat_path, + "text_t5_path": t5_path, + "text_clip_path": clip_path, + }) + + csv_path = f"{BASE_DIR}/data.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=rows[0].keys()) + writer.writeheader() + writer.writerows(rows) + print(f"Wrote {N_SAMPLES} samples to {csv_path}") + + # null vectors(空提示的 embedding,用于 cfg dropout) + null_t5 = torch.zeros(1, T5_SEQ, T5_DIM) + null_clip = torch.zeros(1, CLIP_DIM) + torch.save(null_t5, f"{NULL_DIR}/null_t5.pt") + torch.save(null_clip, f"{NULL_DIR}/null_clip.pt") + print(f"Wrote null vectors to {NULL_DIR}/") + + print("Done.") + +if __name__ == "__main__": + main()