feat: add smoke test config and synthetic data generator
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
bdeb2870d4
commit
ae3e448c8a
|
|
@ -0,0 +1,77 @@
|
|||
_base_ = ["image.py"]
|
||||
|
||||
# ===== Smoke Test Config =====
|
||||
# cached_video + cached_text: 跳过 VAE / T5 / CLIP,直接用预算好的 latent
|
||||
# 仅用 1 GPU,3 个训练步,验证训练 loop 通路
|
||||
|
||||
# --- Dataset ---
|
||||
dataset = dict(
|
||||
type="cached_video_text",
|
||||
data_path="/data/train-input/smoke_test/data.csv",
|
||||
cached_video=True,
|
||||
cached_text=True,
|
||||
load_original_video=False,
|
||||
memory_efficient=False,
|
||||
vmaf=False,
|
||||
)
|
||||
|
||||
# --- Bucket: 256px 单帧,batch_size=1 ---
|
||||
bucket_config = {
|
||||
"_delete_": True,
|
||||
"256px": {
|
||||
1: (1.0, 1),
|
||||
},
|
||||
}
|
||||
|
||||
# --- 跳过 VAE / T5 / CLIP 加载 ---
|
||||
cached_video = True
|
||||
cached_text = True
|
||||
|
||||
# --- Null 向量路径(空提示的 embedding,用于 dropout) ---
|
||||
# train.py 里 cached_text 模式下 hardcode 读这两个路径
|
||||
# 我们在 /mnt/ddn/sora/tmp_load/ 放假的 null 向量
|
||||
|
||||
# --- 模型:随机初始化,不加载预训练权重 ---
|
||||
model = dict(
|
||||
from_pretrained=None,
|
||||
strict_load=False,
|
||||
grad_ckpt_settings=(1, 100),
|
||||
)
|
||||
|
||||
# --- 优化器 ---
|
||||
lr = 1e-5
|
||||
optim = dict(
|
||||
cls="HybridAdam",
|
||||
lr=lr,
|
||||
eps=1e-15,
|
||||
weight_decay=0.0,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# --- 训练参数 ---
|
||||
epochs = 1
|
||||
ckpt_every = 0 # 不保存 checkpoint
|
||||
log_every = 1
|
||||
warmup_steps = 0
|
||||
grad_clip = 1.0
|
||||
accumulation_steps = 1
|
||||
ema_decay = None # 不用 EMA
|
||||
|
||||
# --- 加速 ---
|
||||
dtype = "bf16"
|
||||
plugin = "zero2"
|
||||
plugin_config = dict(
|
||||
reduce_bucket_size_in_m=128,
|
||||
overlap_allgather=False,
|
||||
)
|
||||
pin_memory_cache_pre_alloc_numels = None # 关掉 pin memory 预分配
|
||||
|
||||
# --- 其他 ---
|
||||
num_workers = 0
|
||||
prefetch_factor = None
|
||||
num_bucket_build_workers = 1
|
||||
seed = 42
|
||||
outputs = "/data/train-output/smoke_test_outputs"
|
||||
grad_checkpoint = False # 关掉 activation checkpoint,减少复杂度
|
||||
|
||||
# dropout ratio 保持原样(0.31...),null vector 在 /mnt/ddn/sora/tmp_load/ 下
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
"""
|
||||
生成 smoke test 所需的合成数据(fake latents)。
|
||||
|
||||
数据布局:
|
||||
/data/train-input/smoke_test/
|
||||
data.csv
|
||||
latents/sample_{i}.pt # video latent [16, 1, 32, 32]
|
||||
t5/sample_{i}.pt # T5 embedding [256, 4096]
|
||||
clip/sample_{i}.pt # CLIP embedding [768]
|
||||
|
||||
/mnt/ddn/sora/tmp_load/
|
||||
null_t5.pt # null T5 [1, 256, 4096]
|
||||
null_clip.pt # null CLIP [1, 768]
|
||||
|
||||
Shape 说明(256px 单帧,AE_SPATIAL_COMPRESSION=16,patch_size=2):
|
||||
VAE latent: [C=16, T=1, H_lat=32, W_lat=32]
|
||||
32 = patch_size(2) * ceil(256 / AE_SPATIAL_COMPRESSION(16)) = 2 * 16
|
||||
T5: [seq=256, dim=4096]
|
||||
CLIP pooled: [dim=768]
|
||||
"""
|
||||
|
||||
import os
|
||||
import csv
|
||||
import torch
|
||||
|
||||
N_SAMPLES = 4 # 生成 4 条假样本
|
||||
BASE_DIR = "/data/train-input/smoke_test"
|
||||
NULL_DIR = "/mnt/ddn/sora/tmp_load"
|
||||
|
||||
LAT_C, LAT_T, LAT_H, LAT_W = 16, 1, 32, 32 # video latent shape
|
||||
T5_SEQ, T5_DIM = 256, 4096 # T5 embedding
|
||||
CLIP_DIM = 768 # CLIP pooled
|
||||
|
||||
def main():
|
||||
os.makedirs(f"{BASE_DIR}/latents", exist_ok=True)
|
||||
os.makedirs(f"{BASE_DIR}/t5", exist_ok=True)
|
||||
os.makedirs(f"{BASE_DIR}/clip", exist_ok=True)
|
||||
os.makedirs(NULL_DIR, exist_ok=True)
|
||||
|
||||
rows = []
|
||||
for i in range(N_SAMPLES):
|
||||
lat_path = f"{BASE_DIR}/latents/sample_{i}.pt"
|
||||
t5_path = f"{BASE_DIR}/t5/sample_{i}.pt"
|
||||
clip_path = f"{BASE_DIR}/clip/sample_{i}.pt"
|
||||
|
||||
torch.save(torch.randn(LAT_C, LAT_T, LAT_H, LAT_W), lat_path)
|
||||
torch.save(torch.randn(T5_SEQ, T5_DIM), t5_path)
|
||||
torch.save(torch.randn(CLIP_DIM), clip_path)
|
||||
|
||||
rows.append({
|
||||
"path": lat_path, # 占位(cached 模式不读原始视频)
|
||||
"text": f"smoke test sample {i}",
|
||||
"latents_path": lat_path,
|
||||
"text_t5_path": t5_path,
|
||||
"text_clip_path": clip_path,
|
||||
})
|
||||
|
||||
csv_path = f"{BASE_DIR}/data.csv"
|
||||
with open(csv_path, "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=rows[0].keys())
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
print(f"Wrote {N_SAMPLES} samples to {csv_path}")
|
||||
|
||||
# null vectors(空提示的 embedding,用于 cfg dropout)
|
||||
null_t5 = torch.zeros(1, T5_SEQ, T5_DIM)
|
||||
null_clip = torch.zeros(1, CLIP_DIM)
|
||||
torch.save(null_t5, f"{NULL_DIR}/null_t5.pt")
|
||||
torch.save(null_clip, f"{NULL_DIR}/null_clip.pt")
|
||||
print(f"Wrote null vectors to {NULL_DIR}/")
|
||||
|
||||
print("Done.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue