mysora/opensora/models/vae/losses.py

224 lines
7.2 KiB
Python

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, nn
from opensora.models.vae.lpips import LPIPS
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
)
return d_loss
def wgan_gp_loss(logits_real, logits_fake):
d_loss = 0.5 * (-logits_real.mean() + logits_fake.mean())
return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
def measure_perplexity(predicted_indices, n_embed):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
def l1(x, y):
return torch.abs(x - y)
def l2(x, y):
return torch.pow((x - y), 2)
def batch_mean(x):
return torch.sum(x) / x.shape[0]
def sigmoid_cross_entropy_with_logits(labels, logits):
# The final formulation is: max(x, 0) - x * z + log(1 + exp(-abs(x)))
zeros = torch.zeros_like(logits, dtype=logits.dtype)
condition = logits >= zeros
relu_logits = torch.where(condition, logits, zeros)
neg_abs_logits = torch.where(condition, -logits, logits)
return relu_logits - logits * labels + torch.log1p(torch.exp(neg_abs_logits))
def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred):
assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0
lecam_loss = torch.mean(torch.pow(nn.ReLU()(real_pred - ema_fake_pred), 2))
lecam_loss += torch.mean(torch.pow(nn.ReLU()(ema_real_pred - fake_pred), 2))
return lecam_loss
def gradient_penalty_fn(images, output):
gradients = torch.autograd.grad(
outputs=output,
inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = rearrange(gradients, "b ... -> b (...)")
return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
class VAELoss(nn.Module):
def __init__(
self,
logvar_init=0.0,
perceptual_loss_weight=1.0,
kl_loss_weight=5e-4,
device="cpu",
dtype="bf16",
):
super().__init__()
if type(dtype) == str:
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
elif dtype == "fp32":
dtype = torch.float32
else:
raise NotImplementedError(f"dtype: {dtype}")
# KL Loss
self.kl_weight = kl_loss_weight
# Perceptual Loss
self.perceptual_loss_fn = LPIPS().eval().to(device, dtype)
self.perceptual_loss_fn.requires_grad_(False)
self.perceptual_loss_weight = perceptual_loss_weight
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
def forward(
self,
video,
recon_video,
posterior,
) -> dict:
video.size(0)
video = rearrange(video, "b c t h w -> (b t) c h w").contiguous()
recon_video = rearrange(recon_video, "b c t h w -> (b t) c h w").contiguous()
# reconstruction loss
recon_loss = l1(video, recon_video)
# perceptual loss
perceptual_loss = self.perceptual_loss_fn(video, recon_video)
# nll loss (from reconstruction loss and perceptual loss)
nll_loss = recon_loss + perceptual_loss * self.perceptual_loss_weight
nll_loss = nll_loss / torch.exp(self.logvar) + self.logvar
# Batch Mean
nll_loss = batch_mean(nll_loss)
recon_loss = batch_mean(recon_loss)
numel_elements = video.numel() // video.size(0)
perceptual_loss = batch_mean(perceptual_loss) * numel_elements
# KL Loss
if posterior is None:
kl_loss = torch.tensor(0.0).to(video.device, video.dtype)
else:
kl_loss = posterior.kl()
kl_loss = batch_mean(kl_loss)
weighted_kl_loss = kl_loss * self.kl_weight
return {
"nll_loss": nll_loss,
"kl_loss": weighted_kl_loss,
"recon_loss": recon_loss,
"perceptual_loss": perceptual_loss,
}
class GeneratorLoss(nn.Module):
def __init__(self, gen_start=2001, disc_factor=1.0, disc_weight=0.5):
super().__init__()
self.disc_factor = disc_factor
self.gen_start = gen_start
self.disc_weight = disc_weight
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.disc_weight
return d_weight
def forward(
self,
logits_fake,
nll_loss,
last_layer,
global_step,
is_training=True,
):
g_loss = -torch.mean(logits_fake)
if self.disc_factor is not None and self.disc_factor > 0.0:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer)
else:
d_weight = torch.tensor(1.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.gen_start)
weighted_gen_loss = d_weight * disc_factor * g_loss
return weighted_gen_loss, g_loss
class DiscriminatorLoss(nn.Module):
def __init__(self, disc_start=2001, disc_factor=1.0, disc_loss_type="hinge"):
super().__init__()
assert disc_loss_type in ["hinge", "vanilla", "wgan-gp"]
self.disc_factor = disc_factor
self.disc_start = disc_start
self.disc_loss_type = disc_loss_type
if self.disc_loss_type == "hinge":
self.loss_fn = hinge_d_loss
elif self.disc_loss_type == "vanilla":
self.loss_fn = vanilla_d_loss
elif self.disc_loss_type == "wgan-gp":
self.loss_fn = wgan_gp_loss
else:
raise ValueError(f"Unknown GAN loss '{self.disc_loss_type}'.")
def forward(
self,
real_logits,
fake_logits,
global_step,
):
if self.disc_factor is not None and self.disc_factor > 0.0:
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.disc_start)
disc_loss = self.loss_fn(real_logits, fake_logits)
weighted_discriminator_loss = disc_factor * disc_loss
else:
weighted_discriminator_loss = 0
return weighted_discriminator_loss