mysora/opensora/utils/inference.py

352 lines
12 KiB
Python

import copy
import os
import re
from enum import Enum
import torch
from torch import nn
from opensora.datasets import save_sample
from opensora.datasets.aspect import get_image_size
from opensora.datasets.utils import read_from_path, rescale_image_by_path
from opensora.utils.logger import log_message
from opensora.utils.prompt_refine import refine_prompts
class SamplingMethod(Enum):
I2V = "i2v" # for open sora video generation
DISTILLED = "distill" # for flux image generation
def create_tmp_csv(save_dir: str, prompt: str, ref: str = None, create=True) -> str:
"""
Create a temporary CSV file with the prompt text.
Args:
save_dir (str): The directory where the CSV file will be saved.
prompt (str): The prompt text.
Returns:
str: The path to the temporary CSV file.
"""
tmp_file = os.path.join(save_dir, "prompt.csv")
if not create:
return tmp_file
with open(tmp_file, "w", encoding="utf-8") as f:
if ref is not None:
f.write(f'text,ref\n"{prompt}","{ref}"')
else:
f.write(f'text\n"{prompt}"')
return tmp_file
def modify_option_to_t2i(sampling_option, distilled: bool = False, img_resolution: str = "1080px"):
"""
Modify the sampling option to be used for text-to-image generation.
"""
sampling_option_t2i = copy.copy(sampling_option)
if distilled:
sampling_option_t2i.method = SamplingMethod.DISTILLED
sampling_option_t2i.num_frames = 1
sampling_option_t2i.height, sampling_option_t2i.width = get_image_size(img_resolution, sampling_option.aspect_ratio)
sampling_option_t2i.guidance = 4.0
sampling_option_t2i.resized_resolution = sampling_option.resolution
return sampling_option_t2i
def get_save_path_name(
save_dir,
sub_dir,
save_prefix="",
name=None,
fallback_name=None,
index=None,
num_sample_pos=None, # idx for prompt as path
prompt_as_path=False, # save sample with same name as prompt
prompt=None,
):
"""
Get the save path for the generated samples.
"""
if prompt_as_path: # for vbench
cleaned_prompt = prompt.strip(".")
fname = f"{cleaned_prompt}-{num_sample_pos}"
else:
if name is not None:
fname = save_prefix + name
else:
fname = f"{save_prefix + fallback_name}_{index:04d}"
if num_sample_pos > 0:
fname += f"_{num_sample_pos}"
return os.path.join(save_dir, sub_dir, fname)
def get_names_from_path(path):
"""
Get the filename and extension from a path.
Args:
path (str): The path to the file.
Returns:
tuple[str, str]: The filename and the extension.
"""
filename = os.path.basename(path)
name, _ = os.path.splitext(filename)
return name
def process_and_save(
x: torch.Tensor,
batch: dict,
cfg: dict,
sub_dir: str,
generate_sampling_option,
epoch: int,
start_index: int,
saving: bool = True,
):
"""
Process the generated samples and save them to disk.
"""
fallback_name = cfg.dataset.data_path.split("/")[-1].split(".")[0]
prompt_as_path = cfg.get("prompt_as_path", False)
fps_save = cfg.get("fps_save", 16)
save_dir = cfg.save_dir
names = batch["name"] if "name" in batch else [None] * len(x)
indices = batch["index"] if "index" in batch else [None] * len(x)
if "index" in batch:
indices = [idx + start_index for idx in indices]
prompts = batch["text"]
ret_names = []
is_image = generate_sampling_option.num_frames == 1
for img, name, index, prompt in zip(x, names, indices, prompts):
# == get save path ==
save_path = get_save_path_name(
save_dir,
sub_dir,
save_prefix=cfg.get("save_prefix", ""),
name=name,
fallback_name=fallback_name,
index=index,
num_sample_pos=epoch,
prompt_as_path=prompt_as_path,
prompt=prompt,
)
ret_name = get_names_from_path(save_path)
ret_names.append(ret_name)
if saving:
# == write txt to disk ==
with open(save_path + ".txt", "w", encoding="utf-8") as f:
f.write(prompt)
# == save samples ==
save_sample(img, save_path=save_path, fps=fps_save)
# == resize image for t2i2v ==
if (
cfg.get("use_t2i2v", False)
and is_image
and generate_sampling_option.resolution != generate_sampling_option.resized_resolution
):
log_message("Rescaling image to %s...", generate_sampling_option.resized_resolution)
height, width = get_image_size(
generate_sampling_option.resized_resolution, generate_sampling_option.aspect_ratio
)
rescale_image_by_path(save_path + ".png", width, height)
return ret_names
def check_fps_added(sentence):
"""
Check if the sentence ends with the FPS information.
"""
pattern = r"\d+ FPS\.$"
if re.search(pattern, sentence):
return True
return False
def ensure_sentence_ends_with_period(sentence: str):
"""
Ensure that the sentence ends with a period.
"""
sentence = sentence.strip()
if not sentence.endswith("."):
sentence += "."
return sentence
def add_fps_info_to_text(text: list[str], fps: int = 16):
"""
Add the FPS information to the text.
"""
mod_text = []
for item in text:
item = ensure_sentence_ends_with_period(item)
if not check_fps_added(item):
item = item + f" {fps} FPS."
mod_text.append(item)
return mod_text
def add_motion_score_to_text(text, motion_score: int | str):
"""
Add the motion score to the text.
"""
if motion_score == "dynamic":
ms = refine_prompts(text, type="motion_score")
return [f"{t} {ms[i]}." for i, t in enumerate(text)]
else:
return [f"{t} {motion_score} motion score." for t in text]
def add_noise_to_ref(masked_ref: torch.Tensor, masks: torch.Tensor, t: float, sigma_min: float = 1e-5):
z_1 = torch.randn_like(masked_ref)
z_noisy = (1 - (1 - sigma_min) * t) * masked_ref + t * z_1
return masks * z_noisy
def collect_references_batch(
reference_paths: list[str],
cond_type: str,
model_ae: nn.Module,
image_size: tuple[int, int],
is_causal=False,
):
refs_x = [] # refs_x: [batch, ref_num, C, T, H, W]
device = next(model_ae.parameters()).device
dtype = next(model_ae.parameters()).dtype
for reference_path in reference_paths:
if reference_path == "":
refs_x.append(None)
continue
ref_path = reference_path.split(";")
ref = []
if "v2v" in cond_type:
r = read_from_path(ref_path[0], image_size, transform_name="resize_crop") # size [C, T, H, W]
actual_t = r.size(1)
target_t = (
64 if (actual_t >= 64 and "easy" in cond_type) else 32
) # if reference not long enough, default to shorter ref
if is_causal:
target_t += 1
assert actual_t >= target_t, f"need at least {target_t} reference frames for v2v generation"
if "head" in cond_type: # v2v head
r = r[:, :target_t]
elif "tail" in cond_type: # v2v tail
r = r[:, -target_t:]
else:
raise NotImplementedError
r_x = model_ae.encode(r.unsqueeze(0).to(device, dtype))
r_x = r_x.squeeze(0) # size [C, T, H, W]
ref.append(r_x)
elif cond_type == "i2v_head": # take the 1st frame from first ref_path
r = read_from_path(ref_path[0], image_size, transform_name="resize_crop") # size [C, T, H, W]
r = r[:, :1]
r_x = model_ae.encode(r.unsqueeze(0).to(device, dtype))
r_x = r_x.squeeze(0) # size [C, T, H, W]
ref.append(r_x)
elif cond_type == "i2v_tail": # take the last frame from last ref_path
r = read_from_path(ref_path[-1], image_size, transform_name="resize_crop") # size [C, T, H, W]
r = r[:, -1:]
r_x = model_ae.encode(r.unsqueeze(0).to(device, dtype))
r_x = r_x.squeeze(0) # size [C, T, H, W]
ref.append(r_x)
elif cond_type == "i2v_loop":
# first frame
r_head = read_from_path(ref_path[0], image_size, transform_name="resize_crop") # size [C, T, H, W]
r_head = r_head[:, :1]
r_x_head = model_ae.encode(r_head.unsqueeze(0).to(device, dtype))
r_x_head = r_x_head.squeeze(0) # size [C, T, H, W]
ref.append(r_x_head)
# last frame
r_tail = read_from_path(ref_path[-1], image_size, transform_name="resize_crop") # size [C, T, H, W]
r_tail = r_tail[:, -1:]
r_x_tail = model_ae.encode(r_tail.unsqueeze(0).to(device, dtype))
r_x_tail = r_x_tail.squeeze(0) # size [C, T, H, W]
ref.append(r_x_tail)
else:
raise NotImplementedError(f"Unknown condition type {cond_type}")
refs_x.append(ref)
return refs_x
def prepare_inference_condition(
z: torch.Tensor,
mask_cond: str,
ref_list: list[list[torch.Tensor]] = None,
causal: bool = True,
) -> torch.Tensor:
"""
Prepare the visual condition for the model, using causal vae.
Args:
z (torch.Tensor): The latent noise tensor, of shape [B, C, T, H, W]
mask_cond (dict): The condition configuration.
ref_list: list of lists of media (image/video) for i2v and v2v condition, of shape [C, T', H, W]; len(ref_list)==B; ref_list[i] is the list of media for the generation in batch idx i, we use a list of media for each batch item so that it can have multiple references. For example, ref_list[i] could be [ref_image_1, ref_image_2] for i2v_loop condition.
Returns:
torch.Tensor: The visual condition tensor.
"""
# x has shape [b, c, t, h, w], where b is the batch size
B, C, T, H, W = z.shape
masks = torch.zeros(B, 1, T, H, W)
masked_z = torch.zeros(B, C, T, H, W)
if ref_list is None:
assert mask_cond == "t2v", f"reference is required for {mask_cond}"
for i in range(B):
ref = ref_list[i]
# warning message
if ref is None and mask_cond != "t2v":
print("no reference found. will default to cond_type t2v!")
if ref is not None and T > 1: # video
# Apply the selected mask condition directly on the masks tensor
if mask_cond == "i2v_head": # equivalent to masking the first timestep
masks[i, :, 0, :, :] = 1
masked_z[i, :, 0, :, :] = ref[0][:, 0, :, :]
elif mask_cond == "i2v_tail": # mask the last timestep
masks[i, :, -1, :, :] = 1
masked_z[i, :, -1, :, :] = ref[-1][:, -1, :, :]
elif mask_cond == "v2v_head":
k = 8 + int(causal)
masks[i, :, :k, :, :] = 1
masked_z[i, :, :k, :, :] = ref[0][:, :k, :, :]
elif mask_cond == "v2v_tail":
k = 8 + int(causal)
masks[i, :, -k:, :, :] = 1
masked_z[i, :, -k:, :, :] = ref[0][:, -k:, :, :]
elif mask_cond == "v2v_head_easy":
k = 16 + int(causal)
masks[i, :, :k, :, :] = 1
masked_z[i, :, :k, :, :] = ref[0][:, :k, :, :]
elif mask_cond == "v2v_tail_easy":
k = 16 + int(causal)
masks[i, :, -k:, :, :] = 1
masked_z[i, :, -k:, :, :] = ref[0][:, -k:, :, :]
elif mask_cond == "i2v_loop": # mask first and last timesteps
masks[i, :, 0, :, :] = 1
masks[i, :, -1, :, :] = 1
masked_z[i, :, 0, :, :] = ref[0][:, 0, :, :]
masked_z[i, :, -1, :, :] = ref[-1][:, -1, :, :] # last frame of last referenced content
else:
# "t2v" is the fallback case where no specific condition is specified
assert mask_cond == "t2v", f"Unknown mask condition {mask_cond}"
masks = masks.to(z.device, z.dtype)
masked_z = masked_z.to(z.device, z.dtype)
return masks, masked_z