78 lines
1.8 KiB
Python
78 lines
1.8 KiB
Python
_base_ = ["image.py"]
|
||
|
||
# ===== Smoke Test Config =====
|
||
# cached_video + cached_text: 跳过 VAE / T5 / CLIP,直接用预算好的 latent
|
||
# 仅用 1 GPU,3 个训练步,验证训练 loop 通路
|
||
|
||
# --- Dataset ---
|
||
dataset = dict(
|
||
type="cached_video_text",
|
||
data_path="/data/train-input/smoke_test/data.csv",
|
||
cached_video=True,
|
||
cached_text=True,
|
||
load_original_video=True,
|
||
memory_efficient=False,
|
||
vmaf=False,
|
||
)
|
||
|
||
# --- Bucket: 256px 单帧,batch_size=1 ---
|
||
bucket_config = {
|
||
"_delete_": True,
|
||
"256px": {
|
||
1: (1.0, 1),
|
||
},
|
||
}
|
||
|
||
# --- 跳过 VAE / T5 / CLIP 加载 ---
|
||
cached_video = True
|
||
cached_text = True
|
||
|
||
# --- Null 向量路径(空提示的 embedding,用于 dropout) ---
|
||
# train.py 里 cached_text 模式下 hardcode 读这两个路径
|
||
# 我们在 /mnt/ddn/sora/tmp_load/ 放假的 null 向量
|
||
|
||
# --- 模型:随机初始化,不加载预训练权重 ---
|
||
model = dict(
|
||
from_pretrained=None,
|
||
strict_load=False,
|
||
grad_ckpt_settings=(1, 100),
|
||
)
|
||
|
||
# --- 优化器 ---
|
||
lr = 1e-5
|
||
optim = dict(
|
||
cls="HybridAdam",
|
||
lr=lr,
|
||
eps=1e-15,
|
||
weight_decay=0.0,
|
||
adamw_mode=True,
|
||
)
|
||
|
||
# --- 训练参数 ---
|
||
epochs = 1
|
||
ckpt_every = 0 # 不保存 checkpoint
|
||
log_every = 1
|
||
warmup_steps = 0
|
||
grad_clip = 1.0
|
||
accumulation_steps = 1
|
||
ema_decay = None # 不用 EMA
|
||
|
||
# --- 加速 ---
|
||
dtype = "bf16"
|
||
plugin = "zero2"
|
||
plugin_config = dict(
|
||
reduce_bucket_size_in_m=128,
|
||
overlap_allgather=False,
|
||
)
|
||
pin_memory_cache_pre_alloc_numels = None # 关掉 pin memory 预分配
|
||
|
||
# --- 其他 ---
|
||
num_workers = 0
|
||
prefetch_factor = None
|
||
num_bucket_build_workers = 1
|
||
seed = 42
|
||
outputs = "/data/train-output/smoke_test_outputs"
|
||
grad_checkpoint = False # 关掉 activation checkpoint,减少复杂度
|
||
|
||
# dropout ratio 保持原样(0.31...),null vector 在 /mnt/ddn/sora/tmp_load/ 下
|