115 lines
2.4 KiB
Python
115 lines
2.4 KiB
Python
# Dataset settings
|
|
dataset = dict(
|
|
type="video_text",
|
|
transform_name="resize_crop",
|
|
fps_max=24, # the desired fps for training
|
|
vmaf=True, # load vmaf scores into text
|
|
)
|
|
|
|
grad_ckpt_settings = (8, 100) # set the grad checkpoint settings
|
|
bucket_config = {
|
|
"256px": {1: (1.0, 50)},
|
|
"768px": {1: (0.5, 11)},
|
|
"1024px": {1: (0.5, 7)},
|
|
}
|
|
|
|
# Define model components
|
|
model = dict(
|
|
type="flux",
|
|
from_pretrained=None,
|
|
strict_load=False,
|
|
guidance_embed=False,
|
|
fused_qkv=False,
|
|
use_liger_rope=True,
|
|
grad_ckpt_settings=grad_ckpt_settings,
|
|
# model architecture
|
|
in_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
)
|
|
dropout_ratio = { # probability for dropout text embedding
|
|
"t5": 0.31622777,
|
|
"clip": 0.31622777,
|
|
}
|
|
ae = dict(
|
|
type="hunyuan_vae",
|
|
from_pretrained="./ckpts/hunyuan_vae.safetensors",
|
|
in_channels=3,
|
|
out_channels=3,
|
|
layers_per_block=2,
|
|
latent_channels=16,
|
|
use_spatial_tiling=True,
|
|
use_temporal_tiling=False,
|
|
)
|
|
is_causal_vae = True
|
|
t5 = dict(
|
|
type="text_embedder",
|
|
from_pretrained="google/t5-v1_1-xxl",
|
|
cache_dir="/mnt/ddn/sora/tmp_load/huggingface/hub/",
|
|
max_length=512,
|
|
shardformer=True,
|
|
)
|
|
clip = dict(
|
|
type="text_embedder",
|
|
from_pretrained="openai/clip-vit-large-patch14",
|
|
cache_dir="/mnt/ddn/sora/tmp_load/huggingface/hub/",
|
|
max_length=77,
|
|
)
|
|
|
|
# Optimization settings
|
|
lr = 1e-5
|
|
eps = 1e-15
|
|
optim = dict(
|
|
cls="HybridAdam",
|
|
lr=lr,
|
|
eps=eps,
|
|
weight_decay=0.0,
|
|
adamw_mode=True,
|
|
)
|
|
warmup_steps = 0
|
|
update_warmup_steps = True
|
|
|
|
grad_clip = 1.0
|
|
accumulation_steps = 1
|
|
ema_decay = None
|
|
|
|
# Acceleration settings
|
|
prefetch_factor = 2
|
|
num_workers = 12
|
|
num_bucket_build_workers = 64
|
|
dtype = "bf16"
|
|
plugin = "zero2"
|
|
grad_checkpoint = True
|
|
plugin_config = dict(
|
|
reduce_bucket_size_in_m=128,
|
|
overlap_allgather=False,
|
|
)
|
|
pin_memory_cache_pre_alloc_numels = [(260 + 20) * 1024 * 1024] * 24 + [
|
|
(34 + 20) * 1024 * 1024
|
|
] * 4
|
|
async_io = False
|
|
|
|
# Other settings
|
|
seed = 42
|
|
outputs = "outputs"
|
|
epochs = 1000
|
|
log_every = 10
|
|
ckpt_every = 100
|
|
keep_n_latest = 20
|
|
wandb_project = "mmdit"
|
|
|
|
save_master_weights = True
|
|
load_master_weights = True
|
|
|
|
# For debugging
|
|
# record_time = True
|
|
# record_barrier = True
|