mysora/opensora/datasets/bucket.py

140 lines
5.4 KiB
Python

from collections import OrderedDict
import numpy as np
from opensora.utils.logger import log_message
from .aspect import get_closest_ratio, get_resolution_with_aspect_ratio
from .utils import map_target_fps
class Bucket:
def __init__(self, bucket_config: dict[str, dict[int, tuple[float, int] | tuple[tuple[float, float], int]]]):
"""
Args:
bucket_config (dict): A dictionary containing the bucket configuration.
The dictionary should be in the following format:
{
"bucket_name": {
"time": (probability, batch_size),
"time": (probability, batch_size),
...
},
...
}
Or in the following format:
{
"bucket_name": {
"time": ((probability, next_probability), batch_size),
"time": ((probability, next_probability), batch_size),
...
},
...
}
The bucket_name should be the name of the bucket, and the time should be the number of frames in the video.
The probability should be a float between 0 and 1, and the batch_size should be an integer.
If the probability is a tuple, the second value should be the probability to skip to the next time.
"""
aspect_ratios = {key: get_resolution_with_aspect_ratio(key) for key in bucket_config.keys()}
bucket_probs = OrderedDict()
bucket_bs = OrderedDict()
bucket_names = sorted(bucket_config.keys(), key=lambda x: aspect_ratios[x][0], reverse=True)
for key in bucket_names:
bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True)
bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names})
bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names})
self.hw_criteria = {k: aspect_ratios[k][0] for k in bucket_names}
self.t_criteria = {k1: {k2: k2 for k2 in bucket_config[k1].keys()} for k1 in bucket_names}
self.ar_criteria = {
k1: {k2: {k3: v3 for k3, v3 in aspect_ratios[k1][1].items()} for k2 in bucket_config[k1].keys()}
for k1 in bucket_names
}
bucket_id_cnt = num_bucket = 0
bucket_id = dict()
for k1, v1 in bucket_probs.items():
bucket_id[k1] = dict()
for k2, _ in v1.items():
bucket_id[k1][k2] = bucket_id_cnt
bucket_id_cnt += 1
num_bucket += len(aspect_ratios[k1][1])
self.bucket_probs = bucket_probs
self.bucket_bs = bucket_bs
self.bucket_id = bucket_id
self.num_bucket = num_bucket
log_message("Number of buckets: %s", num_bucket)
def get_bucket_id(
self,
T: int,
H: int,
W: int,
fps: float,
path: str | None = None,
seed: int | None = None,
fps_max: int = 16,
) -> tuple[str, int, int] | None:
approx = 0.8
_, sampling_interval = map_target_fps(fps, fps_max)
T = T // sampling_interval
resolution = H * W
rng = np.random.default_rng(seed)
# Reference to probabilities and criteria for faster access
bucket_probs = self.bucket_probs
hw_criteria = self.hw_criteria
ar_criteria = self.ar_criteria
# Start searching for the appropriate bucket
for hw_id, t_criteria in bucket_probs.items():
# if resolution is too low, skip
if resolution < hw_criteria[hw_id] * approx:
continue
# if sample is an image
if T == 1:
if 1 in t_criteria:
if rng.random() < t_criteria[1]:
return hw_id, 1, get_closest_ratio(H, W, ar_criteria[hw_id][1])
continue
# Look for suitable t_id for video
for t_id, prob in t_criteria.items():
if T >= t_id and t_id != 1:
# if prob is a tuple, use the second value as the threshold to skip
# to the next t_id
if isinstance(prob, tuple):
next_hw_prob, next_t_prob = prob
if next_t_prob >= 1 or rng.random() <= next_t_prob:
continue
else:
next_hw_prob = prob
if next_hw_prob >= 1 or rng.random() <= next_hw_prob:
ar_id = get_closest_ratio(H, W, ar_criteria[hw_id][t_id])
return hw_id, t_id, ar_id
else:
break
return None
def get_thw(self, bucket_idx: tuple[str, int, int]) -> tuple[int, int, int]:
assert len(bucket_idx) == 3
T = self.t_criteria[bucket_idx[0]][bucket_idx[1]]
H, W = self.ar_criteria[bucket_idx[0]][bucket_idx[1]][bucket_idx[2]]
return T, H, W
def get_prob(self, bucket_idx: tuple[str, int]) -> float:
return self.bucket_probs[bucket_idx[0]][bucket_idx[1]]
def get_batch_size(self, bucket_idx: tuple[str, int]) -> int:
return self.bucket_bs[bucket_idx[0]][bucket_idx[1]]
def __len__(self) -> int:
return self.num_bucket