152 lines
5.0 KiB
Python
152 lines
5.0 KiB
Python
import math
|
|
import os
|
|
|
|
ASPECT_RATIO_LD_LIST = [ # width:height
|
|
"2.39:1", # cinemascope, 2.39
|
|
"2:1", # rare, 2
|
|
"16:9", # rare, 1.89
|
|
"1.85:1", # american widescreen, 1.85
|
|
"9:16", # popular, 1.78
|
|
"5:8", # rare, 1.6
|
|
"3:2", # rare, 1.5
|
|
"4:3", # classic, 1.33
|
|
"1:1", # square
|
|
]
|
|
|
|
|
|
def get_ratio(name: str) -> float:
|
|
width, height = map(float, name.split(":"))
|
|
return height / width
|
|
|
|
|
|
def get_aspect_ratios_dict(
|
|
total_pixels: int = 256 * 256, training: bool = True
|
|
) -> dict[str, tuple[int, int]]:
|
|
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
|
|
aspect_ratios_dict = {}
|
|
aspect_ratios_vertical_dict = {}
|
|
for ratio in ASPECT_RATIO_LD_LIST:
|
|
width_ratio, height_ratio = map(float, ratio.split(":"))
|
|
width = int(math.sqrt(total_pixels * (width_ratio / height_ratio)) // D) * D
|
|
height = int((total_pixels / width) // D) * D
|
|
|
|
if training:
|
|
# adjust aspect ratio to match total pixels
|
|
diff = abs(height * width - total_pixels)
|
|
candidate = [
|
|
(height - D, width),
|
|
(height + D, width),
|
|
(height, width - D),
|
|
(height, width + D),
|
|
]
|
|
for h, w in candidate:
|
|
if abs(h * w - total_pixels) < diff:
|
|
height, width = h, w
|
|
diff = abs(h * w - total_pixels)
|
|
|
|
# remove duplicated aspect ratio
|
|
if (height, width) not in aspect_ratios_dict.values() or not training:
|
|
aspect_ratios_dict[ratio] = (height, width)
|
|
vertial_ratios = ":".join(ratio.split(":")[::-1])
|
|
aspect_ratios_vertical_dict[vertial_ratios] = (width, height)
|
|
|
|
aspect_ratios_dict.update(aspect_ratios_vertical_dict)
|
|
|
|
return aspect_ratios_dict
|
|
|
|
|
|
def get_num_pexels(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]:
|
|
return {ratio: h * w for ratio, (h, w) in aspect_ratios_dict.items()}
|
|
|
|
|
|
def get_num_tokens(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]:
|
|
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
|
|
return {ratio: h * w // D // D for ratio, (h, w) in aspect_ratios_dict.items()}
|
|
|
|
|
|
def get_num_pexels_from_name(resolution: str) -> int:
|
|
resolution = resolution.split("_")[0]
|
|
if resolution.endswith("px"):
|
|
size = int(resolution[:-2])
|
|
num_pexels = size * size
|
|
elif resolution.endswith("p"):
|
|
size = int(resolution[:-1])
|
|
num_pexels = int(size * size / 9 * 16)
|
|
else:
|
|
raise ValueError(f"Invalid resolution {resolution}")
|
|
return num_pexels
|
|
|
|
|
|
def get_resolution_with_aspect_ratio(
|
|
resolution: str,
|
|
) -> tuple[int, dict[str, tuple[int, int]]]:
|
|
"""Get resolution with aspect ratio
|
|
|
|
Args:
|
|
resolution (str): resolution name. The format is name only or "{name}_{setting}".
|
|
name supports "256px" or "360p". setting supports "ar1:1" or "max".
|
|
|
|
Returns:
|
|
tuple[int, dict[str, tuple[int, int]]]: resolution with aspect ratio
|
|
"""
|
|
keys = resolution.split("_")
|
|
if len(keys) == 1:
|
|
resolution = keys[0]
|
|
setting = ""
|
|
else:
|
|
resolution, setting = keys
|
|
assert setting == "max" or setting.startswith(
|
|
"ar"
|
|
), f"Invalid setting {setting}"
|
|
|
|
# get resolution
|
|
num_pexels = get_num_pexels_from_name(resolution)
|
|
|
|
# get aspect ratio
|
|
aspect_ratio_dict = get_aspect_ratios_dict(num_pexels)
|
|
|
|
# handle setting
|
|
if setting == "max":
|
|
aspect_ratio = max(
|
|
aspect_ratio_dict,
|
|
key=lambda x: aspect_ratio_dict[x][0] * aspect_ratio_dict[x][1],
|
|
)
|
|
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}
|
|
elif setting.startswith("ar"):
|
|
aspect_ratio = setting[2:]
|
|
assert (
|
|
aspect_ratio in aspect_ratio_dict
|
|
), f"Aspect ratio {aspect_ratio} not found"
|
|
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}
|
|
|
|
return num_pexels, aspect_ratio_dict
|
|
|
|
|
|
def get_closest_ratio(height: float, width: float, ratios: dict) -> str:
|
|
aspect_ratio = height / width
|
|
closest_ratio = min(
|
|
ratios.keys(), key=lambda ratio: abs(aspect_ratio - get_ratio(ratio))
|
|
)
|
|
return closest_ratio
|
|
|
|
|
|
def get_image_size(
|
|
resolution: str, ar_ratio: str, training: bool = True
|
|
) -> tuple[int, int]:
|
|
num_pexels = get_num_pexels_from_name(resolution)
|
|
ar_dict = get_aspect_ratios_dict(num_pexels, training)
|
|
assert ar_ratio in ar_dict, f"Aspect ratio {ar_ratio} not found"
|
|
return ar_dict[ar_ratio]
|
|
|
|
|
|
def bucket_to_shapes(bucket_config, batch_size=None):
|
|
shapes = []
|
|
for resolution, infos in bucket_config.items():
|
|
for num_frames, (_, bs) in infos.items():
|
|
aspect_ratios = get_aspect_ratios_dict(get_num_pexels_from_name(resolution))
|
|
for ar, (height, width) in aspect_ratios.items():
|
|
if batch_size is not None:
|
|
bs = batch_size
|
|
shapes.append((bs, 3, num_frames, height, width))
|
|
return shapes
|