mysora/scripts/vae/stats.py

119 lines
4.0 KiB
Python

from pprint import pformat
import colossalai
import torch
from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets.dataloader import prepare_dataloader
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.config import parse_configs
from opensora.utils.logger import create_logger, is_distributed, is_main_process
from opensora.utils.misc import log_cuda_max_memory, log_model_params, to_torch_dtype
@torch.inference_mode()
def main():
torch.set_grad_enabled(False)
# ======================================================
# configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs()
# == get dtype & device ==
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
device = "cuda" if torch.cuda.is_available() else "cpu"
if is_distributed():
colossalai.launch_from_torch({})
device = get_current_device()
set_seed(cfg.get("seed", 1024))
# == init logger ==
logger = create_logger()
logger.info("Inference configuration:\n %s", pformat(cfg.to_dict()))
verbose = cfg.get("verbose", 1)
# ======================================================
# build model & loss
# ======================================================
if cfg.get("ckpt_path", None) is not None:
cfg.model.from_pretrained = cfg.ckpt_path
logger.info("Building models...")
model = build_module(cfg.model, MODELS, device_map=device, torch_dtype=dtype).eval()
log_model_params(model)
# ======================================================
# build dataset and dataloader
# ======================================================
logger.info("Building dataset...")
# == build dataset ==
dataset = build_module(cfg.dataset, DATASETS)
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", None),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=False,
drop_last=False,
pin_memory=True,
process_group=get_data_parallel_group(),
prefetch_factor=cfg.get("prefetch_factor", None),
)
if cfg.get("eval_setting", None) is not None:
# e.g. 32x256x256, 1x1024x1024
num_frames = int(cfg.eval_setting.split("x")[0])
resolution = str(cfg.eval_setting.split("x")[-1])
bucket_config = {
resolution + "px_ar1:1": {num_frames: (1.0, 1)},
}
print("eval setting:\n", bucket_config)
else:
bucket_config = cfg.get("bucket_config", None)
dataloader, _ = prepare_dataloader(
bucket_config=bucket_config,
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
dataiter = iter(dataloader)
num_steps_per_epoch = len(dataloader)
# ======================================================
# inference
# ======================================================
num_samples = 0
running_sum = running_var = 0.0
# Iter over the dataset
with tqdm(
enumerate(dataiter),
disable=not is_main_process() or verbose < 1,
total=num_steps_per_epoch,
initial=0,
) as pbar:
for _, batch in pbar:
# == load data ==
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
# == vae encoding & decoding ===
z = model.encode(x)
num_samples += 1
running_sum += z.mean().item()
running_var += (z - running_sum / num_samples).pow(2).mean().item()
shift = running_sum / num_samples
scale = (running_var / num_samples) ** 0.5
pbar.set_postfix({"mean": shift, "std": scale})
logger.info("Mean: %.4f, std: %.4f", shift, scale)
log_cuda_max_memory("inference")
if __name__ == "__main__":
main()