mysora/opensora/models/mmdit/model.py

304 lines
9.5 KiB
Python

# Modified from Flux
#
# Copyright 2024 Black Forest Labs
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.models.mmdit.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
LigerEmbedND,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
from opensora.registry import MODELS
from opensora.utils.ckpt import load_checkpoint
@dataclass
class MMDiTConfig:
model_type = "MMDiT"
from_pretrained: str
cache_dir: str
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
cond_embed: bool = False
fused_qkv: bool = True
grad_ckpt_settings: tuple[int, int] | None = None
use_liger_rope: bool = False
patch_size: int = 2
def get(self, attribute_name, default=None):
return getattr(self, attribute_name, default)
def __contains__(self, attribute_name):
return hasattr(self, attribute_name)
class MMDiTModel(nn.Module):
config_class = MMDiTConfig
def __init__(self, config: MMDiTConfig):
super().__init__()
self.config = config
self.in_channels = config.in_channels
self.out_channels = self.in_channels
self.patch_size = config.patch_size
if config.hidden_size % config.num_heads != 0:
raise ValueError(
f"Hidden size {config.hidden_size} must be divisible by num_heads {config.num_heads}"
)
pe_dim = config.hidden_size // config.num_heads
if sum(config.axes_dim) != pe_dim:
raise ValueError(
f"Got {config.axes_dim} but expected positional dim {pe_dim}"
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_heads
pe_embedder_cls = LigerEmbedND if config.use_liger_rope else EmbedND
self.pe_embedder = pe_embedder_cls(
dim=pe_dim, theta=config.theta, axes_dim=config.axes_dim
)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(config.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
if config.guidance_embed
else nn.Identity()
)
self.cond_in = (
nn.Linear(
self.in_channels + self.patch_size**2, self.hidden_size, bias=True
)
if config.cond_embed
else nn.Identity()
)
self.txt_in = nn.Linear(config.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
fused_qkv=config.fused_qkv,
)
for _ in range(config.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=config.mlp_ratio,
fused_qkv=config.fused_qkv,
)
for _ in range(config.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
self.initialize_weights()
if self.config.grad_ckpt_settings:
self.forward = self.forward_selective_ckpt
else:
self.forward = self.forward_ckpt
self._input_requires_grad = False
def initialize_weights(self):
if self.config.cond_embed:
nn.init.zeros_(self.cond_in.weight)
nn.init.zeros_(self.cond_in.bias)
def prepare_block_inputs(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor, # t5 encoded vec
txt_ids: Tensor,
timesteps: Tensor,
y_vec: Tensor, # clip encoded vec
cond: Tensor = None,
guidance: Tensor | None = None,
):
"""
obtain the processed:
img: projected noisy img latent,
txt: text context (from t5),
vec: clip encoded vector,
pe: the positional embeddings for concatenated img and txt
"""
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
if self.config.cond_embed:
if cond is None:
raise ValueError("Didn't get conditional input for conditional model.")
img = img + self.cond_in(cond)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.config.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y_vec)
txt = self.txt_in(txt)
# concat: 4096 + t*h*2/4
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
if self._input_requires_grad:
# we only apply lora to double/single blocks, thus we only need to enable grad for these inputs
img.requires_grad_()
txt.requires_grad_()
return img, txt, vec, pe
def enable_input_require_grads(self):
"""Fit peft lora. This method should not be called manually."""
self._input_requires_grad = True
def forward_ckpt(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y_vec: Tensor,
cond: Tensor = None,
guidance: Tensor | None = None,
**kwargs,
) -> Tensor:
img, txt, vec, pe = self.prepare_block_inputs(
img, img_ids, txt, txt_ids, timesteps, y_vec, cond, guidance
)
for block in self.double_blocks:
img, txt = auto_grad_checkpoint(block, img, txt, vec, pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = auto_grad_checkpoint(block, img, vec, pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward_selective_ckpt(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y_vec: Tensor,
cond: Tensor = None,
guidance: Tensor | None = None,
**kwargs,
) -> Tensor:
img, txt, vec, pe = self.prepare_block_inputs(
img, img_ids, txt, txt_ids, timesteps, y_vec, cond, guidance
)
ckpt_depth_double = self.config.grad_ckpt_settings[0]
for block in self.double_blocks[:ckpt_depth_double]:
img, txt = auto_grad_checkpoint(block, img, txt, vec, pe)
for block in self.double_blocks[ckpt_depth_double:]:
img, txt = block(img, txt, vec, pe)
ckpt_depth_single = self.config.grad_ckpt_settings[1]
img = torch.cat((txt, img), 1)
for block in self.single_blocks[:ckpt_depth_single]:
img = auto_grad_checkpoint(block, img, vec, pe)
for block in self.single_blocks[ckpt_depth_single:]:
img = block(img, vec, pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
@MODELS.register_module("flux")
def Flux(
cache_dir: str = None,
from_pretrained: str = None,
device_map: str | torch.device = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
strict_load: bool = False,
**kwargs,
) -> MMDiTModel:
config = MMDiTConfig(
from_pretrained=from_pretrained,
cache_dir=cache_dir,
**kwargs,
)
low_precision_init = from_pretrained is not None and len(from_pretrained) > 0
if low_precision_init:
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch_dtype)
with torch.device(device_map):
model = MMDiTModel(config)
if low_precision_init:
torch.set_default_dtype(default_dtype)
else:
model = model.to(torch_dtype)
if from_pretrained:
model = load_checkpoint(
model,
from_pretrained,
cache_dir=cache_dir,
device_map=device_map,
strict=strict_load,
)
return model