mysora/scripts/diffusion/inference.py

246 lines
9.3 KiB
Python

import os
import time
import warnings
from pprint import pformat
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
import torch
import torch.distributed as dist
from colossalai.utils import set_seed
from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets.dataloader import prepare_dataloader
from opensora.registry import DATASETS, build_module
from opensora.utils.cai import (
get_booster,
get_is_saving_process,
init_inference_environment,
)
from opensora.utils.config import parse_alias, parse_configs
from opensora.utils.inference import (
add_fps_info_to_text,
add_motion_score_to_text,
create_tmp_csv,
modify_option_to_t2i,
process_and_save,
)
from opensora.utils.logger import create_logger, is_main_process
from opensora.utils.misc import log_cuda_max_memory, to_torch_dtype
from opensora.utils.prompt_refine import refine_prompts
from opensora.utils.sampling import (
SamplingOption,
prepare_api,
prepare_models,
sanitize_sampling_option,
)
@torch.inference_mode()
def main():
# ======================================================
# 1. configs & runtime variables
# ======================================================
torch.set_grad_enabled(False)
# == parse configs ==
cfg = parse_configs()
cfg = parse_alias(cfg)
# == device and dtype ==
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
seed = cfg.get("seed", 1024)
if seed is not None:
set_seed(seed)
# == init distributed env ==
init_inference_environment()
logger = create_logger()
logger.info("Inference configuration:\n %s", pformat(cfg.to_dict()))
is_saving_process = get_is_saving_process(cfg)
booster = get_booster(cfg)
booster_ae = get_booster(cfg, ae=True)
# ======================================================
# 2. build dataset and dataloader
# ======================================================
logger.info("Building dataset...")
# save directory
save_dir = cfg.save_dir
os.makedirs(save_dir, exist_ok=True)
# == build dataset ==
if cfg.get("prompt"):
cfg.dataset.data_path = create_tmp_csv(save_dir, cfg.prompt, cfg.get("ref", None), create=is_main_process())
dist.barrier()
dataset = build_module(cfg.dataset, DATASETS)
# range selection
start_index = cfg.get("start_index", 0)
end_index = cfg.get("end_index", None)
if end_index is None:
end_index = start_index + cfg.get("num_samples", len(dataset.data) + 1)
dataset.data = dataset.data[start_index:end_index]
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", 1),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=False,
drop_last=False,
pin_memory=True,
process_group=get_data_parallel_group(),
prefetch_factor=cfg.get("prefetch_factor", None),
)
dataloader, _ = prepare_dataloader(**dataloader_args)
# == prepare default params ==
sampling_option = SamplingOption(**cfg.sampling_option)
sampling_option = sanitize_sampling_option(sampling_option)
cond_type = cfg.get("cond_type", "t2v")
prompt_refine = cfg.get("prompt_refine", False)
fps_save = cfg.get("fps_save", 16)
num_sample = cfg.get("num_sample", 1)
type_name = "image" if cfg.sampling_option.num_frames == 1 else "video"
sub_dir = f"{type_name}_{cfg.sampling_option.resolution}"
os.makedirs(os.path.join(save_dir, sub_dir), exist_ok=True)
use_t2i2v = cfg.get("use_t2i2v", False)
img_sub_dir = os.path.join(sub_dir, "generated_condition")
if use_t2i2v:
os.makedirs(os.path.join(save_dir, sub_dir, "generated_condition"), exist_ok=True)
# ======================================================
# 3. build model
# ======================================================
logger.info("Building models...")
# == build flux model ==
model, model_ae, model_t5, model_clip, optional_models = prepare_models(
cfg, device, dtype, offload_model=cfg.get("offload_model", False)
)
log_cuda_max_memory("build model")
if booster:
model, _, _, _, _ = booster.boost(model=model)
model = model.unwrap()
if booster_ae:
model_ae, _, _, _, _ = booster_ae.boost(model=model_ae)
model_ae = model_ae.unwrap()
api_fn = prepare_api(model, model_ae, model_t5, model_clip, optional_models)
# prepare image flux model if t2i2v
if use_t2i2v:
api_fn_img = prepare_api(
optional_models["img_flux"], optional_models["img_flux_ae"], model_t5, model_clip, optional_models
)
# ======================================================
# 4. inference
# ======================================================
for epoch in range(num_sample): # generate multiple samples with different seeds
dataloader_iter = iter(dataloader)
with tqdm(
enumerate(dataloader_iter, start=0),
desc="Inference progress",
disable=not is_main_process(),
initial=0,
total=len(dataloader),
) as pbar:
for _, batch in pbar:
original_text = batch.pop("text")
if use_t2i2v:
batch["text"] = original_text if not prompt_refine else refine_prompts(original_text, type="t2i")
sampling_option_t2i = modify_option_to_t2i(
sampling_option,
distilled=True,
img_resolution=cfg.get("img_resolution", "768px"),
)
if cfg.get("offload_model", False):
model_move_start = time.time()
model = model.to("cpu", dtype)
model_ae = model_ae.to("cpu", dtype)
optional_models["img_flux"].to(device, dtype)
optional_models["img_flux_ae"].to(device, dtype)
logger.info(
"offload video diffusion model to cpu, load image flux model to gpu: %s s",
time.time() - model_move_start,
)
logger.info("Generating image condition by flux...")
x_cond = api_fn_img(
sampling_option_t2i,
"t2v",
seed=sampling_option.seed + epoch if sampling_option.seed else None,
channel=cfg["img_flux"]["in_channels"],
**batch,
).cpu()
# save image to disk
batch["name"] = process_and_save(
x_cond,
batch,
cfg,
img_sub_dir,
sampling_option_t2i,
epoch,
start_index,
saving=is_saving_process,
)
dist.barrier()
if cfg.get("offload_model", False):
model_move_start = time.time()
model = model.to(device, dtype)
model_ae = model_ae.to(device, dtype)
optional_models["img_flux"].to("cpu", dtype)
optional_models["img_flux_ae"].to("cpu", dtype)
logger.info(
"load video diffusion model to gpu, offload image flux model to cpu: %s s",
time.time() - model_move_start,
)
ref_dir = os.path.join(save_dir, os.path.join(sub_dir, "generated_condition"))
batch["ref"] = [os.path.join(ref_dir, f"{x}.png") for x in batch["name"]]
cond_type = "i2v_head"
batch["text"] = original_text
if prompt_refine:
batch["text"] = refine_prompts(
original_text, type="t2v" if cond_type == "t2v" else "t2i", image_paths=batch.get("ref", None)
)
batch["text"] = add_fps_info_to_text(batch.pop("text"), fps=fps_save)
if "motion_score" in cfg:
batch["text"] = add_motion_score_to_text(batch.pop("text"), cfg.get("motion_score", 5))
logger.info("Generating video...")
x = api_fn(
sampling_option,
cond_type,
seed=sampling_option.seed + epoch if sampling_option.seed else None,
patch_size=cfg.get("patch_size", 2),
save_prefix=cfg.get("save_prefix", ""),
channel=cfg["model"]["in_channels"],
**batch,
).cpu()
if is_saving_process:
process_and_save(x, batch, cfg, sub_dir, sampling_option, epoch, start_index)
dist.barrier()
logger.info("Inference finished.")
log_cuda_max_memory("inference")
if __name__ == "__main__":
main()