403 lines
14 KiB
Python
403 lines
14 KiB
Python
import collections
|
|
import functools
|
|
import os
|
|
import queue
|
|
import random
|
|
import threading
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.multiprocessing as multiprocessing
|
|
from torch._utils import ExceptionWrapper
|
|
from torch.distributed import ProcessGroup
|
|
from torch.utils.data import DataLoader, _utils
|
|
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
|
|
from torch.utils.data.dataloader import (
|
|
IterDataPipe,
|
|
MapDataPipe,
|
|
_BaseDataLoaderIter,
|
|
_MultiProcessingDataLoaderIter,
|
|
_sharding_worker_init_fn,
|
|
_SingleProcessDataLoaderIter,
|
|
)
|
|
|
|
from opensora.acceleration.parallel_states import get_data_parallel_group
|
|
from opensora.registry import DATASETS, build_module
|
|
from opensora.utils.config import parse_configs
|
|
from opensora.utils.logger import create_logger
|
|
from opensora.utils.misc import format_duration
|
|
from opensora.utils.train import setup_device
|
|
|
|
from .datasets import TextDataset, VideoTextDataset
|
|
from .pin_memory_cache import PinMemoryCache
|
|
from .sampler import DistributedSampler, VariableVideoBatchSampler
|
|
|
|
|
|
def _pin_memory_loop(
|
|
in_queue, out_queue, device_id, done_event, device, pin_memory_cache: PinMemoryCache, pin_memory_key: str
|
|
):
|
|
# This setting is thread local, and prevents the copy in pin_memory from
|
|
# consuming all CPU cores.
|
|
torch.set_num_threads(1)
|
|
|
|
if device == "cuda":
|
|
torch.cuda.set_device(device_id)
|
|
elif device == "xpu":
|
|
torch.xpu.set_device(device_id) # type: ignore[attr-defined]
|
|
elif device == torch._C._get_privateuse1_backend_name():
|
|
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
|
|
custom_device_mod.set_device(device_id)
|
|
|
|
def do_one_step():
|
|
try:
|
|
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
|
except queue.Empty:
|
|
return
|
|
idx, data = r
|
|
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
|
|
try:
|
|
assert isinstance(data, dict)
|
|
if pin_memory_key in data:
|
|
val = data[pin_memory_key]
|
|
pin_memory_value = pin_memory_cache.get(val)
|
|
pin_memory_value.copy_(val)
|
|
data[pin_memory_key] = pin_memory_value
|
|
except Exception:
|
|
data = ExceptionWrapper(where=f"in pin memory thread for device {device_id}")
|
|
r = (idx, data)
|
|
while not done_event.is_set():
|
|
try:
|
|
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
|
|
break
|
|
except queue.Full:
|
|
continue
|
|
|
|
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
|
# logic of this function.
|
|
while not done_event.is_set():
|
|
# Make sure that we don't preserve any object from one iteration
|
|
# to the next
|
|
do_one_step()
|
|
|
|
|
|
class _MultiProcessingDataLoaderIterForVideo(_MultiProcessingDataLoaderIter):
|
|
pin_memory_key: str = "video"
|
|
|
|
def __init__(self, loader):
|
|
_BaseDataLoaderIter.__init__(self, loader)
|
|
self.pin_memory_cache = PinMemoryCache()
|
|
|
|
self._prefetch_factor = loader.prefetch_factor
|
|
|
|
assert self._num_workers > 0
|
|
assert self._prefetch_factor > 0
|
|
|
|
if loader.multiprocessing_context is None:
|
|
multiprocessing_context = multiprocessing
|
|
else:
|
|
multiprocessing_context = loader.multiprocessing_context
|
|
|
|
self._worker_init_fn = loader.worker_init_fn
|
|
|
|
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
|
# Additional worker init function will take care of sharding in MP and Distributed
|
|
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
|
|
self._worker_init_fn = functools.partial(
|
|
_sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank
|
|
)
|
|
|
|
# No certainty which module multiprocessing_context is
|
|
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
|
self._worker_pids_set = False
|
|
self._shutdown = False
|
|
self._workers_done_event = multiprocessing_context.Event()
|
|
|
|
self._index_queues = []
|
|
self._workers = []
|
|
for i in range(self._num_workers):
|
|
# No certainty which module multiprocessing_context is
|
|
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
|
# Need to `cancel_join_thread` here!
|
|
# See sections (2) and (3b) above.
|
|
index_queue.cancel_join_thread()
|
|
w = multiprocessing_context.Process(
|
|
target=_utils.worker._worker_loop,
|
|
args=(
|
|
self._dataset_kind,
|
|
self._dataset,
|
|
index_queue,
|
|
self._worker_result_queue,
|
|
self._workers_done_event,
|
|
self._auto_collation,
|
|
self._collate_fn,
|
|
self._drop_last,
|
|
self._base_seed,
|
|
self._worker_init_fn,
|
|
i,
|
|
self._num_workers,
|
|
self._persistent_workers,
|
|
self._shared_seed,
|
|
),
|
|
)
|
|
w.daemon = True
|
|
# NB: Process.start() actually take some time as it needs to
|
|
# start a process and pass the arguments over via a pipe.
|
|
# Therefore, we only add a worker to self._workers list after
|
|
# it started, so that we do not call .join() if program dies
|
|
# before it starts, and __del__ tries to join but will get:
|
|
# AssertionError: can only join a started process.
|
|
w.start()
|
|
self._index_queues.append(index_queue)
|
|
self._workers.append(w)
|
|
|
|
if self._pin_memory:
|
|
self._pin_memory_thread_done_event = threading.Event()
|
|
|
|
# Queue is not type-annotated
|
|
self._data_queue = queue.Queue() # type: ignore[var-annotated]
|
|
if self._pin_memory_device == "xpu":
|
|
current_device = torch.xpu.current_device() # type: ignore[attr-defined]
|
|
elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
|
|
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
|
|
current_device = custom_device_mod.current_device()
|
|
else:
|
|
current_device = torch.cuda.current_device() # choose cuda for default
|
|
pin_memory_thread = threading.Thread(
|
|
target=_pin_memory_loop,
|
|
args=(
|
|
self._worker_result_queue,
|
|
self._data_queue,
|
|
current_device,
|
|
self._pin_memory_thread_done_event,
|
|
self._pin_memory_device,
|
|
self.pin_memory_cache,
|
|
self.pin_memory_key,
|
|
),
|
|
)
|
|
pin_memory_thread.daemon = True
|
|
pin_memory_thread.start()
|
|
# Similar to workers (see comment above), we only register
|
|
# pin_memory_thread once it is started.
|
|
self._pin_memory_thread = pin_memory_thread
|
|
else:
|
|
self._data_queue = self._worker_result_queue # type: ignore[assignment]
|
|
|
|
# In some rare cases, persistent workers (daemonic processes)
|
|
# would be terminated before `__del__` of iterator is invoked
|
|
# when main process exits
|
|
# It would cause failure when pin_memory_thread tries to read
|
|
# corrupted data from worker_result_queue
|
|
# atexit is used to shutdown thread and child processes in the
|
|
# right sequence before main process exits
|
|
if self._persistent_workers and self._pin_memory:
|
|
import atexit
|
|
|
|
for w in self._workers:
|
|
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
|
|
|
|
# .pid can be None only before process is spawned (not the case, so ignore)
|
|
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
|
|
_utils.signal_handling._set_SIGCHLD_handler()
|
|
self._worker_pids_set = True
|
|
self._reset(loader, first_iter=True)
|
|
|
|
def remove_cache(self, output_tensor: torch.Tensor):
|
|
self.pin_memory_cache.remove(output_tensor)
|
|
|
|
def get_cache_info(self) -> str:
|
|
return str(self.pin_memory_cache)
|
|
|
|
|
|
class DataloaderForVideo(DataLoader):
|
|
def _get_iterator(self) -> "_BaseDataLoaderIter":
|
|
if self.num_workers == 0:
|
|
return _SingleProcessDataLoaderIter(self)
|
|
else:
|
|
self.check_worker_number_rationality()
|
|
return _MultiProcessingDataLoaderIterForVideo(self)
|
|
|
|
|
|
# Deterministic dataloader
|
|
def get_seed_worker(seed):
|
|
def seed_worker(worker_id):
|
|
worker_seed = seed
|
|
if seed is not None:
|
|
np.random.seed(worker_seed)
|
|
torch.manual_seed(worker_seed)
|
|
random.seed(worker_seed)
|
|
|
|
return seed_worker
|
|
|
|
|
|
def prepare_dataloader(
|
|
dataset,
|
|
batch_size=None,
|
|
shuffle=False,
|
|
seed=1024,
|
|
drop_last=False,
|
|
pin_memory=False,
|
|
num_workers=0,
|
|
process_group: ProcessGroup | None = None,
|
|
bucket_config=None,
|
|
num_bucket_build_workers=1,
|
|
prefetch_factor=None,
|
|
cache_pin_memory=False,
|
|
num_groups=1,
|
|
**kwargs,
|
|
):
|
|
_kwargs = kwargs.copy()
|
|
if isinstance(dataset, VideoTextDataset):
|
|
batch_sampler = VariableVideoBatchSampler(
|
|
dataset,
|
|
bucket_config,
|
|
num_replicas=process_group.size(),
|
|
rank=process_group.rank(),
|
|
shuffle=shuffle,
|
|
seed=seed,
|
|
drop_last=drop_last,
|
|
verbose=True,
|
|
num_bucket_build_workers=num_bucket_build_workers,
|
|
num_groups=num_groups,
|
|
)
|
|
dl_cls = DataloaderForVideo if cache_pin_memory else DataLoader
|
|
return (
|
|
dl_cls(
|
|
dataset,
|
|
batch_sampler=batch_sampler,
|
|
worker_init_fn=get_seed_worker(seed),
|
|
pin_memory=pin_memory,
|
|
num_workers=num_workers,
|
|
collate_fn=collate_fn_default,
|
|
prefetch_factor=prefetch_factor,
|
|
**_kwargs,
|
|
),
|
|
batch_sampler,
|
|
)
|
|
elif isinstance(dataset, TextDataset):
|
|
if process_group is None:
|
|
return (
|
|
DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=shuffle,
|
|
worker_init_fn=get_seed_worker(seed),
|
|
drop_last=drop_last,
|
|
pin_memory=pin_memory,
|
|
num_workers=num_workers,
|
|
prefetch_factor=prefetch_factor,
|
|
**_kwargs,
|
|
),
|
|
None,
|
|
)
|
|
else:
|
|
sampler = DistributedSampler(
|
|
dataset,
|
|
num_replicas=process_group.size(),
|
|
rank=process_group.rank(),
|
|
shuffle=shuffle,
|
|
seed=seed,
|
|
drop_last=drop_last,
|
|
)
|
|
return (
|
|
DataLoader(
|
|
dataset,
|
|
sampler=sampler,
|
|
worker_init_fn=get_seed_worker(seed),
|
|
pin_memory=pin_memory,
|
|
num_workers=num_workers,
|
|
collate_fn=collate_fn_default,
|
|
prefetch_factor=prefetch_factor,
|
|
**_kwargs,
|
|
),
|
|
sampler,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
|
|
|
|
|
|
def collate_fn_default(batch):
|
|
# filter out None
|
|
batch = [x for x in batch if x is not None]
|
|
assert len(batch) > 0, "batch is empty"
|
|
|
|
# HACK: for loading text features
|
|
use_mask = False
|
|
if "mask" in batch[0] and isinstance(batch[0]["mask"], int):
|
|
masks = [x.pop("mask") for x in batch]
|
|
|
|
texts = [x.pop("text") for x in batch]
|
|
texts = torch.cat(texts, dim=1)
|
|
use_mask = True
|
|
|
|
ret = torch.utils.data.default_collate(batch)
|
|
|
|
if use_mask:
|
|
ret["mask"] = masks
|
|
ret["text"] = texts
|
|
return ret
|
|
|
|
|
|
def collate_fn_batch(batch):
|
|
"""
|
|
Used only with BatchDistributedSampler
|
|
"""
|
|
# filter out None
|
|
batch = [x for x in batch if x is not None]
|
|
|
|
res = torch.utils.data.default_collate(batch)
|
|
|
|
# squeeze the first dimension, which is due to torch.stack() in default_collate()
|
|
if isinstance(res, collections.abc.Mapping):
|
|
for k, v in res.items():
|
|
if isinstance(v, torch.Tensor):
|
|
res[k] = v.squeeze(0)
|
|
elif isinstance(res, collections.abc.Sequence):
|
|
res = [x.squeeze(0) if isinstance(x, torch.Tensor) else x for x in res]
|
|
elif isinstance(res, torch.Tensor):
|
|
res = res.squeeze(0)
|
|
else:
|
|
raise TypeError
|
|
|
|
return res
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# NUM_GPU: number of GPUs for training
|
|
# TIME_PER_STEP: time per step in seconds
|
|
|
|
# Example usage:
|
|
# torchrun --nproc_per_node 1 -m opensora.datasets.dataloader configs/diffusion/train/video_cond.py
|
|
cfg = parse_configs()
|
|
setup_device()
|
|
logger = create_logger()
|
|
|
|
# == build dataset ==
|
|
dataset = build_module(cfg.dataset, DATASETS)
|
|
|
|
# == build dataloader ==
|
|
dataloader_args = dict(
|
|
dataset=dataset,
|
|
batch_size=cfg.get("batch_size", None),
|
|
num_workers=cfg.get("num_workers", 4),
|
|
seed=cfg.get("seed", 1024),
|
|
shuffle=True,
|
|
drop_last=True,
|
|
pin_memory=True,
|
|
process_group=get_data_parallel_group(),
|
|
prefetch_factor=cfg.get("prefetch_factor", None),
|
|
)
|
|
dataloader, sampler = prepare_dataloader(
|
|
bucket_config=cfg.get("bucket_config", None),
|
|
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
|
|
**dataloader_args,
|
|
)
|
|
num_steps_per_epoch = len(dataloader)
|
|
num_machines = int(os.environ.get("NUM_MACHINES", 28))
|
|
num_gpu = num_machines * 8
|
|
logger.info("Number of GPUs: %d", num_gpu)
|
|
logger.info("Number of steps per epoch: %d", num_steps_per_epoch // num_gpu)
|
|
time_per_step = int(os.environ.get("TIME_PER_STEP", 20))
|
|
time_training = num_steps_per_epoch // num_gpu * time_per_step
|
|
logger.info("Time per step: %s", format_duration(time_per_step))
|
|
logger.info("Time for training: %s", format_duration(time_training))
|