mysora/opensora/models/vae/utils.py

258 lines
9.6 KiB
Python

import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
NUMEL_LIMIT = 2**30
def ceil_to_divisible(n: int, dividend: int) -> int:
return math.ceil(dividend / (dividend // n))
def chunked_avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
n_chunks = math.ceil(input.numel() / NUMEL_LIMIT)
if n_chunks == 1:
return F.avg_pool1d(input, kernel_size, stride, padding, ceil_mode, count_include_pad)
else:
l_in = input.shape[-1]
l_out = math.floor((l_in + 2 * padding - kernel_size) / stride + 1)
output_shape = list(input.shape)
output_shape[-1] = l_out
out_list = []
for inp_chunk in input.chunk(n_chunks, dim=0):
out_chunk = F.avg_pool1d(inp_chunk, kernel_size, stride, padding, ceil_mode, count_include_pad)
out_list.append(out_chunk)
return torch.cat(out_list, dim=0)
def chunked_interpolate(input, scale_factor):
output_shape = list(input.shape)
output_shape = output_shape[:2] + [int(i * scale_factor) for i in output_shape[2:]]
n_chunks = math.ceil(torch.Size(output_shape).numel() / NUMEL_LIMIT)
if n_chunks == 1:
return F.interpolate(input, scale_factor=scale_factor)
else:
out_list = []
n_chunks += 1
for inp_chunk in input.chunk(n_chunks, dim=1):
out_chunk = F.interpolate(inp_chunk, scale_factor=scale_factor)
out_list.append(out_chunk)
return torch.cat(out_list, dim=1)
def get_conv3d_output_shape(
input_shape: torch.Size, out_channels: int, kernel_size: list, stride: list, padding: int, dilation: list
) -> list:
output_shape = [out_channels]
if len(input_shape) == 5:
output_shape.insert(0, input_shape[0])
for i, d in enumerate(input_shape[-3:]):
d_out = math.floor((d + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) / stride[i] + 1)
output_shape.append(d_out)
return output_shape
def get_conv3d_n_chunks(numel: int, n_channels: int, numel_limit: int):
n_chunks = math.ceil(numel / numel_limit)
n_chunks = ceil_to_divisible(n_chunks, n_channels)
return n_chunks
def channel_chunk_conv3d(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: list,
padding: list,
dilation: list,
groups: int,
numel_limit: int,
):
out_channels, in_channels = weight.shape[:2]
kernel_size = weight.shape[2:]
output_shape = get_conv3d_output_shape(input.shape, out_channels, kernel_size, stride, padding, dilation)
n_in_chunks = get_conv3d_n_chunks(input.numel(), in_channels, numel_limit)
n_out_chunks = get_conv3d_n_chunks(
np.prod(output_shape),
out_channels,
numel_limit,
)
if n_in_chunks == 1 and n_out_chunks == 1:
return F.conv3d(input, weight, bias, stride, padding, dilation, groups)
# output = torch.empty(output_shape, device=input.device, dtype=input.dtype)
# outputs = output.chunk(n_out_chunks, dim=1)
input_shards = input.chunk(n_in_chunks, dim=1)
weight_chunks = weight.chunk(n_out_chunks)
output_list = []
if bias is not None:
bias_chunks = bias.chunk(n_out_chunks)
else:
bias_chunks = [None] * n_out_chunks
for weight_, bias_ in zip(weight_chunks, bias_chunks):
weight_shards = weight_.chunk(n_in_chunks, dim=1)
o = None
for x, w in zip(input_shards, weight_shards):
if o is None:
o = F.conv3d(x, w, None, stride, padding, dilation, groups).float()
else:
o += F.conv3d(x, w, None, stride, padding, dilation, groups).float()
o = o.to(input.dtype)
if bias_ is not None:
o += bias_[None, :, None, None, None]
# inplace operation cannot be used during training
# output_.copy_(o)
output_list.append(o)
return torch.cat(output_list, dim=1)
class DiagonalGaussianDistribution(object):
def __init__(
self,
parameters,
deterministic=False,
):
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device, dtype=self.mean.dtype)
def sample(self):
# torch.randn: standard normal distribution
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device, dtype=self.mean.dtype)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None: # SCH: assumes other is a standard normal distribution
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 3, 4]).flatten(0)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 3, 4],
).flatten(0)
def mode(self):
return self.mean
class ChannelChunkConv3d(nn.Conv3d):
CONV3D_NUMEL_LIMIT = 2**31
def _get_output_numel(self, input_shape: torch.Size) -> int:
numel = self.out_channels
if len(input_shape) == 5:
numel *= input_shape[0]
for i, d in enumerate(input_shape[-3:]):
d_out = math.floor(
(d + 2 * self.padding[i] - self.dilation[i] * (self.kernel_size[i] - 1) - 1) / self.stride[i] + 1
)
numel *= d_out
return numel
def _get_n_chunks(self, numel: int, n_channels: int):
n_chunks = math.ceil(numel / ChannelChunkConv3d.CONV3D_NUMEL_LIMIT)
n_chunks = ceil_to_divisible(n_chunks, n_channels)
return n_chunks
def forward(self, input: Tensor) -> Tensor:
if input.numel() // input.size(0) < ChannelChunkConv3d.CONV3D_NUMEL_LIMIT:
return super().forward(input)
n_in_chunks = self._get_n_chunks(input.numel(), self.in_channels)
n_out_chunks = self._get_n_chunks(self._get_output_numel(input.shape), self.out_channels)
if n_in_chunks == 1 and n_out_chunks == 1:
return super().forward(input)
outputs = []
input_shards = input.chunk(n_in_chunks, dim=1)
for weight, bias in zip(self.weight.chunk(n_out_chunks), self.bias.chunk(n_out_chunks)):
weight_shards = weight.chunk(n_in_chunks, dim=1)
o = None
for x, w in zip(input_shards, weight_shards):
if o is None:
o = F.conv3d(x, w, bias, self.stride, self.padding, self.dilation, self.groups)
else:
o += F.conv3d(x, w, None, self.stride, self.padding, self.dilation, self.groups)
outputs.append(o)
return torch.cat(outputs, dim=1)
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
def pad_for_conv3d(x: torch.Tensor, width_pad: int, height_pad: int, time_pad: int) -> torch.Tensor:
if width_pad > 0 or height_pad > 0:
x = F.pad(x, (width_pad, width_pad, height_pad, height_pad), mode="constant", value=0)
if time_pad > 0:
x = F.pad(x, (0, 0, 0, 0, time_pad, time_pad), mode="replicate")
return x
def pad_for_conv3d_kernel_3x3x3(x: torch.Tensor) -> torch.Tensor:
n_chunks = math.ceil(x.numel() / NUMEL_LIMIT)
if n_chunks == 1:
x = F.pad(x, (1, 1, 1, 1), mode="constant", value=0)
x = F.pad(x, (0, 0, 0, 0, 1, 1), mode="replicate")
else:
out_list = []
n_chunks += 1
for inp_chunk in x.chunk(n_chunks, dim=1):
out_chunk = F.pad(inp_chunk, (1, 1, 1, 1), mode="constant", value=0)
out_chunk = F.pad(out_chunk, (0, 0, 0, 0, 1, 1), mode="replicate")
out_list.append(out_chunk)
x = torch.cat(out_list, dim=1)
return x
class PadConv3D(nn.Module):
"""
pad the first frame in temporal dimension
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
self.kernel_size = kernel_size
# == specific padding ==
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert time_kernel_size == height_kernel_size == width_kernel_size, "only support cubic kernel size"
if time_kernel_size == 3:
self.pad = pad_for_conv3d_kernel_3x3x3
else:
assert time_kernel_size == 1, f"only support kernel size 1/3 for now, got {kernel_size}"
self.pad = lambda x: x
self.conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=0,
)
def forward(self, x: Tensor) -> Tensor:
x = self.pad(x)
x = self.conv(x)
return x
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
class ChannelChunkPadConv3D(PadConv3D):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
super().__init__(in_channels, out_channels, kernel_size)
self.conv = ChannelChunkConv3d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1)