mysora/opensora/utils/train.py

459 lines
20 KiB
Python

import random
import warnings
from collections import OrderedDict
from datetime import timedelta
import torch
import torch.distributed as dist
import torch.nn.functional as F
from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
from einops import rearrange
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler
from tqdm import tqdm
from opensora.acceleration.parallel_states import (
set_data_parallel_group,
set_sequence_parallel_group,
set_tensor_parallel_group,
)
from opensora.utils.optimizer import LinearWarmupLR
def set_lr(
optimizer: torch.optim.Optimizer,
lr_scheduler: _LRScheduler,
lr: float,
initial_lr: float = None,
):
for param_group in optimizer.param_groups:
param_group["lr"] = lr
if isinstance(lr_scheduler, LinearWarmupLR):
lr_scheduler.base_lrs = [lr] * len(lr_scheduler.base_lrs)
if initial_lr is not None:
lr_scheduler.initial_lr = initial_lr
def set_warmup_steps(
lr_scheduler: _LRScheduler,
warmup_steps: int,
):
if isinstance(lr_scheduler, LinearWarmupLR):
lr_scheduler.warmup_steps = warmup_steps
def set_eps(
optimizer: torch.optim.Optimizer,
eps: float = None,
):
if eps is not None:
for param_group in optimizer.param_groups:
param_group["eps"] = eps
def setup_device() -> tuple[torch.device, DistCoordinator]:
"""
Setup the device and the distributed coordinator.
Returns:
tuple[torch.device, DistCoordinator]: The device and the distributed coordinator.
"""
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
# NOTE: A very large timeout is set to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
coordinator = DistCoordinator()
device = get_current_device()
return device, coordinator
def create_colossalai_plugin(
plugin: str,
dtype: str,
grad_clip: float,
**kwargs,
) -> LowLevelZeroPlugin | HybridParallelPlugin:
"""
Create a ColossalAI plugin.
Args:
plugin (str): The plugin name.
dtype (str): The data type.
grad_clip (float): The gradient clip value.
Returns:
LowLevelZeroPlugin | HybridParallelPlugin: The plugin.
"""
plugin_kwargs = dict(
precision=dtype,
initial_scale=2**16,
max_norm=grad_clip,
overlap_allgather=True,
cast_inputs=False,
reduce_bucket_size_in_m=20,
)
plugin_kwargs.update(kwargs)
sp_size = plugin_kwargs.get("sp_size", 1)
if plugin == "zero1" or plugin == "zero2":
assert sp_size == 1, "Zero plugin does not support sequence parallelism"
stage = 1 if plugin == "zero1" else 2
plugin = LowLevelZeroPlugin(
stage=stage,
**plugin_kwargs,
)
set_data_parallel_group(dist.group.WORLD)
elif plugin == "hybrid":
plugin_kwargs["find_unused_parameters"] = True
reduce_bucket_size_in_m = plugin_kwargs.pop("reduce_bucket_size_in_m")
if "zero_bucket_size_in_m" not in plugin_kwargs:
plugin_kwargs["zero_bucket_size_in_m"] = reduce_bucket_size_in_m
plugin_kwargs.pop("cast_inputs")
plugin_kwargs["enable_metadata_cache"] = False
custom_policy = plugin_kwargs.pop("custom_policy", None)
if custom_policy is not None:
custom_policy = custom_policy()
plugin = HybridParallelPlugin(
custom_policy=custom_policy,
**plugin_kwargs,
)
set_tensor_parallel_group(plugin.tp_group)
set_sequence_parallel_group(plugin.sp_group)
set_data_parallel_group(plugin.dp_group)
else:
raise ValueError(f"Unknown plugin {plugin}")
return plugin
@torch.no_grad()
def update_ema(
ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True
):
"""
Step the EMA model towards the current model.
Args:
ema_model (torch.nn.Module): The EMA model.
model (torch.nn.Module): The current model.
optimizer (torch.optim.Optimizer): The optimizer.
decay (float): The decay rate.
sharded (bool): Whether the model is sharded.
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
if name == "pos_embed":
continue
if not param.requires_grad:
continue
if not sharded:
param_data = param.data
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
else:
if param.data.dtype != torch.float32:
param_id = id(param)
master_param = optimizer.get_working_to_master_map()[param_id]
param_data = master_param.data
else:
param_data = param.data
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
def dropout_condition(prob: float, txt: torch.Tensor, null_txt: torch.Tensor) -> torch.Tensor:
"""
Apply dropout to the text tensor.
Args:
prob (float): The dropout probability.
txt (torch.Tensor): The text tensor.
null_txt (torch.Tensor): The null text tensor.
Returns:
torch.Tensor: The text tensor with dropout applied.
"""
if prob == 0:
warnings.warn("Dropout probability is 0, skipping dropout")
drop_ids = torch.rand(txt.shape[0], device=txt.device) < prob
drop_ids = drop_ids.view((drop_ids.shape[0],) + (1,) * (txt.ndim - 1))
new_txt = torch.where(drop_ids, null_txt, txt)
return new_txt
def prepare_visual_condition_uncausal(
x: torch.Tensor, condition_config: dict, model_ae: torch.nn.Module, pad: bool = False
) -> torch.Tensor:
"""
Prepare the visual condition for the model.
Args:
x: (torch.Tensor): The input video tensor.
condition_config (dict): The condition configuration.
model_ae (torch.nn.Module): The video encoder module.
Returns:
torch.Tensor: The visual condition tensor.
"""
# x has shape [b, c, t, h, w], where b is the batch size
B = x.shape[0]
C = model_ae.cfg.latent_channels
T, H, W = model_ae.get_latent_size(x.shape[-3:])
# Initialize masks tensor to match the shape of x, but only the time dimension will be masked
masks = torch.zeros(B, 1, T, H, W).to(
x.device, x.dtype
) # broadcasting over channel, concat to masked_x with 1 + 16 = 17 channesl
# to prevent information leakage, image must be encoded separately and copied to latent
latent = torch.zeros(B, C, T, H, W).to(x.device, x.dtype)
x_0 = torch.zeros(B, C, T, H, W).to(x.device, x.dtype)
if T > 1: # video
# certain v2v conditions not are applicable for short videos
if T <= 32 // model_ae.time_compression_ratio:
condition_config.pop("v2v_head", None) # given first 32 frames
condition_config.pop("v2v_tail", None) # given last 32 frames
condition_config.pop("v2v_head_easy", None) # given first 64 frames
condition_config.pop("v2v_tail_easy", None) # given last 64 frames
if T <= 64 // model_ae.time_compression_ratio:
condition_config.pop("v2v_head_easy", None) # given first 64 frames
condition_config.pop("v2v_tail_easy", None) # given last 64 frames
mask_cond_options = list(condition_config.keys()) # list of mask conditions
mask_cond_weights = list(condition_config.values()) # corresponding probabilities
for i in range(B):
# Randomly select a mask condition based on the provided probabilities
mask_cond = random.choices(mask_cond_options, weights=mask_cond_weights, k=1)[0]
# Apply the selected mask condition directly on the masks tensor
if mask_cond == "i2v_head": # NOTE: modify video, mask first latent frame
# padded video such that the first latent frame correspond to image only
masks[i, :, 0, :, :] = 1
if pad:
pad_num = model_ae.time_compression_ratio - 1 # 32 --> new video: 7 + (1+31-7)
padded_x = torch.cat([x[i, :, :1]] * pad_num + [x[i, :, :-pad_num]], dim=1).unsqueeze(0)
x_0[i] = model_ae.encode(padded_x)[0]
else:
x_0[i] = model_ae.encode(x[i : i + 1])[0]
# condition: encode the image only
latent[i, :, :1, :, :] = model_ae.encode(
x[i, :, :1, :, :].unsqueeze(0)
) # since the first dimension of right hand side is singleton, torch auto-ignores it
elif mask_cond == "i2v_loop": # # NOTE: modify video, mask first and last latent frame
# pad video such that first and last latent frame correspond to image only
masks[i, :, 0, :, :] = 1
masks[i, :, -1, :, :] = 1
if pad:
pad_num = model_ae.time_compression_ratio - 1
padded_x = torch.cat(
[x[i, :, :1]] * pad_num
+ [x[i, :, : -pad_num * 2]]
+ [x[i, :, -pad_num * 2 - 1].unsqueeze(1)] * pad_num,
dim=1,
).unsqueeze(
0
) # remove the last pad_num * 2 frames from the end of the video
x_0[i] = model_ae.encode(padded_x)[0]
# condition: encode the image only
latent[i, :, :1, :, :] = model_ae.encode(x[i, :, :1, :, :].unsqueeze(0))
latent[i, :, -1:, :, :] = model_ae.encode(x[i, :, -pad_num * 2 - 1, :, :].unsqueeze(1).unsqueeze(0))
else:
x_0[i] = model_ae.encode(x[i : i + 1])[0]
latent[i, :, :1, :, :] = model_ae.encode(x[i, :, :1, :, :].unsqueeze(0))
latent[i, :, -1:, :, :] = model_ae.encode(x[i, :, -1:, :, :].unsqueeze(0))
elif mask_cond == "i2v_tail": # mask the last latent frame
masks[i, :, -1, :, :] = 1
if pad:
pad_num = model_ae.time_compression_ratio - 1
padded_x = torch.cat([x[i, :, pad_num:]] + [x[i, :, -1:]] * pad_num, dim=1).unsqueeze(0)
x_0[i] = model_ae.encode(padded_x)[0]
latent[i, :, -1:, :, :] = model_ae.encode(x[i, :, -pad_num * 2 - 1, :, :].unsqueeze(1).unsqueeze(0))
else:
x_0[i] = model_ae.encode(x[i : i + 1])[0]
latent[i, :, -1:, :, :] = model_ae.encode(x[i, :, -1:, :, :].unsqueeze(0))
elif mask_cond == "v2v_head": # mask the first 32 video frames
assert T > 32 // model_ae.time_compression_ratio
conditioned_t = 32 // model_ae.time_compression_ratio
masks[i, :, :conditioned_t, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
latent[i, :, :conditioned_t, :, :] = x_0[i, :, :conditioned_t, :, :]
elif mask_cond == "v2v_tail": # mask the last 32 video frames
assert T > 32 // model_ae.time_compression_ratio
conditioned_t = 32 // model_ae.time_compression_ratio
masks[i, :, -conditioned_t:, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
latent[i, :, -conditioned_t:, :, :] = x_0[i, :, -conditioned_t:, :, :]
elif mask_cond == "v2v_head_easy": # mask the first 64 video frames
assert T > 64 // model_ae.time_compression_ratio
conditioned_t = 64 // model_ae.time_compression_ratio
masks[i, :, :conditioned_t, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
latent[i, :, :conditioned_t, :, :] = x_0[i, :, :conditioned_t, :, :]
elif mask_cond == "v2v_tail_easy": # mask the last 64 video frames
assert T > 64 // model_ae.time_compression_ratio
conditioned_t = 64 // model_ae.time_compression_ratio
masks[i, :, -conditioned_t:, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
latent[i, :, -conditioned_t:, :, :] = x_0[i, :, -conditioned_t:, :, :]
# elif mask_cond == "v2v_head": # mask from the beginning to a random point
# masks[i, :, : random.randint(1, T - 2), :, :] = 1
# elif mask_cond == "v2v_tail": # mask from a random point to the end
# masks[i, :, -random.randint(1, T - 2) :, :, :] = 1
else:
# "t2v" is the fallback case where no specific condition is specified
assert mask_cond == "t2v", f"Unknown mask condition {mask_cond}"
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
else: # image
x_0 = model_ae.encode(x) # latent video
latent = masks * latent # condition latent
# merge the masks and the masked_x into a single tensor
cond = torch.cat((masks, latent), dim=1)
return x_0, cond
def prepare_visual_condition_causal(x: torch.Tensor, condition_config: dict, model_ae: torch.nn.Module) -> torch.Tensor:
"""
Prepare the visual condition for the model.
Args:
x: (torch.Tensor): The input video tensor.
condition_config (dict): The condition configuration.
model_ae (torch.nn.Module): The video encoder module.
Returns:
torch.Tensor: The visual condition tensor.
"""
# x has shape [b, c, t, h, w], where b is the batch size
B = x.shape[0]
C = model_ae.cfg.latent_channels
T, H, W = model_ae.get_latent_size(x.shape[-3:])
# Initialize masks tensor to match the shape of x, but only the time dimension will be masked
masks = torch.zeros(B, 1, T, H, W).to(
x.device, x.dtype
) # broadcasting over channel, concat to masked_x with 1 + 16 = 17 channesl
# to prevent information leakage, image must be encoded separately and copied to latent
latent = torch.zeros(B, C, T, H, W).to(x.device, x.dtype)
x_0 = torch.zeros(B, C, T, H, W).to(x.device, x.dtype)
if T > 1: # video
# certain v2v conditions not are applicable for short videos
if T <= (32 // model_ae.time_compression_ratio) + 1:
condition_config.pop("v2v_head", None) # given first 33 frames
condition_config.pop("v2v_tail", None) # given last 33 frames
condition_config.pop("v2v_head_easy", None) # given first 65 frames
condition_config.pop("v2v_tail_easy", None) # given last 65 frames
if T <= (64 // model_ae.time_compression_ratio) + 1:
condition_config.pop("v2v_head_easy", None) # given first 65 frames
condition_config.pop("v2v_tail_easy", None) # given last 65 frames
mask_cond_options = list(condition_config.keys()) # list of mask conditions
mask_cond_weights = list(condition_config.values()) # corresponding probabilities
for i in range(B):
# Randomly select a mask condition based on the provided probabilities
mask_cond = random.choices(mask_cond_options, weights=mask_cond_weights, k=1)[0]
# Apply the selected mask condition directly on the masks tensor
if mask_cond == "i2v_head": # NOTE: modify video, mask first latent frame
masks[i, :, 0, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
# condition: encode the image only
latent[i, :, :1, :, :] = model_ae.encode(x[i, :, :1, :, :].unsqueeze(0))
elif mask_cond == "i2v_loop": # # NOTE: modify video, mask first and last latent frame
# pad video such that first and last latent frame correspond to image only
masks[i, :, 0, :, :] = 1
masks[i, :, -1, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
# condition: encode the image only
latent[i, :, :1, :, :] = model_ae.encode(x[i, :, :1, :, :].unsqueeze(0))
latent[i, :, -1:, :, :] = model_ae.encode(x[i, :, -1:, :, :].unsqueeze(0))
elif mask_cond == "i2v_tail": # mask the last latent frame
masks[i, :, -1, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
# condition: encode the last image only
latent[i, :, -1:, :, :] = model_ae.encode(x[i, :, -1:, :, :].unsqueeze(0))
elif "v2v_head" in mask_cond: # mask the first 33 video frames
ref_t = 33 if not "easy" in mask_cond else 65
assert (ref_t - 1) % model_ae.time_compression_ratio == 0
conditioned_t = (ref_t - 1) // model_ae.time_compression_ratio + 1
masks[i, :, :conditioned_t, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
# encode the first ref_t frame video separately
latent[i, :, :conditioned_t, :, :] = model_ae.encode(x[i, :, :ref_t, :, :].unsqueeze(0))
elif "v2v_tail" in mask_cond: # mask the last 32 video frames
ref_t = 33 if not "easy" in mask_cond else 65
assert (ref_t - 1) % model_ae.time_compression_ratio == 0
conditioned_t = (ref_t - 1) // model_ae.time_compression_ratio + 1
masks[i, :, -conditioned_t:, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
# encode the first ref_t frame video separately
latent[i, :, -conditioned_t:, :, :] = model_ae.encode(x[i, :, -ref_t:, :, :].unsqueeze(0))
else:
# "t2v" is the fallback case where no specific condition is specified
assert mask_cond == "t2v", f"Unknown mask condition {mask_cond}"
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
else: # image
x_0 = model_ae.encode(x) # latent video
latent = masks * latent # condition latent
# merge the masks and the masked_x into a single tensor
cond = torch.cat((masks, latent), dim=1)
return x_0, cond
def get_batch_loss(model_pred, v_t, masks=None):
# for I2V, only include the generated frames in loss calculation
if masks is not None: # shape [B, T, H, W]
num_frames, height, width = masks.shape[-3:]
masks = masks[:, :, 0, 0] # only look at [B, T]
model_pred = rearrange(
model_pred,
"b (t h w) (c ph pw) -> b c t (h ph) (w pw)",
h=height // 2,
w=width // 2,
t=num_frames,
ph=2,
pw=2,
)
v_t = rearrange(
v_t,
"b (t h w) (c ph pw) -> b c t (h ph) (w pw)",
h=height // 2,
w=width // 2,
t=num_frames,
ph=2,
pw=2,
)
batch_loss = 0
for i in range(model_pred.size(0)):
pred_val = model_pred[i]
target_val = v_t[i]
if masks[i][0] == 1 and (not 1 in masks[i][1:-1]): # have front padding
pred_val = pred_val[:, 1:]
target_val = target_val[:, 1:]
if masks[i][-1] == 1 and (not 1 in masks[i][1:-1]): # have tail padding
pred_val = pred_val[:, :-1]
target_val = target_val[:, :-1]
batch_loss += F.mse_loss(pred_val.float(), target_val.float(), reduction="mean")
# print(f"mask {masks[i]}, pred_val shape: {pred_val.size()}")
loss = batch_loss / model_pred.size(0)
else:
# use reduction mean so that each batch will have same level of influence regardless of batch size
loss = F.mse_loss(model_pred.float(), v_t.float(), reduction="mean")
return loss
@torch.no_grad()
def warmup_ae(model_ae: nn.Module, shapes: list[tuple[int, ...]], device: torch.device, dtype: torch.dtype):
progress_bar = tqdm(shapes, desc="Warmup AE", disable=dist.get_rank() != 0)
for x_shape in progress_bar:
x = torch.randn(*x_shape, device=device, dtype=dtype)
_ = model_ae.encode(x)