184 lines
7.6 KiB
Python
184 lines
7.6 KiB
Python
from pprint import pformat
|
|
|
|
import colossalai
|
|
import torch
|
|
import torch.distributed as dist
|
|
from colossalai.cluster import DistCoordinator
|
|
from mmengine.runner import set_random_seed
|
|
from tqdm import tqdm
|
|
|
|
from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group
|
|
from opensora.datasets.dataloader import prepare_dataloader
|
|
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
|
from opensora.utils.config_utils import parse_configs
|
|
from opensora.utils.misc import create_logger, to_torch_dtype
|
|
from opensora.utils.train_utils import MaskGenerator
|
|
|
|
|
|
def main():
|
|
torch.set_grad_enabled(False)
|
|
# ======================================================
|
|
# configs & runtime variables
|
|
# ======================================================
|
|
# == parse configs ==
|
|
cfg = parse_configs(training=False)
|
|
|
|
# == device and dtype ==
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
cfg_dtype = cfg.get("dtype", "fp32")
|
|
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
|
|
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
# == init distributed env ==
|
|
colossalai.launch_from_torch({})
|
|
DistCoordinator()
|
|
set_random_seed(seed=cfg.get("seed", 1024))
|
|
set_data_parallel_group(dist.group.WORLD)
|
|
|
|
# == init logger ==
|
|
logger = create_logger()
|
|
logger.info("Eval loss configuration:\n %s", pformat(cfg.to_dict()))
|
|
|
|
# ======================================================
|
|
# build model & load weights
|
|
# ======================================================
|
|
logger.info("Building models...")
|
|
# == build text-encoder and vae ==
|
|
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
|
|
if text_encoder is not None:
|
|
text_encoder_output_dim = text_encoder.output_dim
|
|
text_encoder_model_max_length = text_encoder.model_max_length
|
|
cfg.dataset.tokenize_fn = text_encoder.tokenize_fn
|
|
else:
|
|
text_encoder_output_dim = cfg.get("text_encoder_output_dim", 4096)
|
|
text_encoder_model_max_length = cfg.get("text_encoder_model_max_length", 300)
|
|
|
|
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
|
|
|
|
# == build diffusion model ==
|
|
input_size = (None, None, None)
|
|
latent_size = vae.get_latent_size(input_size)
|
|
model = (
|
|
build_module(
|
|
cfg.model,
|
|
MODELS,
|
|
input_size=latent_size,
|
|
in_channels=vae.out_channels,
|
|
caption_channels=text_encoder_output_dim,
|
|
model_max_length=text_encoder_model_max_length,
|
|
enable_sequence_parallelism=cfg.get("sp_size", 1) > 1,
|
|
)
|
|
.to(device, dtype)
|
|
.eval()
|
|
)
|
|
text_encoder.y_embedder = model.y_embedder # HACK: for classifier-free guidance
|
|
|
|
# == build scheduler ==
|
|
scheduler = build_module(cfg.scheduler, SCHEDULERS)
|
|
|
|
if cfg.get("mask_ratios", None) is not None:
|
|
mask_generator = MaskGenerator(cfg.mask_ratios)
|
|
|
|
# ======================================================
|
|
# inference
|
|
# ======================================================
|
|
# start evaluation, prepare a dataset everytime in the loop
|
|
bucket_config = cfg.bucket_config
|
|
if cfg.get("resolution", None) is not None:
|
|
bucket_config = {cfg.resolution: bucket_config[cfg.resolution]}
|
|
assert bucket_config is not None, "bucket_config is required for evaluation"
|
|
logger.info("Evaluating bucket_config: %s", bucket_config)
|
|
|
|
def build_dataset(resolution, num_frames, batch_size):
|
|
bucket_config = {resolution: {num_frames: (1.0, batch_size)}}
|
|
dataset = build_module(cfg.dataset, DATASETS)
|
|
dataloader_args = dict(
|
|
dataset=dataset,
|
|
batch_size=None,
|
|
num_workers=cfg.num_workers,
|
|
shuffle=False,
|
|
drop_last=False,
|
|
pin_memory=True,
|
|
process_group=get_data_parallel_group(),
|
|
)
|
|
dataloader, sampler = prepare_dataloader(bucket_config=bucket_config, **dataloader_args)
|
|
num_batch = sampler.get_num_batch()
|
|
num_steps_per_epoch = num_batch // dist.get_world_size()
|
|
return dataloader, num_steps_per_epoch, num_batch
|
|
|
|
evaluation_losses = {}
|
|
start = cfg.start_index if "start_index" in cfg else 0
|
|
end = cfg.end_index if "end_index" in cfg else len(bucket_config)
|
|
for i, res in enumerate(bucket_config):
|
|
if len(bucket_config) > 1 and (i < start or i >= end): # skip task
|
|
print("skipping:", bucket_config[res])
|
|
continue
|
|
|
|
t_bucket = bucket_config[res]
|
|
num_frames_index = 0
|
|
for num_frames, (_, batch_size) in t_bucket.items():
|
|
if batch_size is None:
|
|
continue
|
|
|
|
if len(bucket_config) == 1 and (num_frames_index < start or num_frames_index >= end): # skip task
|
|
print("skipping:", num_frames)
|
|
num_frames_index += 1
|
|
continue
|
|
else:
|
|
num_frames_index += 1
|
|
logger.info("Evaluating resolution: %s, num_frames: %s", res, num_frames)
|
|
dataloader, num_steps_per_epoch, num_batch = build_dataset(res, num_frames, batch_size)
|
|
if num_batch == 0:
|
|
logger.warning("No data for resolution: %s, num_frames: %s", res, num_frames)
|
|
continue
|
|
|
|
evaluation_t_losses = []
|
|
for t in torch.linspace(0, scheduler.num_timesteps, cfg.get("num_eval_timesteps", 10) + 2)[1:-1]:
|
|
loss_t = 0.0
|
|
num_samples = 0
|
|
dataloader_iter = iter(dataloader)
|
|
for _ in tqdm(range(num_steps_per_epoch), desc=f"res: {res}, num_frames: {num_frames}, t: {t:.2f}"):
|
|
batch = next(dataloader_iter)
|
|
x = batch.pop("video").to(device, dtype)
|
|
batch.pop("text")
|
|
x = vae.encode(x)
|
|
input_ids = batch.pop("input_ids")
|
|
attention_mask = batch.pop("attention_mask")
|
|
model_args = text_encoder.encode(input_ids, attention_mask=attention_mask)
|
|
|
|
# == mask ==
|
|
mask = None
|
|
if cfg.get("mask_ratios", None) is not None:
|
|
mask = mask_generator.get_masks(x)
|
|
model_args["x_mask"] = mask
|
|
|
|
# == video meta info ==
|
|
for k, v in batch.items():
|
|
model_args[k] = v.to(device, dtype)
|
|
|
|
# == diffusion loss computation ==
|
|
timestep = torch.tensor([t] * x.shape[0], device=device, dtype=dtype)
|
|
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask, t=timestep)
|
|
losses = loss_dict["loss"] # (batch_size)
|
|
num_samples += x.shape[0]
|
|
loss_t += losses.sum().item()
|
|
loss_t /= num_samples
|
|
evaluation_t_losses.append(loss_t)
|
|
logger.info("resolution: %s, num_frames: %s, timestep: %.2f, loss: %.4f", res, num_frames, t, loss_t)
|
|
|
|
evaluation_losses[(res, num_frames)] = sum(evaluation_t_losses) / len(evaluation_t_losses)
|
|
logger.info(
|
|
"Evaluation losses for resolution: %s, num_frames: %s, loss: %s\n %s",
|
|
res,
|
|
num_frames,
|
|
evaluation_losses[(res, num_frames)],
|
|
evaluation_t_losses,
|
|
)
|
|
logger.info("Evaluation losses: %s", evaluation_losses)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|