mysora/configs/diffusion/train/smoke_test.py

78 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

_base_ = ["image.py"]
# ===== Smoke Test Config =====
# cached_video + cached_text: 跳过 VAE / T5 / CLIP直接用预算好的 latent
# 仅用 1 GPU3 个训练步,验证训练 loop 通路
# --- Dataset ---
dataset = dict(
type="cached_video_text",
data_path="/data/train-input/smoke_test/data.csv",
cached_video=True,
cached_text=True,
load_original_video=True,
memory_efficient=False,
vmaf=False,
)
# --- Bucket: 256px 单帧batch_size=1 ---
bucket_config = {
"_delete_": True,
"256px": {
1: (1.0, 1),
},
}
# --- 跳过 VAE / T5 / CLIP 加载 ---
cached_video = True
cached_text = True
# --- Null 向量路径(空提示的 embedding用于 dropout ---
# train.py 里 cached_text 模式下 hardcode 读这两个路径
# 我们在 /mnt/ddn/sora/tmp_load/ 放假的 null 向量
# --- 模型:随机初始化,不加载预训练权重 ---
model = dict(
from_pretrained=None,
strict_load=False,
grad_ckpt_settings=(1, 100),
)
# --- 优化器 ---
lr = 1e-5
optim = dict(
cls="HybridAdam",
lr=lr,
eps=1e-15,
weight_decay=0.0,
adamw_mode=True,
)
# --- 训练参数 ---
epochs = 1
ckpt_every = 0 # 不保存 checkpoint
log_every = 1
warmup_steps = 0
grad_clip = 1.0
accumulation_steps = 1
ema_decay = None # 不用 EMA
# --- 加速 ---
dtype = "bf16"
plugin = "zero2"
plugin_config = dict(
reduce_bucket_size_in_m=128,
overlap_allgather=False,
)
pin_memory_cache_pre_alloc_numels = None # 关掉 pin memory 预分配
# --- 其他 ---
num_workers = 0
prefetch_factor = None
num_bucket_build_workers = 1
seed = 42
outputs = "/data/train-output/smoke_test_outputs"
grad_checkpoint = False # 关掉 activation checkpoint减少复杂度
# dropout ratio 保持原样0.31...null vector 在 /mnt/ddn/sora/tmp_load/ 下