mysora/opensora/models/vae/autoencoder_2d.py

340 lines
12 KiB
Python

# Modified from Flux
#
# Copyright 2024 Black Forest Labs
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
from torch.nn.functional import silu as swish
from opensora.registry import MODELS
from opensora.utils.ckpt import load_checkpoint
from .utils import DiagonalGaussianDistribution
@dataclass
class AutoEncoderConfig:
from_pretrained: str | None
cache_dir: str | None
resolution: int
in_channels: int
ch: int
out_ch: int
ch_mult: list[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
sample: bool = True
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: Tensor) -> Tensor:
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
return self.conv(x)
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor) -> Tensor:
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
return self.conv(x)
class Encoder(nn.Module):
def __init__(self, config: AutoEncoderConfig):
super().__init__()
self.ch = config.ch
self.num_resolutions = len(config.ch_mult)
self.num_res_blocks = config.num_res_blocks
self.resolution = config.resolution
self.in_channels = config.in_channels
# downsampling
self.conv_in = nn.Conv2d(config.in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = config.resolution
in_ch_mult = (1,) + tuple(config.ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = config.ch * in_ch_mult[i_level]
block_out = config.ch * config.ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, 2 * config.z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self, config: AutoEncoderConfig):
super().__init__()
self.ch = config.ch
self.num_resolutions = len(config.ch_mult)
self.num_res_blocks = config.num_res_blocks
self.resolution = config.resolution
self.in_channels = config.in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
block_in = config.ch * config.ch_mult[self.num_resolutions - 1]
curr_res = config.resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, config.z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(config.z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = config.ch * config.ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, config.out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = swish(h)
return self.conv_out(h)
class AutoEncoder(nn.Module):
def __init__(self, config: AutoEncoderConfig):
super().__init__()
self.encoder = Encoder(config)
self.decoder = Decoder(config)
self.scale_factor = config.scale_factor
self.shift_factor = config.shift_factor
self.sample = config.sample
def encode_(self, x: Tensor) -> tuple[Tensor, DiagonalGaussianDistribution]:
T = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
params = self.encoder(x)
params = rearrange(params, "(b t) c h w -> b c t h w", t=T)
posterior = DiagonalGaussianDistribution(params)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
z = self.scale_factor * (z - self.shift_factor)
return z, posterior
def encode(self, x: Tensor) -> Tensor:
return self.encode_(x)[0]
def decode(self, z: Tensor) -> Tensor:
T = z.shape[2]
z = rearrange(z, "b c t h w -> (b t) c h w")
z = z / self.scale_factor + self.shift_factor
x = self.decoder(z)
x = rearrange(x, "(b t) c h w -> b c t h w", t=T)
return x
def forward(self, x: Tensor) -> tuple[Tensor, DiagonalGaussianDistribution, Tensor]:
# encode
x.shape[2]
z, posterior = self.encode_(x)
# decode
x_rec = self.decode(z)
return x_rec, posterior, z
def get_last_layer(self):
return self.decoder.conv_out.weight
@MODELS.register_module("autoencoder_2d")
def AutoEncoderFlux(
from_pretrained: str,
cache_dir=None,
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
device_map: str | torch.device = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
) -> AutoEncoder:
config = AutoEncoderConfig(
from_pretrained=from_pretrained,
cache_dir=cache_dir,
resolution=resolution,
in_channels=in_channels,
ch=ch,
out_ch=out_ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
z_channels=z_channels,
scale_factor=scale_factor,
shift_factor=shift_factor,
)
with torch.device(device_map):
model = AutoEncoder(config).to(torch_dtype)
if from_pretrained:
model = load_checkpoint(model, from_pretrained, cache_dir=cache_dir, device_map=device_map)
return model