import collections import functools import os import queue import random import threading import numpy as np import torch import torch.multiprocessing as multiprocessing from torch._utils import ExceptionWrapper from torch.distributed import ProcessGroup from torch.utils.data import DataLoader, _utils from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL from torch.utils.data.dataloader import ( IterDataPipe, MapDataPipe, _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, _sharding_worker_init_fn, _SingleProcessDataLoaderIter, ) from opensora.acceleration.parallel_states import get_data_parallel_group from opensora.registry import DATASETS, build_module from opensora.utils.config import parse_configs from opensora.utils.logger import create_logger from opensora.utils.misc import format_duration from opensora.utils.train import setup_device from .datasets import TextDataset, VideoTextDataset from .pin_memory_cache import PinMemoryCache from .sampler import DistributedSampler, VariableVideoBatchSampler def _pin_memory_loop( in_queue, out_queue, device_id, done_event, device, pin_memory_cache: PinMemoryCache, pin_memory_key: str ): # This setting is thread local, and prevents the copy in pin_memory from # consuming all CPU cores. torch.set_num_threads(1) if device == "cuda": torch.cuda.set_device(device_id) elif device == "xpu": torch.xpu.set_device(device_id) # type: ignore[attr-defined] elif device == torch._C._get_privateuse1_backend_name(): custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) custom_device_mod.set_device(device_id) def do_one_step(): try: r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: return idx, data = r if not done_event.is_set() and not isinstance(data, ExceptionWrapper): try: assert isinstance(data, dict) if pin_memory_key in data: val = data[pin_memory_key] pin_memory_value = pin_memory_cache.get(val) pin_memory_value.copy_(val) data[pin_memory_key] = pin_memory_value except Exception: data = ExceptionWrapper(where=f"in pin memory thread for device {device_id}") r = (idx, data) while not done_event.is_set(): try: out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) break except queue.Full: continue # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the # logic of this function. while not done_event.is_set(): # Make sure that we don't preserve any object from one iteration # to the next do_one_step() class _MultiProcessingDataLoaderIterForVideo(_MultiProcessingDataLoaderIter): pin_memory_key: str = "video" def __init__(self, loader): _BaseDataLoaderIter.__init__(self, loader) self.pin_memory_cache = PinMemoryCache() self._prefetch_factor = loader.prefetch_factor assert self._num_workers > 0 assert self._prefetch_factor > 0 if loader.multiprocessing_context is None: multiprocessing_context = multiprocessing else: multiprocessing_context = loader.multiprocessing_context self._worker_init_fn = loader.worker_init_fn # Adds forward compatibilities so classic DataLoader can work with DataPipes: # Additional worker init function will take care of sharding in MP and Distributed if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): self._worker_init_fn = functools.partial( _sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank ) # No certainty which module multiprocessing_context is self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] self._worker_pids_set = False self._shutdown = False self._workers_done_event = multiprocessing_context.Event() self._index_queues = [] self._workers = [] for i in range(self._num_workers): # No certainty which module multiprocessing_context is index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] # Need to `cancel_join_thread` here! # See sections (2) and (3b) above. index_queue.cancel_join_thread() w = multiprocessing_context.Process( target=_utils.worker._worker_loop, args=( self._dataset_kind, self._dataset, index_queue, self._worker_result_queue, self._workers_done_event, self._auto_collation, self._collate_fn, self._drop_last, self._base_seed, self._worker_init_fn, i, self._num_workers, self._persistent_workers, self._shared_seed, ), ) w.daemon = True # NB: Process.start() actually take some time as it needs to # start a process and pass the arguments over via a pipe. # Therefore, we only add a worker to self._workers list after # it started, so that we do not call .join() if program dies # before it starts, and __del__ tries to join but will get: # AssertionError: can only join a started process. w.start() self._index_queues.append(index_queue) self._workers.append(w) if self._pin_memory: self._pin_memory_thread_done_event = threading.Event() # Queue is not type-annotated self._data_queue = queue.Queue() # type: ignore[var-annotated] if self._pin_memory_device == "xpu": current_device = torch.xpu.current_device() # type: ignore[attr-defined] elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) current_device = custom_device_mod.current_device() else: current_device = torch.cuda.current_device() # choose cuda for default pin_memory_thread = threading.Thread( target=_pin_memory_loop, args=( self._worker_result_queue, self._data_queue, current_device, self._pin_memory_thread_done_event, self._pin_memory_device, self.pin_memory_cache, self.pin_memory_key, ), ) pin_memory_thread.daemon = True pin_memory_thread.start() # Similar to workers (see comment above), we only register # pin_memory_thread once it is started. self._pin_memory_thread = pin_memory_thread else: self._data_queue = self._worker_result_queue # type: ignore[assignment] # In some rare cases, persistent workers (daemonic processes) # would be terminated before `__del__` of iterator is invoked # when main process exits # It would cause failure when pin_memory_thread tries to read # corrupted data from worker_result_queue # atexit is used to shutdown thread and child processes in the # right sequence before main process exits if self._persistent_workers and self._pin_memory: import atexit for w in self._workers: atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) # .pid can be None only before process is spawned (not the case, so ignore) _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] _utils.signal_handling._set_SIGCHLD_handler() self._worker_pids_set = True self._reset(loader, first_iter=True) def remove_cache(self, output_tensor: torch.Tensor): self.pin_memory_cache.remove(output_tensor) def get_cache_info(self) -> str: return str(self.pin_memory_cache) class DataloaderForVideo(DataLoader): def _get_iterator(self) -> "_BaseDataLoaderIter": if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: self.check_worker_number_rationality() return _MultiProcessingDataLoaderIterForVideo(self) # Deterministic dataloader def get_seed_worker(seed): def seed_worker(worker_id): worker_seed = seed if seed is not None: np.random.seed(worker_seed) torch.manual_seed(worker_seed) random.seed(worker_seed) return seed_worker def prepare_dataloader( dataset, batch_size=None, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, process_group: ProcessGroup | None = None, bucket_config=None, num_bucket_build_workers=1, prefetch_factor=None, cache_pin_memory=False, num_groups=1, **kwargs, ): _kwargs = kwargs.copy() if isinstance(dataset, VideoTextDataset): batch_sampler = VariableVideoBatchSampler( dataset, bucket_config, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle, seed=seed, drop_last=drop_last, verbose=True, num_bucket_build_workers=num_bucket_build_workers, num_groups=num_groups, ) dl_cls = DataloaderForVideo if cache_pin_memory else DataLoader return ( dl_cls( dataset, batch_sampler=batch_sampler, worker_init_fn=get_seed_worker(seed), pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn_default, prefetch_factor=prefetch_factor, **_kwargs, ), batch_sampler, ) elif isinstance(dataset, TextDataset): if process_group is None: return ( DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, worker_init_fn=get_seed_worker(seed), drop_last=drop_last, pin_memory=pin_memory, num_workers=num_workers, prefetch_factor=prefetch_factor, **_kwargs, ), None, ) else: sampler = DistributedSampler( dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle, seed=seed, drop_last=drop_last, ) return ( DataLoader( dataset, sampler=sampler, worker_init_fn=get_seed_worker(seed), pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn_default, prefetch_factor=prefetch_factor, **_kwargs, ), sampler, ) else: raise ValueError(f"Unsupported dataset type: {type(dataset)}") def collate_fn_default(batch): # filter out None batch = [x for x in batch if x is not None] assert len(batch) > 0, "batch is empty" # HACK: for loading text features use_mask = False if "mask" in batch[0] and isinstance(batch[0]["mask"], int): masks = [x.pop("mask") for x in batch] texts = [x.pop("text") for x in batch] texts = torch.cat(texts, dim=1) use_mask = True ret = torch.utils.data.default_collate(batch) if use_mask: ret["mask"] = masks ret["text"] = texts return ret def collate_fn_batch(batch): """ Used only with BatchDistributedSampler """ # filter out None batch = [x for x in batch if x is not None] res = torch.utils.data.default_collate(batch) # squeeze the first dimension, which is due to torch.stack() in default_collate() if isinstance(res, collections.abc.Mapping): for k, v in res.items(): if isinstance(v, torch.Tensor): res[k] = v.squeeze(0) elif isinstance(res, collections.abc.Sequence): res = [x.squeeze(0) if isinstance(x, torch.Tensor) else x for x in res] elif isinstance(res, torch.Tensor): res = res.squeeze(0) else: raise TypeError return res if __name__ == "__main__": # NUM_GPU: number of GPUs for training # TIME_PER_STEP: time per step in seconds # Example usage: # torchrun --nproc_per_node 1 -m opensora.datasets.dataloader configs/diffusion/train/video_cond.py cfg = parse_configs() setup_device() logger = create_logger() # == build dataset == dataset = build_module(cfg.dataset, DATASETS) # == build dataloader == dataloader_args = dict( dataset=dataset, batch_size=cfg.get("batch_size", None), num_workers=cfg.get("num_workers", 4), seed=cfg.get("seed", 1024), shuffle=True, drop_last=True, pin_memory=True, process_group=get_data_parallel_group(), prefetch_factor=cfg.get("prefetch_factor", None), ) dataloader, sampler = prepare_dataloader( bucket_config=cfg.get("bucket_config", None), num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1), **dataloader_args, ) num_steps_per_epoch = len(dataloader) num_machines = int(os.environ.get("NUM_MACHINES", 28)) num_gpu = num_machines * 8 logger.info("Number of GPUs: %d", num_gpu) logger.info("Number of steps per epoch: %d", num_steps_per_epoch // num_gpu) time_per_step = int(os.environ.get("TIME_PER_STEP", 20)) time_training = num_steps_per_epoch // num_gpu * time_per_step logger.info("Time per step: %s", format_duration(time_per_step)) logger.info("Time for training: %s", format_duration(time_training))