394 lines
16 KiB
Python
394 lines
16 KiB
Python
from collections import OrderedDict, defaultdict
|
|
from typing import Iterator
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.utils.data import Dataset, DistributedSampler
|
|
|
|
from opensora.utils.logger import log_message
|
|
from opensora.utils.misc import format_numel_str
|
|
|
|
from .aspect import get_num_pexels_from_name
|
|
from .bucket import Bucket
|
|
from .datasets import VideoTextDataset
|
|
from .parallel import pandarallel
|
|
from .utils import sync_object_across_devices
|
|
|
|
|
|
# use pandarallel to accelerate bucket processing
|
|
# NOTE: pandarallel should only access local variables
|
|
def apply(data, method=None, seed=None, num_bucket=None, fps_max=16):
|
|
return method(
|
|
data["num_frames"],
|
|
data["height"],
|
|
data["width"],
|
|
data["fps"],
|
|
data["path"],
|
|
seed + data["id"] * num_bucket,
|
|
fps_max,
|
|
)
|
|
|
|
|
|
class StatefulDistributedSampler(DistributedSampler):
|
|
def __init__(
|
|
self,
|
|
dataset: Dataset,
|
|
num_replicas: int | None = None,
|
|
rank: int | None = None,
|
|
shuffle: bool = True,
|
|
seed: int = 0,
|
|
drop_last: bool = False,
|
|
) -> None:
|
|
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
|
self.start_index: int = 0
|
|
|
|
def __iter__(self) -> Iterator:
|
|
iterator = super().__iter__()
|
|
indices = list(iterator)
|
|
indices = indices[self.start_index :]
|
|
return iter(indices)
|
|
|
|
def __len__(self) -> int:
|
|
return self.num_samples - self.start_index
|
|
|
|
def reset(self) -> None:
|
|
self.start_index = 0
|
|
|
|
def state_dict(self, step) -> dict:
|
|
return {"start_index": step}
|
|
|
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
self.__dict__.update(state_dict)
|
|
|
|
|
|
class VariableVideoBatchSampler(DistributedSampler):
|
|
def __init__(
|
|
self,
|
|
dataset: VideoTextDataset,
|
|
bucket_config: dict,
|
|
num_replicas: int | None = None,
|
|
rank: int | None = None,
|
|
shuffle: bool = True,
|
|
seed: int = 0,
|
|
drop_last: bool = False,
|
|
verbose: bool = False,
|
|
num_bucket_build_workers: int = 1,
|
|
num_groups: int = 1,
|
|
) -> None:
|
|
super().__init__(
|
|
dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last
|
|
)
|
|
self.dataset = dataset
|
|
assert dataset.bucket_class == "Bucket", "Only support Bucket class for now"
|
|
self.bucket = Bucket(bucket_config)
|
|
self.verbose = verbose
|
|
self.last_micro_batch_access_index = 0
|
|
self.num_bucket_build_workers = num_bucket_build_workers
|
|
self._cached_bucket_sample_dict = None
|
|
self._cached_num_total_batch = None
|
|
self.num_groups = num_groups
|
|
|
|
if dist.get_rank() == 0:
|
|
pandarallel.initialize(
|
|
nb_workers=self.num_bucket_build_workers,
|
|
progress_bar=False,
|
|
verbose=0,
|
|
use_memory_fs=False,
|
|
)
|
|
|
|
def __iter__(self) -> Iterator[list[int]]:
|
|
bucket_sample_dict, _ = self.group_by_bucket()
|
|
self.clear_cache()
|
|
|
|
g = torch.Generator()
|
|
g.manual_seed(self.seed + self.epoch)
|
|
bucket_micro_batch_count = OrderedDict()
|
|
bucket_last_consumed = OrderedDict()
|
|
|
|
# process the samples
|
|
for bucket_id, data_list in bucket_sample_dict.items():
|
|
# handle droplast
|
|
bs_per_gpu = self.bucket.get_batch_size(bucket_id)
|
|
remainder = len(data_list) % bs_per_gpu
|
|
|
|
if remainder > 0:
|
|
if not self.drop_last:
|
|
# if there is remainder, we pad to make it divisible
|
|
data_list += data_list[: bs_per_gpu - remainder]
|
|
else:
|
|
# we just drop the remainder to make it divisible
|
|
data_list = data_list[:-remainder]
|
|
bucket_sample_dict[bucket_id] = data_list
|
|
|
|
# handle shuffle
|
|
if self.shuffle:
|
|
data_indices = torch.randperm(len(data_list), generator=g).tolist()
|
|
data_list = [data_list[i] for i in data_indices]
|
|
bucket_sample_dict[bucket_id] = data_list
|
|
|
|
# compute how many micro-batches each bucket has
|
|
num_micro_batches = len(data_list) // bs_per_gpu
|
|
bucket_micro_batch_count[bucket_id] = num_micro_batches
|
|
|
|
# compute the bucket access order
|
|
# each bucket may have more than one batch of data
|
|
# thus bucket_id may appear more than 1 time
|
|
bucket_id_access_order = []
|
|
for bucket_id, num_micro_batch in bucket_micro_batch_count.items():
|
|
bucket_id_access_order.extend([bucket_id] * num_micro_batch)
|
|
|
|
# randomize the access order
|
|
if self.shuffle:
|
|
bucket_id_access_order_indices = torch.randperm(len(bucket_id_access_order), generator=g).tolist()
|
|
bucket_id_access_order = [bucket_id_access_order[i] for i in bucket_id_access_order_indices]
|
|
|
|
# make the number of bucket accesses divisible by dp size
|
|
remainder = len(bucket_id_access_order) % self.num_replicas
|
|
if remainder > 0:
|
|
if self.drop_last:
|
|
bucket_id_access_order = bucket_id_access_order[: len(bucket_id_access_order) - remainder]
|
|
else:
|
|
bucket_id_access_order += bucket_id_access_order[: self.num_replicas - remainder]
|
|
|
|
# prepare each batch from its bucket
|
|
# according to the predefined bucket access order
|
|
num_iters = len(bucket_id_access_order) // self.num_replicas
|
|
start_iter_idx = self.last_micro_batch_access_index // self.num_replicas
|
|
|
|
# re-compute the micro-batch consumption
|
|
# this is useful when resuming from a state dict with a different number of GPUs
|
|
self.last_micro_batch_access_index = start_iter_idx * self.num_replicas
|
|
for i in range(self.last_micro_batch_access_index):
|
|
bucket_id = bucket_id_access_order[i]
|
|
bucket_bs = self.bucket.get_batch_size(bucket_id)
|
|
if bucket_id in bucket_last_consumed:
|
|
bucket_last_consumed[bucket_id] += bucket_bs
|
|
else:
|
|
bucket_last_consumed[bucket_id] = bucket_bs
|
|
|
|
for i in range(start_iter_idx, num_iters):
|
|
bucket_access_list = bucket_id_access_order[i * self.num_replicas : (i + 1) * self.num_replicas]
|
|
self.last_micro_batch_access_index += self.num_replicas
|
|
|
|
# compute the data samples consumed by each access
|
|
bucket_access_boundaries = []
|
|
for bucket_id in bucket_access_list:
|
|
bucket_bs = self.bucket.get_batch_size(bucket_id)
|
|
last_consumed_index = bucket_last_consumed.get(bucket_id, 0)
|
|
bucket_access_boundaries.append([last_consumed_index, last_consumed_index + bucket_bs])
|
|
|
|
# update consumption
|
|
if bucket_id in bucket_last_consumed:
|
|
bucket_last_consumed[bucket_id] += bucket_bs
|
|
else:
|
|
bucket_last_consumed[bucket_id] = bucket_bs
|
|
|
|
# compute the range of data accessed by each GPU
|
|
bucket_id = bucket_access_list[self.rank]
|
|
boundary = bucket_access_boundaries[self.rank]
|
|
cur_micro_batch = bucket_sample_dict[bucket_id][boundary[0] : boundary[1]]
|
|
|
|
# encode t, h, w into the sample index
|
|
real_t, real_h, real_w = self.bucket.get_thw(bucket_id)
|
|
cur_micro_batch = [f"{idx}-{real_t}-{real_h}-{real_w}" for idx in cur_micro_batch]
|
|
yield cur_micro_batch
|
|
|
|
self.reset()
|
|
|
|
def __len__(self) -> int:
|
|
return self.get_num_batch() // self.num_groups
|
|
|
|
def get_num_batch(self) -> int:
|
|
_, num_total_batch = self.group_by_bucket()
|
|
return num_total_batch
|
|
|
|
def clear_cache(self):
|
|
self._cached_bucket_sample_dict = None
|
|
self._cached_num_total_batch = 0
|
|
|
|
def group_by_bucket(self) -> dict:
|
|
"""
|
|
Group the dataset samples into buckets.
|
|
This method will set `self._cached_bucket_sample_dict` to the bucket sample dict.
|
|
|
|
Returns:
|
|
dict: a dictionary with bucket id as key and a list of sample indices as value
|
|
"""
|
|
if self._cached_bucket_sample_dict is not None:
|
|
return self._cached_bucket_sample_dict, self._cached_num_total_batch
|
|
|
|
# use pandarallel to accelerate bucket processing
|
|
log_message("Building buckets using %d workers...", self.num_bucket_build_workers)
|
|
bucket_ids = None
|
|
if dist.get_rank() == 0:
|
|
data = self.dataset.data.copy(deep=True)
|
|
data["id"] = data.index
|
|
bucket_ids = data.parallel_apply(
|
|
apply,
|
|
axis=1,
|
|
method=self.bucket.get_bucket_id,
|
|
seed=self.seed + self.epoch,
|
|
num_bucket=self.bucket.num_bucket,
|
|
fps_max=self.dataset.fps_max,
|
|
)
|
|
dist.barrier()
|
|
bucket_ids = sync_object_across_devices(bucket_ids)
|
|
dist.barrier()
|
|
|
|
# group by bucket
|
|
# each data sample is put into a bucket with a similar image/video size
|
|
bucket_sample_dict = defaultdict(list)
|
|
bucket_ids_np = np.array(bucket_ids)
|
|
valid_indices = np.where(bucket_ids_np != None)[0]
|
|
for i in valid_indices:
|
|
bucket_sample_dict[bucket_ids_np[i]].append(i)
|
|
|
|
# cache the bucket sample dict
|
|
self._cached_bucket_sample_dict = bucket_sample_dict
|
|
|
|
# num total batch
|
|
num_total_batch = self.print_bucket_info(bucket_sample_dict)
|
|
self._cached_num_total_batch = num_total_batch
|
|
|
|
return bucket_sample_dict, num_total_batch
|
|
|
|
def print_bucket_info(self, bucket_sample_dict: dict) -> int:
|
|
# collect statistics
|
|
num_total_samples = num_total_batch = 0
|
|
num_total_img_samples = num_total_vid_samples = 0
|
|
num_total_img_batch = num_total_vid_batch = 0
|
|
num_total_vid_batch_256 = num_total_vid_batch_768 = 0
|
|
num_aspect_dict = defaultdict(lambda: [0, 0])
|
|
num_hwt_dict = defaultdict(lambda: [0, 0])
|
|
for k, v in bucket_sample_dict.items():
|
|
size = len(v)
|
|
num_batch = size // self.bucket.get_batch_size(k[:-1])
|
|
|
|
num_total_samples += size
|
|
num_total_batch += num_batch
|
|
|
|
if k[1] == 1:
|
|
num_total_img_samples += size
|
|
num_total_img_batch += num_batch
|
|
else:
|
|
if k[0] == "256px":
|
|
num_total_vid_batch_256 += num_batch
|
|
elif k[0] == "768px":
|
|
num_total_vid_batch_768 += num_batch
|
|
num_total_vid_samples += size
|
|
num_total_vid_batch += num_batch
|
|
|
|
num_aspect_dict[k[-1]][0] += size
|
|
num_aspect_dict[k[-1]][1] += num_batch
|
|
num_hwt_dict[k[:-1]][0] += size
|
|
num_hwt_dict[k[:-1]][1] += num_batch
|
|
|
|
# sort
|
|
num_aspect_dict = dict(sorted(num_aspect_dict.items(), key=lambda x: x[0]))
|
|
num_hwt_dict = dict(
|
|
sorted(num_hwt_dict.items(), key=lambda x: (get_num_pexels_from_name(x[0][0]), x[0][1]), reverse=True)
|
|
)
|
|
num_hwt_img_dict = {k: v for k, v in num_hwt_dict.items() if k[1] == 1}
|
|
num_hwt_vid_dict = {k: v for k, v in num_hwt_dict.items() if k[1] > 1}
|
|
|
|
# log
|
|
if dist.get_rank() == 0 and self.verbose:
|
|
log_message("Bucket Info:")
|
|
log_message("Bucket [#sample, #batch] by aspect ratio:")
|
|
for k, v in num_aspect_dict.items():
|
|
log_message("(%s): #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
|
|
log_message("===== Image Info =====")
|
|
log_message("Image Bucket by HxWxT:")
|
|
for k, v in num_hwt_img_dict.items():
|
|
log_message("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
|
|
log_message("--------------------------------")
|
|
log_message(
|
|
"#image sample: %s, #image batch: %s",
|
|
format_numel_str(num_total_img_samples),
|
|
format_numel_str(num_total_img_batch),
|
|
)
|
|
log_message("===== Video Info =====")
|
|
log_message("Video Bucket by HxWxT:")
|
|
for k, v in num_hwt_vid_dict.items():
|
|
log_message("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
|
|
log_message("--------------------------------")
|
|
log_message(
|
|
"#video sample: %s, #video batch: %s",
|
|
format_numel_str(num_total_vid_samples),
|
|
format_numel_str(num_total_vid_batch),
|
|
)
|
|
log_message("===== Summary =====")
|
|
log_message("#non-empty buckets: %s", len(bucket_sample_dict))
|
|
log_message(
|
|
"Img/Vid sample ratio: %.2f",
|
|
num_total_img_samples / num_total_vid_samples if num_total_vid_samples > 0 else 0,
|
|
)
|
|
log_message(
|
|
"Img/Vid batch ratio: %.2f", num_total_img_batch / num_total_vid_batch if num_total_vid_batch > 0 else 0
|
|
)
|
|
log_message(
|
|
"vid batch 256: %s, vid batch 768: %s", format_numel_str(num_total_vid_batch_256), format_numel_str(num_total_vid_batch_768)
|
|
)
|
|
log_message(
|
|
"Vid batch ratio (256px/768px): %.2f", num_total_vid_batch_256 / num_total_vid_batch_768 if num_total_vid_batch_768 > 0 else 0
|
|
)
|
|
log_message(
|
|
"#training sample: %s, #training batch: %s",
|
|
format_numel_str(num_total_samples),
|
|
format_numel_str(num_total_batch),
|
|
)
|
|
return num_total_batch
|
|
|
|
def reset(self):
|
|
self.last_micro_batch_access_index = 0
|
|
|
|
def set_step(self, start_step: int):
|
|
self.last_micro_batch_access_index = start_step * self.num_replicas
|
|
|
|
def state_dict(self, num_steps: int) -> dict:
|
|
# the last_micro_batch_access_index in the __iter__ is often
|
|
# not accurate during multi-workers and data prefetching
|
|
# thus, we need the user to pass the actual steps which have been executed
|
|
# to calculate the correct last_micro_batch_access_index
|
|
return {"seed": self.seed, "epoch": self.epoch, "last_micro_batch_access_index": num_steps * self.num_replicas}
|
|
|
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
self.__dict__.update(state_dict)
|
|
|
|
|
|
class BatchDistributedSampler(DistributedSampler):
|
|
"""
|
|
Used with BatchDataset;
|
|
Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then
|
|
| buffer {i} | buffer {i+1}
|
|
------ | ------------------- | -------------------
|
|
rank 0 | 0, 1, 2, 3, 4, | 5, 6, 7, 8, 9
|
|
rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19
|
|
rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29
|
|
"""
|
|
|
|
def __init__(self, dataset: Dataset, **kwargs):
|
|
super().__init__(dataset, **kwargs)
|
|
self.start_index = 0
|
|
|
|
def __iter__(self):
|
|
num_buffers = self.dataset.num_buffers
|
|
len_buffer = self.dataset.len_buffer
|
|
num_buffers_i = num_buffers // self.num_replicas
|
|
num_samples_i = len_buffer * num_buffers_i
|
|
|
|
indices_i = np.arange(self.start_index, num_samples_i) + self.rank * num_samples_i
|
|
indices_i = indices_i.tolist()
|
|
|
|
return iter(indices_i)
|
|
|
|
def reset(self):
|
|
self.start_index = 0
|
|
|
|
def state_dict(self, step) -> dict:
|
|
return {"start_index": step}
|
|
|
|
def load_state_dict(self, state_dict: dict):
|
|
self.start_index = state_dict["start_index"] + 1
|