mysora/opensora/datasets/aspect.py

152 lines
5.0 KiB
Python

import math
import os
ASPECT_RATIO_LD_LIST = [ # width:height
"2.39:1", # cinemascope, 2.39
"2:1", # rare, 2
"16:9", # rare, 1.89
"1.85:1", # american widescreen, 1.85
"9:16", # popular, 1.78
"5:8", # rare, 1.6
"3:2", # rare, 1.5
"4:3", # classic, 1.33
"1:1", # square
]
def get_ratio(name: str) -> float:
width, height = map(float, name.split(":"))
return height / width
def get_aspect_ratios_dict(
total_pixels: int = 256 * 256, training: bool = True
) -> dict[str, tuple[int, int]]:
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
aspect_ratios_dict = {}
aspect_ratios_vertical_dict = {}
for ratio in ASPECT_RATIO_LD_LIST:
width_ratio, height_ratio = map(float, ratio.split(":"))
width = int(math.sqrt(total_pixels * (width_ratio / height_ratio)) // D) * D
height = int((total_pixels / width) // D) * D
if training:
# adjust aspect ratio to match total pixels
diff = abs(height * width - total_pixels)
candidate = [
(height - D, width),
(height + D, width),
(height, width - D),
(height, width + D),
]
for h, w in candidate:
if abs(h * w - total_pixels) < diff:
height, width = h, w
diff = abs(h * w - total_pixels)
# remove duplicated aspect ratio
if (height, width) not in aspect_ratios_dict.values() or not training:
aspect_ratios_dict[ratio] = (height, width)
vertial_ratios = ":".join(ratio.split(":")[::-1])
aspect_ratios_vertical_dict[vertial_ratios] = (width, height)
aspect_ratios_dict.update(aspect_ratios_vertical_dict)
return aspect_ratios_dict
def get_num_pexels(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]:
return {ratio: h * w for ratio, (h, w) in aspect_ratios_dict.items()}
def get_num_tokens(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]:
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
return {ratio: h * w // D // D for ratio, (h, w) in aspect_ratios_dict.items()}
def get_num_pexels_from_name(resolution: str) -> int:
resolution = resolution.split("_")[0]
if resolution.endswith("px"):
size = int(resolution[:-2])
num_pexels = size * size
elif resolution.endswith("p"):
size = int(resolution[:-1])
num_pexels = int(size * size / 9 * 16)
else:
raise ValueError(f"Invalid resolution {resolution}")
return num_pexels
def get_resolution_with_aspect_ratio(
resolution: str,
) -> tuple[int, dict[str, tuple[int, int]]]:
"""Get resolution with aspect ratio
Args:
resolution (str): resolution name. The format is name only or "{name}_{setting}".
name supports "256px" or "360p". setting supports "ar1:1" or "max".
Returns:
tuple[int, dict[str, tuple[int, int]]]: resolution with aspect ratio
"""
keys = resolution.split("_")
if len(keys) == 1:
resolution = keys[0]
setting = ""
else:
resolution, setting = keys
assert setting == "max" or setting.startswith(
"ar"
), f"Invalid setting {setting}"
# get resolution
num_pexels = get_num_pexels_from_name(resolution)
# get aspect ratio
aspect_ratio_dict = get_aspect_ratios_dict(num_pexels)
# handle setting
if setting == "max":
aspect_ratio = max(
aspect_ratio_dict,
key=lambda x: aspect_ratio_dict[x][0] * aspect_ratio_dict[x][1],
)
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}
elif setting.startswith("ar"):
aspect_ratio = setting[2:]
assert (
aspect_ratio in aspect_ratio_dict
), f"Aspect ratio {aspect_ratio} not found"
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}
return num_pexels, aspect_ratio_dict
def get_closest_ratio(height: float, width: float, ratios: dict) -> str:
aspect_ratio = height / width
closest_ratio = min(
ratios.keys(), key=lambda ratio: abs(aspect_ratio - get_ratio(ratio))
)
return closest_ratio
def get_image_size(
resolution: str, ar_ratio: str, training: bool = True
) -> tuple[int, int]:
num_pexels = get_num_pexels_from_name(resolution)
ar_dict = get_aspect_ratios_dict(num_pexels, training)
assert ar_ratio in ar_dict, f"Aspect ratio {ar_ratio} not found"
return ar_dict[ar_ratio]
def bucket_to_shapes(bucket_config, batch_size=None):
shapes = []
for resolution, infos in bucket_config.items():
for num_frames, (_, bs) in infos.items():
aspect_ratios = get_aspect_ratios_dict(get_num_pexels_from_name(resolution))
for ar, (height, width) in aspect_ratios.items():
if batch_size is not None:
bs = batch_size
shapes.append((bs, 3, num_frames, height, width))
return shapes