mysora/opensora/datasets/dataloader.py

403 lines
14 KiB
Python

import collections
import functools
import os
import queue
import random
import threading
import numpy as np
import torch
import torch.multiprocessing as multiprocessing
from torch._utils import ExceptionWrapper
from torch.distributed import ProcessGroup
from torch.utils.data import DataLoader, _utils
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data.dataloader import (
IterDataPipe,
MapDataPipe,
_BaseDataLoaderIter,
_MultiProcessingDataLoaderIter,
_sharding_worker_init_fn,
_SingleProcessDataLoaderIter,
)
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.registry import DATASETS, build_module
from opensora.utils.config import parse_configs
from opensora.utils.logger import create_logger
from opensora.utils.misc import format_duration
from opensora.utils.train import setup_device
from .datasets import TextDataset, VideoTextDataset
from .pin_memory_cache import PinMemoryCache
from .sampler import DistributedSampler, VariableVideoBatchSampler
def _pin_memory_loop(
in_queue, out_queue, device_id, done_event, device, pin_memory_cache: PinMemoryCache, pin_memory_key: str
):
# This setting is thread local, and prevents the copy in pin_memory from
# consuming all CPU cores.
torch.set_num_threads(1)
if device == "cuda":
torch.cuda.set_device(device_id)
elif device == "xpu":
torch.xpu.set_device(device_id) # type: ignore[attr-defined]
elif device == torch._C._get_privateuse1_backend_name():
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
custom_device_mod.set_device(device_id)
def do_one_step():
try:
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
return
idx, data = r
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
try:
assert isinstance(data, dict)
if pin_memory_key in data:
val = data[pin_memory_key]
pin_memory_value = pin_memory_cache.get(val)
pin_memory_value.copy_(val)
data[pin_memory_key] = pin_memory_value
except Exception:
data = ExceptionWrapper(where=f"in pin memory thread for device {device_id}")
r = (idx, data)
while not done_event.is_set():
try:
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
break
except queue.Full:
continue
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
while not done_event.is_set():
# Make sure that we don't preserve any object from one iteration
# to the next
do_one_step()
class _MultiProcessingDataLoaderIterForVideo(_MultiProcessingDataLoaderIter):
pin_memory_key: str = "video"
def __init__(self, loader):
_BaseDataLoaderIter.__init__(self, loader)
self.pin_memory_cache = PinMemoryCache()
self._prefetch_factor = loader.prefetch_factor
assert self._num_workers > 0
assert self._prefetch_factor > 0
if loader.multiprocessing_context is None:
multiprocessing_context = multiprocessing
else:
multiprocessing_context = loader.multiprocessing_context
self._worker_init_fn = loader.worker_init_fn
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
# Additional worker init function will take care of sharding in MP and Distributed
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
self._worker_init_fn = functools.partial(
_sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank
)
# No certainty which module multiprocessing_context is
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
self._worker_pids_set = False
self._shutdown = False
self._workers_done_event = multiprocessing_context.Event()
self._index_queues = []
self._workers = []
for i in range(self._num_workers):
# No certainty which module multiprocessing_context is
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
# Need to `cancel_join_thread` here!
# See sections (2) and (3b) above.
index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
args=(
self._dataset_kind,
self._dataset,
index_queue,
self._worker_result_queue,
self._workers_done_event,
self._auto_collation,
self._collate_fn,
self._drop_last,
self._base_seed,
self._worker_init_fn,
i,
self._num_workers,
self._persistent_workers,
self._shared_seed,
),
)
w.daemon = True
# NB: Process.start() actually take some time as it needs to
# start a process and pass the arguments over via a pipe.
# Therefore, we only add a worker to self._workers list after
# it started, so that we do not call .join() if program dies
# before it starts, and __del__ tries to join but will get:
# AssertionError: can only join a started process.
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event()
# Queue is not type-annotated
self._data_queue = queue.Queue() # type: ignore[var-annotated]
if self._pin_memory_device == "xpu":
current_device = torch.xpu.current_device() # type: ignore[attr-defined]
elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
current_device = custom_device_mod.current_device()
else:
current_device = torch.cuda.current_device() # choose cuda for default
pin_memory_thread = threading.Thread(
target=_pin_memory_loop,
args=(
self._worker_result_queue,
self._data_queue,
current_device,
self._pin_memory_thread_done_event,
self._pin_memory_device,
self.pin_memory_cache,
self.pin_memory_key,
),
)
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue # type: ignore[assignment]
# In some rare cases, persistent workers (daemonic processes)
# would be terminated before `__del__` of iterator is invoked
# when main process exits
# It would cause failure when pin_memory_thread tries to read
# corrupted data from worker_result_queue
# atexit is used to shutdown thread and child processes in the
# right sequence before main process exits
if self._persistent_workers and self._pin_memory:
import atexit
for w in self._workers:
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
# .pid can be None only before process is spawned (not the case, so ignore)
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True
self._reset(loader, first_iter=True)
def remove_cache(self, output_tensor: torch.Tensor):
self.pin_memory_cache.remove(output_tensor)
def get_cache_info(self) -> str:
return str(self.pin_memory_cache)
class DataloaderForVideo(DataLoader):
def _get_iterator(self) -> "_BaseDataLoaderIter":
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIterForVideo(self)
# Deterministic dataloader
def get_seed_worker(seed):
def seed_worker(worker_id):
worker_seed = seed
if seed is not None:
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return seed_worker
def prepare_dataloader(
dataset,
batch_size=None,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
process_group: ProcessGroup | None = None,
bucket_config=None,
num_bucket_build_workers=1,
prefetch_factor=None,
cache_pin_memory=False,
num_groups=1,
**kwargs,
):
_kwargs = kwargs.copy()
if isinstance(dataset, VideoTextDataset):
batch_sampler = VariableVideoBatchSampler(
dataset,
bucket_config,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
verbose=True,
num_bucket_build_workers=num_bucket_build_workers,
num_groups=num_groups,
)
dl_cls = DataloaderForVideo if cache_pin_memory else DataLoader
return (
dl_cls(
dataset,
batch_sampler=batch_sampler,
worker_init_fn=get_seed_worker(seed),
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_default,
prefetch_factor=prefetch_factor,
**_kwargs,
),
batch_sampler,
)
elif isinstance(dataset, TextDataset):
if process_group is None:
return (
DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
worker_init_fn=get_seed_worker(seed),
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
**_kwargs,
),
None,
)
else:
sampler = DistributedSampler(
dataset,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)
return (
DataLoader(
dataset,
sampler=sampler,
worker_init_fn=get_seed_worker(seed),
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_default,
prefetch_factor=prefetch_factor,
**_kwargs,
),
sampler,
)
else:
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
def collate_fn_default(batch):
# filter out None
batch = [x for x in batch if x is not None]
assert len(batch) > 0, "batch is empty"
# HACK: for loading text features
use_mask = False
if "mask" in batch[0] and isinstance(batch[0]["mask"], int):
masks = [x.pop("mask") for x in batch]
texts = [x.pop("text") for x in batch]
texts = torch.cat(texts, dim=1)
use_mask = True
ret = torch.utils.data.default_collate(batch)
if use_mask:
ret["mask"] = masks
ret["text"] = texts
return ret
def collate_fn_batch(batch):
"""
Used only with BatchDistributedSampler
"""
# filter out None
batch = [x for x in batch if x is not None]
res = torch.utils.data.default_collate(batch)
# squeeze the first dimension, which is due to torch.stack() in default_collate()
if isinstance(res, collections.abc.Mapping):
for k, v in res.items():
if isinstance(v, torch.Tensor):
res[k] = v.squeeze(0)
elif isinstance(res, collections.abc.Sequence):
res = [x.squeeze(0) if isinstance(x, torch.Tensor) else x for x in res]
elif isinstance(res, torch.Tensor):
res = res.squeeze(0)
else:
raise TypeError
return res
if __name__ == "__main__":
# NUM_GPU: number of GPUs for training
# TIME_PER_STEP: time per step in seconds
# Example usage:
# torchrun --nproc_per_node 1 -m opensora.datasets.dataloader configs/diffusion/train/video_cond.py
cfg = parse_configs()
setup_device()
logger = create_logger()
# == build dataset ==
dataset = build_module(cfg.dataset, DATASETS)
# == build dataloader ==
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", None),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
prefetch_factor=cfg.get("prefetch_factor", None),
)
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
num_steps_per_epoch = len(dataloader)
num_machines = int(os.environ.get("NUM_MACHINES", 28))
num_gpu = num_machines * 8
logger.info("Number of GPUs: %d", num_gpu)
logger.info("Number of steps per epoch: %d", num_steps_per_epoch // num_gpu)
time_per_step = int(os.environ.get("TIME_PER_STEP", 20))
time_training = num_steps_per_epoch // num_gpu * time_per_step
logger.info("Time per step: %s", format_duration(time_per_step))
logger.info("Time for training: %s", format_duration(time_training))