# Modified from Flux # # Copyright 2024 Black Forest Labs # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass import torch from einops import rearrange from torch import Tensor, nn from torch.nn.functional import silu as swish from opensora.registry import MODELS from opensora.utils.ckpt import load_checkpoint from .utils import DiagonalGaussianDistribution @dataclass class AutoEncoderConfig: from_pretrained: str | None cache_dir: str | None resolution: int in_channels: int ch: int out_ch: int ch_mult: list[int] num_res_blocks: int z_channels: int scale_factor: float shift_factor: float sample: bool = True class AttnBlock(nn.Module): def __init__(self, in_channels: int): super().__init__() self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) def attention(self, h_: Tensor) -> Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() h_ = nn.functional.scaled_dot_product_attention(q, k, v) return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w) def forward(self, x: Tensor) -> Tensor: return x + self.proj_out(self.attention(x)) class ResnetBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h = x h = self.norm1(h) h = swish(h) h = self.conv1(h) h = self.norm2(h) h = swish(h) h = self.conv2(h) if self.in_channels != self.out_channels: x = self.nin_shortcut(x) return x + h class Downsample(nn.Module): def __init__(self, in_channels: int): super().__init__() self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x: Tensor) -> Tensor: pad = (0, 1, 0, 1) x = nn.functional.pad(x, pad, mode="constant", value=0) return self.conv(x) class Upsample(nn.Module): def __init__(self, in_channels: int): super().__init__() self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: Tensor) -> Tensor: x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") return self.conv(x) class Encoder(nn.Module): def __init__(self, config: AutoEncoderConfig): super().__init__() self.ch = config.ch self.num_resolutions = len(config.ch_mult) self.num_res_blocks = config.num_res_blocks self.resolution = config.resolution self.in_channels = config.in_channels # downsampling self.conv_in = nn.Conv2d(config.in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = config.resolution in_ch_mult = (1,) + tuple(config.ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() block_in = self.ch for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = config.ch * in_ch_mult[i_level] block_out = config.ch * config.ch_mult[i_level] for _ in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, 2 * config.z_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: Tensor) -> Tensor: # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1]) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__(self, config: AutoEncoderConfig): super().__init__() self.ch = config.ch self.num_resolutions = len(config.ch_mult) self.num_res_blocks = config.num_res_blocks self.resolution = config.resolution self.in_channels = config.in_channels self.ffactor = 2 ** (self.num_resolutions - 1) block_in = config.ch * config.ch_mult[self.num_resolutions - 1] curr_res = config.resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, config.z_channels, curr_res, curr_res) # z to block_in self.conv_in = nn.Conv2d(config.z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = config.ch * config.ch_mult[i_level] for _ in range(self.num_res_blocks + 1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, config.out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z: Tensor) -> Tensor: # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = swish(h) return self.conv_out(h) class AutoEncoder(nn.Module): def __init__(self, config: AutoEncoderConfig): super().__init__() self.encoder = Encoder(config) self.decoder = Decoder(config) self.scale_factor = config.scale_factor self.shift_factor = config.shift_factor self.sample = config.sample def encode_(self, x: Tensor) -> tuple[Tensor, DiagonalGaussianDistribution]: T = x.shape[2] x = rearrange(x, "b c t h w -> (b t) c h w") params = self.encoder(x) params = rearrange(params, "(b t) c h w -> b c t h w", t=T) posterior = DiagonalGaussianDistribution(params) if self.sample: z = posterior.sample() else: z = posterior.mode() z = self.scale_factor * (z - self.shift_factor) return z, posterior def encode(self, x: Tensor) -> Tensor: return self.encode_(x)[0] def decode(self, z: Tensor) -> Tensor: T = z.shape[2] z = rearrange(z, "b c t h w -> (b t) c h w") z = z / self.scale_factor + self.shift_factor x = self.decoder(z) x = rearrange(x, "(b t) c h w -> b c t h w", t=T) return x def forward(self, x: Tensor) -> tuple[Tensor, DiagonalGaussianDistribution, Tensor]: # encode x.shape[2] z, posterior = self.encode_(x) # decode x_rec = self.decode(z) return x_rec, posterior, z def get_last_layer(self): return self.decoder.conv_out.weight @MODELS.register_module("autoencoder_2d") def AutoEncoderFlux( from_pretrained: str, cache_dir=None, resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, device_map: str | torch.device = "cuda", torch_dtype: torch.dtype = torch.bfloat16, ) -> AutoEncoder: config = AutoEncoderConfig( from_pretrained=from_pretrained, cache_dir=cache_dir, resolution=resolution, in_channels=in_channels, ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, z_channels=z_channels, scale_factor=scale_factor, shift_factor=shift_factor, ) with torch.device(device_map): model = AutoEncoder(config).to(torch_dtype) if from_pretrained: model = load_checkpoint(model, from_pretrained, cache_dir=cache_dir, device_map=device_map) return model