258 lines
9.6 KiB
Python
258 lines
9.6 KiB
Python
import math
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor, nn
|
|
|
|
NUMEL_LIMIT = 2**30
|
|
|
|
|
|
def ceil_to_divisible(n: int, dividend: int) -> int:
|
|
return math.ceil(dividend / (dividend // n))
|
|
|
|
|
|
def chunked_avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
|
n_chunks = math.ceil(input.numel() / NUMEL_LIMIT)
|
|
if n_chunks == 1:
|
|
return F.avg_pool1d(input, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
|
else:
|
|
l_in = input.shape[-1]
|
|
l_out = math.floor((l_in + 2 * padding - kernel_size) / stride + 1)
|
|
output_shape = list(input.shape)
|
|
output_shape[-1] = l_out
|
|
out_list = []
|
|
|
|
for inp_chunk in input.chunk(n_chunks, dim=0):
|
|
out_chunk = F.avg_pool1d(inp_chunk, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
|
out_list.append(out_chunk)
|
|
return torch.cat(out_list, dim=0)
|
|
|
|
|
|
def chunked_interpolate(input, scale_factor):
|
|
output_shape = list(input.shape)
|
|
output_shape = output_shape[:2] + [int(i * scale_factor) for i in output_shape[2:]]
|
|
n_chunks = math.ceil(torch.Size(output_shape).numel() / NUMEL_LIMIT)
|
|
if n_chunks == 1:
|
|
return F.interpolate(input, scale_factor=scale_factor)
|
|
else:
|
|
out_list = []
|
|
n_chunks += 1
|
|
for inp_chunk in input.chunk(n_chunks, dim=1):
|
|
out_chunk = F.interpolate(inp_chunk, scale_factor=scale_factor)
|
|
out_list.append(out_chunk)
|
|
return torch.cat(out_list, dim=1)
|
|
|
|
|
|
def get_conv3d_output_shape(
|
|
input_shape: torch.Size, out_channels: int, kernel_size: list, stride: list, padding: int, dilation: list
|
|
) -> list:
|
|
output_shape = [out_channels]
|
|
if len(input_shape) == 5:
|
|
output_shape.insert(0, input_shape[0])
|
|
for i, d in enumerate(input_shape[-3:]):
|
|
d_out = math.floor((d + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) / stride[i] + 1)
|
|
output_shape.append(d_out)
|
|
return output_shape
|
|
|
|
|
|
def get_conv3d_n_chunks(numel: int, n_channels: int, numel_limit: int):
|
|
n_chunks = math.ceil(numel / numel_limit)
|
|
n_chunks = ceil_to_divisible(n_chunks, n_channels)
|
|
return n_chunks
|
|
|
|
|
|
def channel_chunk_conv3d(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor,
|
|
stride: list,
|
|
padding: list,
|
|
dilation: list,
|
|
groups: int,
|
|
numel_limit: int,
|
|
):
|
|
out_channels, in_channels = weight.shape[:2]
|
|
kernel_size = weight.shape[2:]
|
|
output_shape = get_conv3d_output_shape(input.shape, out_channels, kernel_size, stride, padding, dilation)
|
|
n_in_chunks = get_conv3d_n_chunks(input.numel(), in_channels, numel_limit)
|
|
n_out_chunks = get_conv3d_n_chunks(
|
|
np.prod(output_shape),
|
|
out_channels,
|
|
numel_limit,
|
|
)
|
|
if n_in_chunks == 1 and n_out_chunks == 1:
|
|
return F.conv3d(input, weight, bias, stride, padding, dilation, groups)
|
|
# output = torch.empty(output_shape, device=input.device, dtype=input.dtype)
|
|
# outputs = output.chunk(n_out_chunks, dim=1)
|
|
input_shards = input.chunk(n_in_chunks, dim=1)
|
|
weight_chunks = weight.chunk(n_out_chunks)
|
|
output_list = []
|
|
if bias is not None:
|
|
bias_chunks = bias.chunk(n_out_chunks)
|
|
else:
|
|
bias_chunks = [None] * n_out_chunks
|
|
for weight_, bias_ in zip(weight_chunks, bias_chunks):
|
|
weight_shards = weight_.chunk(n_in_chunks, dim=1)
|
|
o = None
|
|
for x, w in zip(input_shards, weight_shards):
|
|
if o is None:
|
|
o = F.conv3d(x, w, None, stride, padding, dilation, groups).float()
|
|
else:
|
|
o += F.conv3d(x, w, None, stride, padding, dilation, groups).float()
|
|
o = o.to(input.dtype)
|
|
if bias_ is not None:
|
|
o += bias_[None, :, None, None, None]
|
|
# inplace operation cannot be used during training
|
|
# output_.copy_(o)
|
|
output_list.append(o)
|
|
return torch.cat(output_list, dim=1)
|
|
|
|
|
|
class DiagonalGaussianDistribution(object):
|
|
def __init__(
|
|
self,
|
|
parameters,
|
|
deterministic=False,
|
|
):
|
|
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
|
self.parameters = parameters
|
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
|
self.deterministic = deterministic
|
|
self.std = torch.exp(0.5 * self.logvar)
|
|
self.var = torch.exp(self.logvar)
|
|
if self.deterministic:
|
|
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device, dtype=self.mean.dtype)
|
|
|
|
def sample(self):
|
|
# torch.randn: standard normal distribution
|
|
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device, dtype=self.mean.dtype)
|
|
return x
|
|
|
|
def kl(self, other=None):
|
|
if self.deterministic:
|
|
return torch.Tensor([0.0])
|
|
else:
|
|
if other is None: # SCH: assumes other is a standard normal distribution
|
|
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 3, 4]).flatten(0)
|
|
else:
|
|
return 0.5 * torch.sum(
|
|
torch.pow(self.mean - other.mean, 2) / other.var
|
|
+ self.var / other.var
|
|
- 1.0
|
|
- self.logvar
|
|
+ other.logvar,
|
|
dim=[1, 3, 4],
|
|
).flatten(0)
|
|
|
|
def mode(self):
|
|
return self.mean
|
|
|
|
|
|
class ChannelChunkConv3d(nn.Conv3d):
|
|
CONV3D_NUMEL_LIMIT = 2**31
|
|
|
|
def _get_output_numel(self, input_shape: torch.Size) -> int:
|
|
numel = self.out_channels
|
|
if len(input_shape) == 5:
|
|
numel *= input_shape[0]
|
|
for i, d in enumerate(input_shape[-3:]):
|
|
d_out = math.floor(
|
|
(d + 2 * self.padding[i] - self.dilation[i] * (self.kernel_size[i] - 1) - 1) / self.stride[i] + 1
|
|
)
|
|
numel *= d_out
|
|
return numel
|
|
|
|
def _get_n_chunks(self, numel: int, n_channels: int):
|
|
n_chunks = math.ceil(numel / ChannelChunkConv3d.CONV3D_NUMEL_LIMIT)
|
|
n_chunks = ceil_to_divisible(n_chunks, n_channels)
|
|
return n_chunks
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
if input.numel() // input.size(0) < ChannelChunkConv3d.CONV3D_NUMEL_LIMIT:
|
|
return super().forward(input)
|
|
n_in_chunks = self._get_n_chunks(input.numel(), self.in_channels)
|
|
n_out_chunks = self._get_n_chunks(self._get_output_numel(input.shape), self.out_channels)
|
|
if n_in_chunks == 1 and n_out_chunks == 1:
|
|
return super().forward(input)
|
|
outputs = []
|
|
input_shards = input.chunk(n_in_chunks, dim=1)
|
|
for weight, bias in zip(self.weight.chunk(n_out_chunks), self.bias.chunk(n_out_chunks)):
|
|
weight_shards = weight.chunk(n_in_chunks, dim=1)
|
|
o = None
|
|
for x, w in zip(input_shards, weight_shards):
|
|
if o is None:
|
|
o = F.conv3d(x, w, bias, self.stride, self.padding, self.dilation, self.groups)
|
|
else:
|
|
o += F.conv3d(x, w, None, self.stride, self.padding, self.dilation, self.groups)
|
|
outputs.append(o)
|
|
return torch.cat(outputs, dim=1)
|
|
|
|
|
|
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
|
def pad_for_conv3d(x: torch.Tensor, width_pad: int, height_pad: int, time_pad: int) -> torch.Tensor:
|
|
if width_pad > 0 or height_pad > 0:
|
|
x = F.pad(x, (width_pad, width_pad, height_pad, height_pad), mode="constant", value=0)
|
|
if time_pad > 0:
|
|
x = F.pad(x, (0, 0, 0, 0, time_pad, time_pad), mode="replicate")
|
|
return x
|
|
|
|
|
|
def pad_for_conv3d_kernel_3x3x3(x: torch.Tensor) -> torch.Tensor:
|
|
n_chunks = math.ceil(x.numel() / NUMEL_LIMIT)
|
|
if n_chunks == 1:
|
|
x = F.pad(x, (1, 1, 1, 1), mode="constant", value=0)
|
|
x = F.pad(x, (0, 0, 0, 0, 1, 1), mode="replicate")
|
|
else:
|
|
out_list = []
|
|
n_chunks += 1
|
|
for inp_chunk in x.chunk(n_chunks, dim=1):
|
|
out_chunk = F.pad(inp_chunk, (1, 1, 1, 1), mode="constant", value=0)
|
|
out_chunk = F.pad(out_chunk, (0, 0, 0, 0, 1, 1), mode="replicate")
|
|
out_list.append(out_chunk)
|
|
x = torch.cat(out_list, dim=1)
|
|
return x
|
|
|
|
|
|
class PadConv3D(nn.Module):
|
|
"""
|
|
pad the first frame in temporal dimension
|
|
"""
|
|
|
|
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
|
|
super().__init__()
|
|
|
|
if isinstance(kernel_size, int):
|
|
kernel_size = (kernel_size,) * 3
|
|
self.kernel_size = kernel_size
|
|
|
|
# == specific padding ==
|
|
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
|
assert time_kernel_size == height_kernel_size == width_kernel_size, "only support cubic kernel size"
|
|
if time_kernel_size == 3:
|
|
self.pad = pad_for_conv3d_kernel_3x3x3
|
|
else:
|
|
assert time_kernel_size == 1, f"only support kernel size 1/3 for now, got {kernel_size}"
|
|
self.pad = lambda x: x
|
|
|
|
self.conv = nn.Conv3d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
x = self.pad(x)
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
|
|
class ChannelChunkPadConv3D(PadConv3D):
|
|
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
|
|
super().__init__(in_channels, out_channels, kernel_size)
|
|
self.conv = ChannelChunkConv3d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1)
|