mysora/opensora/datasets/sampler.py

394 lines
16 KiB
Python

from collections import OrderedDict, defaultdict
from typing import Iterator
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DistributedSampler
from opensora.utils.logger import log_message
from opensora.utils.misc import format_numel_str
from .aspect import get_num_pexels_from_name
from .bucket import Bucket
from .datasets import VideoTextDataset
from .parallel import pandarallel
from .utils import sync_object_across_devices
# use pandarallel to accelerate bucket processing
# NOTE: pandarallel should only access local variables
def apply(data, method=None, seed=None, num_bucket=None, fps_max=16):
return method(
data["num_frames"],
data["height"],
data["width"],
data["fps"],
data["path"],
seed + data["id"] * num_bucket,
fps_max,
)
class StatefulDistributedSampler(DistributedSampler):
def __init__(
self,
dataset: Dataset,
num_replicas: int | None = None,
rank: int | None = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.start_index: int = 0
def __iter__(self) -> Iterator:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def reset(self) -> None:
self.start_index = 0
def state_dict(self, step) -> dict:
return {"start_index": step}
def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict)
class VariableVideoBatchSampler(DistributedSampler):
def __init__(
self,
dataset: VideoTextDataset,
bucket_config: dict,
num_replicas: int | None = None,
rank: int | None = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
verbose: bool = False,
num_bucket_build_workers: int = 1,
num_groups: int = 1,
) -> None:
super().__init__(
dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last
)
self.dataset = dataset
assert dataset.bucket_class == "Bucket", "Only support Bucket class for now"
self.bucket = Bucket(bucket_config)
self.verbose = verbose
self.last_micro_batch_access_index = 0
self.num_bucket_build_workers = num_bucket_build_workers
self._cached_bucket_sample_dict = None
self._cached_num_total_batch = None
self.num_groups = num_groups
if dist.get_rank() == 0:
pandarallel.initialize(
nb_workers=self.num_bucket_build_workers,
progress_bar=False,
verbose=0,
use_memory_fs=False,
)
def __iter__(self) -> Iterator[list[int]]:
bucket_sample_dict, _ = self.group_by_bucket()
self.clear_cache()
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
bucket_micro_batch_count = OrderedDict()
bucket_last_consumed = OrderedDict()
# process the samples
for bucket_id, data_list in bucket_sample_dict.items():
# handle droplast
bs_per_gpu = self.bucket.get_batch_size(bucket_id)
remainder = len(data_list) % bs_per_gpu
if remainder > 0:
if not self.drop_last:
# if there is remainder, we pad to make it divisible
data_list += data_list[: bs_per_gpu - remainder]
else:
# we just drop the remainder to make it divisible
data_list = data_list[:-remainder]
bucket_sample_dict[bucket_id] = data_list
# handle shuffle
if self.shuffle:
data_indices = torch.randperm(len(data_list), generator=g).tolist()
data_list = [data_list[i] for i in data_indices]
bucket_sample_dict[bucket_id] = data_list
# compute how many micro-batches each bucket has
num_micro_batches = len(data_list) // bs_per_gpu
bucket_micro_batch_count[bucket_id] = num_micro_batches
# compute the bucket access order
# each bucket may have more than one batch of data
# thus bucket_id may appear more than 1 time
bucket_id_access_order = []
for bucket_id, num_micro_batch in bucket_micro_batch_count.items():
bucket_id_access_order.extend([bucket_id] * num_micro_batch)
# randomize the access order
if self.shuffle:
bucket_id_access_order_indices = torch.randperm(len(bucket_id_access_order), generator=g).tolist()
bucket_id_access_order = [bucket_id_access_order[i] for i in bucket_id_access_order_indices]
# make the number of bucket accesses divisible by dp size
remainder = len(bucket_id_access_order) % self.num_replicas
if remainder > 0:
if self.drop_last:
bucket_id_access_order = bucket_id_access_order[: len(bucket_id_access_order) - remainder]
else:
bucket_id_access_order += bucket_id_access_order[: self.num_replicas - remainder]
# prepare each batch from its bucket
# according to the predefined bucket access order
num_iters = len(bucket_id_access_order) // self.num_replicas
start_iter_idx = self.last_micro_batch_access_index // self.num_replicas
# re-compute the micro-batch consumption
# this is useful when resuming from a state dict with a different number of GPUs
self.last_micro_batch_access_index = start_iter_idx * self.num_replicas
for i in range(self.last_micro_batch_access_index):
bucket_id = bucket_id_access_order[i]
bucket_bs = self.bucket.get_batch_size(bucket_id)
if bucket_id in bucket_last_consumed:
bucket_last_consumed[bucket_id] += bucket_bs
else:
bucket_last_consumed[bucket_id] = bucket_bs
for i in range(start_iter_idx, num_iters):
bucket_access_list = bucket_id_access_order[i * self.num_replicas : (i + 1) * self.num_replicas]
self.last_micro_batch_access_index += self.num_replicas
# compute the data samples consumed by each access
bucket_access_boundaries = []
for bucket_id in bucket_access_list:
bucket_bs = self.bucket.get_batch_size(bucket_id)
last_consumed_index = bucket_last_consumed.get(bucket_id, 0)
bucket_access_boundaries.append([last_consumed_index, last_consumed_index + bucket_bs])
# update consumption
if bucket_id in bucket_last_consumed:
bucket_last_consumed[bucket_id] += bucket_bs
else:
bucket_last_consumed[bucket_id] = bucket_bs
# compute the range of data accessed by each GPU
bucket_id = bucket_access_list[self.rank]
boundary = bucket_access_boundaries[self.rank]
cur_micro_batch = bucket_sample_dict[bucket_id][boundary[0] : boundary[1]]
# encode t, h, w into the sample index
real_t, real_h, real_w = self.bucket.get_thw(bucket_id)
cur_micro_batch = [f"{idx}-{real_t}-{real_h}-{real_w}" for idx in cur_micro_batch]
yield cur_micro_batch
self.reset()
def __len__(self) -> int:
return self.get_num_batch() // self.num_groups
def get_num_batch(self) -> int:
_, num_total_batch = self.group_by_bucket()
return num_total_batch
def clear_cache(self):
self._cached_bucket_sample_dict = None
self._cached_num_total_batch = 0
def group_by_bucket(self) -> dict:
"""
Group the dataset samples into buckets.
This method will set `self._cached_bucket_sample_dict` to the bucket sample dict.
Returns:
dict: a dictionary with bucket id as key and a list of sample indices as value
"""
if self._cached_bucket_sample_dict is not None:
return self._cached_bucket_sample_dict, self._cached_num_total_batch
# use pandarallel to accelerate bucket processing
log_message("Building buckets using %d workers...", self.num_bucket_build_workers)
bucket_ids = None
if dist.get_rank() == 0:
data = self.dataset.data.copy(deep=True)
data["id"] = data.index
bucket_ids = data.parallel_apply(
apply,
axis=1,
method=self.bucket.get_bucket_id,
seed=self.seed + self.epoch,
num_bucket=self.bucket.num_bucket,
fps_max=self.dataset.fps_max,
)
dist.barrier()
bucket_ids = sync_object_across_devices(bucket_ids)
dist.barrier()
# group by bucket
# each data sample is put into a bucket with a similar image/video size
bucket_sample_dict = defaultdict(list)
bucket_ids_np = np.array(bucket_ids)
valid_indices = np.where(bucket_ids_np != None)[0]
for i in valid_indices:
bucket_sample_dict[bucket_ids_np[i]].append(i)
# cache the bucket sample dict
self._cached_bucket_sample_dict = bucket_sample_dict
# num total batch
num_total_batch = self.print_bucket_info(bucket_sample_dict)
self._cached_num_total_batch = num_total_batch
return bucket_sample_dict, num_total_batch
def print_bucket_info(self, bucket_sample_dict: dict) -> int:
# collect statistics
num_total_samples = num_total_batch = 0
num_total_img_samples = num_total_vid_samples = 0
num_total_img_batch = num_total_vid_batch = 0
num_total_vid_batch_256 = num_total_vid_batch_768 = 0
num_aspect_dict = defaultdict(lambda: [0, 0])
num_hwt_dict = defaultdict(lambda: [0, 0])
for k, v in bucket_sample_dict.items():
size = len(v)
num_batch = size // self.bucket.get_batch_size(k[:-1])
num_total_samples += size
num_total_batch += num_batch
if k[1] == 1:
num_total_img_samples += size
num_total_img_batch += num_batch
else:
if k[0] == "256px":
num_total_vid_batch_256 += num_batch
elif k[0] == "768px":
num_total_vid_batch_768 += num_batch
num_total_vid_samples += size
num_total_vid_batch += num_batch
num_aspect_dict[k[-1]][0] += size
num_aspect_dict[k[-1]][1] += num_batch
num_hwt_dict[k[:-1]][0] += size
num_hwt_dict[k[:-1]][1] += num_batch
# sort
num_aspect_dict = dict(sorted(num_aspect_dict.items(), key=lambda x: x[0]))
num_hwt_dict = dict(
sorted(num_hwt_dict.items(), key=lambda x: (get_num_pexels_from_name(x[0][0]), x[0][1]), reverse=True)
)
num_hwt_img_dict = {k: v for k, v in num_hwt_dict.items() if k[1] == 1}
num_hwt_vid_dict = {k: v for k, v in num_hwt_dict.items() if k[1] > 1}
# log
if dist.get_rank() == 0 and self.verbose:
log_message("Bucket Info:")
log_message("Bucket [#sample, #batch] by aspect ratio:")
for k, v in num_aspect_dict.items():
log_message("(%s): #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
log_message("===== Image Info =====")
log_message("Image Bucket by HxWxT:")
for k, v in num_hwt_img_dict.items():
log_message("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
log_message("--------------------------------")
log_message(
"#image sample: %s, #image batch: %s",
format_numel_str(num_total_img_samples),
format_numel_str(num_total_img_batch),
)
log_message("===== Video Info =====")
log_message("Video Bucket by HxWxT:")
for k, v in num_hwt_vid_dict.items():
log_message("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
log_message("--------------------------------")
log_message(
"#video sample: %s, #video batch: %s",
format_numel_str(num_total_vid_samples),
format_numel_str(num_total_vid_batch),
)
log_message("===== Summary =====")
log_message("#non-empty buckets: %s", len(bucket_sample_dict))
log_message(
"Img/Vid sample ratio: %.2f",
num_total_img_samples / num_total_vid_samples if num_total_vid_samples > 0 else 0,
)
log_message(
"Img/Vid batch ratio: %.2f", num_total_img_batch / num_total_vid_batch if num_total_vid_batch > 0 else 0
)
log_message(
"vid batch 256: %s, vid batch 768: %s", format_numel_str(num_total_vid_batch_256), format_numel_str(num_total_vid_batch_768)
)
log_message(
"Vid batch ratio (256px/768px): %.2f", num_total_vid_batch_256 / num_total_vid_batch_768 if num_total_vid_batch_768 > 0 else 0
)
log_message(
"#training sample: %s, #training batch: %s",
format_numel_str(num_total_samples),
format_numel_str(num_total_batch),
)
return num_total_batch
def reset(self):
self.last_micro_batch_access_index = 0
def set_step(self, start_step: int):
self.last_micro_batch_access_index = start_step * self.num_replicas
def state_dict(self, num_steps: int) -> dict:
# the last_micro_batch_access_index in the __iter__ is often
# not accurate during multi-workers and data prefetching
# thus, we need the user to pass the actual steps which have been executed
# to calculate the correct last_micro_batch_access_index
return {"seed": self.seed, "epoch": self.epoch, "last_micro_batch_access_index": num_steps * self.num_replicas}
def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict)
class BatchDistributedSampler(DistributedSampler):
"""
Used with BatchDataset;
Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then
| buffer {i} | buffer {i+1}
------ | ------------------- | -------------------
rank 0 | 0, 1, 2, 3, 4, | 5, 6, 7, 8, 9
rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19
rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29
"""
def __init__(self, dataset: Dataset, **kwargs):
super().__init__(dataset, **kwargs)
self.start_index = 0
def __iter__(self):
num_buffers = self.dataset.num_buffers
len_buffer = self.dataset.len_buffer
num_buffers_i = num_buffers // self.num_replicas
num_samples_i = len_buffer * num_buffers_i
indices_i = np.arange(self.start_index, num_samples_i) + self.rank * num_samples_i
indices_i = indices_i.tolist()
return iter(indices_i)
def reset(self):
self.start_index = 0
def state_dict(self, step) -> dict:
return {"start_index": step}
def load_state_dict(self, state_dict: dict):
self.start_index = state_dict["start_index"] + 1