304 lines
9.5 KiB
Python
304 lines
9.5 KiB
Python
# Modified from Flux
|
|
#
|
|
# Copyright 2024 Black Forest Labs
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
from torch import Tensor, nn
|
|
|
|
from opensora.acceleration.checkpoint import auto_grad_checkpoint
|
|
from opensora.models.mmdit.layers import (
|
|
DoubleStreamBlock,
|
|
EmbedND,
|
|
LastLayer,
|
|
LigerEmbedND,
|
|
MLPEmbedder,
|
|
SingleStreamBlock,
|
|
timestep_embedding,
|
|
)
|
|
from opensora.registry import MODELS
|
|
from opensora.utils.ckpt import load_checkpoint
|
|
|
|
|
|
@dataclass
|
|
class MMDiTConfig:
|
|
model_type = "MMDiT"
|
|
from_pretrained: str
|
|
cache_dir: str
|
|
in_channels: int
|
|
vec_in_dim: int
|
|
context_in_dim: int
|
|
hidden_size: int
|
|
mlp_ratio: float
|
|
num_heads: int
|
|
depth: int
|
|
depth_single_blocks: int
|
|
axes_dim: list[int]
|
|
theta: int
|
|
qkv_bias: bool
|
|
guidance_embed: bool
|
|
cond_embed: bool = False
|
|
fused_qkv: bool = True
|
|
grad_ckpt_settings: tuple[int, int] | None = None
|
|
use_liger_rope: bool = False
|
|
patch_size: int = 2
|
|
|
|
def get(self, attribute_name, default=None):
|
|
return getattr(self, attribute_name, default)
|
|
|
|
def __contains__(self, attribute_name):
|
|
return hasattr(self, attribute_name)
|
|
|
|
|
|
class MMDiTModel(nn.Module):
|
|
config_class = MMDiTConfig
|
|
|
|
def __init__(self, config: MMDiTConfig):
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
self.in_channels = config.in_channels
|
|
self.out_channels = self.in_channels
|
|
self.patch_size = config.patch_size
|
|
|
|
if config.hidden_size % config.num_heads != 0:
|
|
raise ValueError(
|
|
f"Hidden size {config.hidden_size} must be divisible by num_heads {config.num_heads}"
|
|
)
|
|
|
|
pe_dim = config.hidden_size // config.num_heads
|
|
if sum(config.axes_dim) != pe_dim:
|
|
raise ValueError(
|
|
f"Got {config.axes_dim} but expected positional dim {pe_dim}"
|
|
)
|
|
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_heads
|
|
pe_embedder_cls = LigerEmbedND if config.use_liger_rope else EmbedND
|
|
self.pe_embedder = pe_embedder_cls(
|
|
dim=pe_dim, theta=config.theta, axes_dim=config.axes_dim
|
|
)
|
|
|
|
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
|
self.vector_in = MLPEmbedder(config.vec_in_dim, self.hidden_size)
|
|
self.guidance_in = (
|
|
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
|
if config.guidance_embed
|
|
else nn.Identity()
|
|
)
|
|
self.cond_in = (
|
|
nn.Linear(
|
|
self.in_channels + self.patch_size**2, self.hidden_size, bias=True
|
|
)
|
|
if config.cond_embed
|
|
else nn.Identity()
|
|
)
|
|
self.txt_in = nn.Linear(config.context_in_dim, self.hidden_size)
|
|
|
|
self.double_blocks = nn.ModuleList(
|
|
[
|
|
DoubleStreamBlock(
|
|
self.hidden_size,
|
|
self.num_heads,
|
|
mlp_ratio=config.mlp_ratio,
|
|
qkv_bias=config.qkv_bias,
|
|
fused_qkv=config.fused_qkv,
|
|
)
|
|
for _ in range(config.depth)
|
|
]
|
|
)
|
|
|
|
self.single_blocks = nn.ModuleList(
|
|
[
|
|
SingleStreamBlock(
|
|
self.hidden_size,
|
|
self.num_heads,
|
|
mlp_ratio=config.mlp_ratio,
|
|
fused_qkv=config.fused_qkv,
|
|
)
|
|
for _ in range(config.depth_single_blocks)
|
|
]
|
|
)
|
|
|
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
|
self.initialize_weights()
|
|
|
|
if self.config.grad_ckpt_settings:
|
|
self.forward = self.forward_selective_ckpt
|
|
else:
|
|
self.forward = self.forward_ckpt
|
|
self._input_requires_grad = False
|
|
|
|
def initialize_weights(self):
|
|
if self.config.cond_embed:
|
|
nn.init.zeros_(self.cond_in.weight)
|
|
nn.init.zeros_(self.cond_in.bias)
|
|
|
|
def prepare_block_inputs(
|
|
self,
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
txt: Tensor, # t5 encoded vec
|
|
txt_ids: Tensor,
|
|
timesteps: Tensor,
|
|
y_vec: Tensor, # clip encoded vec
|
|
cond: Tensor = None,
|
|
guidance: Tensor | None = None,
|
|
):
|
|
"""
|
|
obtain the processed:
|
|
img: projected noisy img latent,
|
|
txt: text context (from t5),
|
|
vec: clip encoded vector,
|
|
pe: the positional embeddings for concatenated img and txt
|
|
"""
|
|
if img.ndim != 3 or txt.ndim != 3:
|
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
|
|
# running on sequences img
|
|
img = self.img_in(img)
|
|
if self.config.cond_embed:
|
|
if cond is None:
|
|
raise ValueError("Didn't get conditional input for conditional model.")
|
|
img = img + self.cond_in(cond)
|
|
|
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
|
if self.config.guidance_embed:
|
|
if guidance is None:
|
|
raise ValueError(
|
|
"Didn't get guidance strength for guidance distilled model."
|
|
)
|
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
|
vec = vec + self.vector_in(y_vec)
|
|
|
|
txt = self.txt_in(txt)
|
|
|
|
# concat: 4096 + t*h*2/4
|
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
pe = self.pe_embedder(ids)
|
|
|
|
if self._input_requires_grad:
|
|
# we only apply lora to double/single blocks, thus we only need to enable grad for these inputs
|
|
img.requires_grad_()
|
|
txt.requires_grad_()
|
|
|
|
return img, txt, vec, pe
|
|
|
|
def enable_input_require_grads(self):
|
|
"""Fit peft lora. This method should not be called manually."""
|
|
self._input_requires_grad = True
|
|
|
|
def forward_ckpt(
|
|
self,
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
txt: Tensor,
|
|
txt_ids: Tensor,
|
|
timesteps: Tensor,
|
|
y_vec: Tensor,
|
|
cond: Tensor = None,
|
|
guidance: Tensor | None = None,
|
|
**kwargs,
|
|
) -> Tensor:
|
|
img, txt, vec, pe = self.prepare_block_inputs(
|
|
img, img_ids, txt, txt_ids, timesteps, y_vec, cond, guidance
|
|
)
|
|
|
|
for block in self.double_blocks:
|
|
img, txt = auto_grad_checkpoint(block, img, txt, vec, pe)
|
|
|
|
img = torch.cat((txt, img), 1)
|
|
for block in self.single_blocks:
|
|
img = auto_grad_checkpoint(block, img, vec, pe)
|
|
img = img[:, txt.shape[1] :, ...]
|
|
|
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
|
return img
|
|
|
|
def forward_selective_ckpt(
|
|
self,
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
txt: Tensor,
|
|
txt_ids: Tensor,
|
|
timesteps: Tensor,
|
|
y_vec: Tensor,
|
|
cond: Tensor = None,
|
|
guidance: Tensor | None = None,
|
|
**kwargs,
|
|
) -> Tensor:
|
|
img, txt, vec, pe = self.prepare_block_inputs(
|
|
img, img_ids, txt, txt_ids, timesteps, y_vec, cond, guidance
|
|
)
|
|
|
|
ckpt_depth_double = self.config.grad_ckpt_settings[0]
|
|
for block in self.double_blocks[:ckpt_depth_double]:
|
|
img, txt = auto_grad_checkpoint(block, img, txt, vec, pe)
|
|
|
|
for block in self.double_blocks[ckpt_depth_double:]:
|
|
img, txt = block(img, txt, vec, pe)
|
|
|
|
ckpt_depth_single = self.config.grad_ckpt_settings[1]
|
|
img = torch.cat((txt, img), 1)
|
|
for block in self.single_blocks[:ckpt_depth_single]:
|
|
img = auto_grad_checkpoint(block, img, vec, pe)
|
|
for block in self.single_blocks[ckpt_depth_single:]:
|
|
img = block(img, vec, pe)
|
|
|
|
img = img[:, txt.shape[1] :, ...]
|
|
|
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
|
return img
|
|
|
|
|
|
@MODELS.register_module("flux")
|
|
def Flux(
|
|
cache_dir: str = None,
|
|
from_pretrained: str = None,
|
|
device_map: str | torch.device = "cuda",
|
|
torch_dtype: torch.dtype = torch.bfloat16,
|
|
strict_load: bool = False,
|
|
**kwargs,
|
|
) -> MMDiTModel:
|
|
config = MMDiTConfig(
|
|
from_pretrained=from_pretrained,
|
|
cache_dir=cache_dir,
|
|
**kwargs,
|
|
)
|
|
low_precision_init = from_pretrained is not None and len(from_pretrained) > 0
|
|
if low_precision_init:
|
|
default_dtype = torch.get_default_dtype()
|
|
torch.set_default_dtype(torch_dtype)
|
|
with torch.device(device_map):
|
|
model = MMDiTModel(config)
|
|
if low_precision_init:
|
|
torch.set_default_dtype(default_dtype)
|
|
else:
|
|
model = model.to(torch_dtype)
|
|
if from_pretrained:
|
|
model = load_checkpoint(
|
|
model,
|
|
from_pretrained,
|
|
cache_dir=cache_dir,
|
|
device_map=device_map,
|
|
strict=strict_load,
|
|
)
|
|
return model
|