feat: add opensora/datasets module and tools/datasets
- Add opensora/datasets (aspect, bucket, dataloader, datasets, parallel, pin_memory_cache, read_video, sampler, utils, video_transforms) - Add tools/datasets pipeline scripts - Fix .gitignore: scope /datasets to root-level only, whitelist opensora/datasets/ Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
916ee2126d
commit
bdeb2870d4
|
|
@ -195,4 +195,5 @@ package.json
|
|||
exps
|
||||
ckpts
|
||||
flash-attention
|
||||
datasets
|
||||
/datasets
|
||||
!opensora/datasets/
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
from .datasets import TextDataset, VideoTextDataset
|
||||
from .utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample
|
||||
|
|
@ -0,0 +1,151 @@
|
|||
import math
|
||||
import os
|
||||
|
||||
ASPECT_RATIO_LD_LIST = [ # width:height
|
||||
"2.39:1", # cinemascope, 2.39
|
||||
"2:1", # rare, 2
|
||||
"16:9", # rare, 1.89
|
||||
"1.85:1", # american widescreen, 1.85
|
||||
"9:16", # popular, 1.78
|
||||
"5:8", # rare, 1.6
|
||||
"3:2", # rare, 1.5
|
||||
"4:3", # classic, 1.33
|
||||
"1:1", # square
|
||||
]
|
||||
|
||||
|
||||
def get_ratio(name: str) -> float:
|
||||
width, height = map(float, name.split(":"))
|
||||
return height / width
|
||||
|
||||
|
||||
def get_aspect_ratios_dict(
|
||||
total_pixels: int = 256 * 256, training: bool = True
|
||||
) -> dict[str, tuple[int, int]]:
|
||||
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
|
||||
aspect_ratios_dict = {}
|
||||
aspect_ratios_vertical_dict = {}
|
||||
for ratio in ASPECT_RATIO_LD_LIST:
|
||||
width_ratio, height_ratio = map(float, ratio.split(":"))
|
||||
width = int(math.sqrt(total_pixels * (width_ratio / height_ratio)) // D) * D
|
||||
height = int((total_pixels / width) // D) * D
|
||||
|
||||
if training:
|
||||
# adjust aspect ratio to match total pixels
|
||||
diff = abs(height * width - total_pixels)
|
||||
candidate = [
|
||||
(height - D, width),
|
||||
(height + D, width),
|
||||
(height, width - D),
|
||||
(height, width + D),
|
||||
]
|
||||
for h, w in candidate:
|
||||
if abs(h * w - total_pixels) < diff:
|
||||
height, width = h, w
|
||||
diff = abs(h * w - total_pixels)
|
||||
|
||||
# remove duplicated aspect ratio
|
||||
if (height, width) not in aspect_ratios_dict.values() or not training:
|
||||
aspect_ratios_dict[ratio] = (height, width)
|
||||
vertial_ratios = ":".join(ratio.split(":")[::-1])
|
||||
aspect_ratios_vertical_dict[vertial_ratios] = (width, height)
|
||||
|
||||
aspect_ratios_dict.update(aspect_ratios_vertical_dict)
|
||||
|
||||
return aspect_ratios_dict
|
||||
|
||||
|
||||
def get_num_pexels(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]:
|
||||
return {ratio: h * w for ratio, (h, w) in aspect_ratios_dict.items()}
|
||||
|
||||
|
||||
def get_num_tokens(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]:
|
||||
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
|
||||
return {ratio: h * w // D // D for ratio, (h, w) in aspect_ratios_dict.items()}
|
||||
|
||||
|
||||
def get_num_pexels_from_name(resolution: str) -> int:
|
||||
resolution = resolution.split("_")[0]
|
||||
if resolution.endswith("px"):
|
||||
size = int(resolution[:-2])
|
||||
num_pexels = size * size
|
||||
elif resolution.endswith("p"):
|
||||
size = int(resolution[:-1])
|
||||
num_pexels = int(size * size / 9 * 16)
|
||||
else:
|
||||
raise ValueError(f"Invalid resolution {resolution}")
|
||||
return num_pexels
|
||||
|
||||
|
||||
def get_resolution_with_aspect_ratio(
|
||||
resolution: str,
|
||||
) -> tuple[int, dict[str, tuple[int, int]]]:
|
||||
"""Get resolution with aspect ratio
|
||||
|
||||
Args:
|
||||
resolution (str): resolution name. The format is name only or "{name}_{setting}".
|
||||
name supports "256px" or "360p". setting supports "ar1:1" or "max".
|
||||
|
||||
Returns:
|
||||
tuple[int, dict[str, tuple[int, int]]]: resolution with aspect ratio
|
||||
"""
|
||||
keys = resolution.split("_")
|
||||
if len(keys) == 1:
|
||||
resolution = keys[0]
|
||||
setting = ""
|
||||
else:
|
||||
resolution, setting = keys
|
||||
assert setting == "max" or setting.startswith(
|
||||
"ar"
|
||||
), f"Invalid setting {setting}"
|
||||
|
||||
# get resolution
|
||||
num_pexels = get_num_pexels_from_name(resolution)
|
||||
|
||||
# get aspect ratio
|
||||
aspect_ratio_dict = get_aspect_ratios_dict(num_pexels)
|
||||
|
||||
# handle setting
|
||||
if setting == "max":
|
||||
aspect_ratio = max(
|
||||
aspect_ratio_dict,
|
||||
key=lambda x: aspect_ratio_dict[x][0] * aspect_ratio_dict[x][1],
|
||||
)
|
||||
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}
|
||||
elif setting.startswith("ar"):
|
||||
aspect_ratio = setting[2:]
|
||||
assert (
|
||||
aspect_ratio in aspect_ratio_dict
|
||||
), f"Aspect ratio {aspect_ratio} not found"
|
||||
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}
|
||||
|
||||
return num_pexels, aspect_ratio_dict
|
||||
|
||||
|
||||
def get_closest_ratio(height: float, width: float, ratios: dict) -> str:
|
||||
aspect_ratio = height / width
|
||||
closest_ratio = min(
|
||||
ratios.keys(), key=lambda ratio: abs(aspect_ratio - get_ratio(ratio))
|
||||
)
|
||||
return closest_ratio
|
||||
|
||||
|
||||
def get_image_size(
|
||||
resolution: str, ar_ratio: str, training: bool = True
|
||||
) -> tuple[int, int]:
|
||||
num_pexels = get_num_pexels_from_name(resolution)
|
||||
ar_dict = get_aspect_ratios_dict(num_pexels, training)
|
||||
assert ar_ratio in ar_dict, f"Aspect ratio {ar_ratio} not found"
|
||||
return ar_dict[ar_ratio]
|
||||
|
||||
|
||||
def bucket_to_shapes(bucket_config, batch_size=None):
|
||||
shapes = []
|
||||
for resolution, infos in bucket_config.items():
|
||||
for num_frames, (_, bs) in infos.items():
|
||||
aspect_ratios = get_aspect_ratios_dict(get_num_pexels_from_name(resolution))
|
||||
for ar, (height, width) in aspect_ratios.items():
|
||||
if batch_size is not None:
|
||||
bs = batch_size
|
||||
shapes.append((bs, 3, num_frames, height, width))
|
||||
return shapes
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from opensora.utils.logger import log_message
|
||||
|
||||
from .aspect import get_closest_ratio, get_resolution_with_aspect_ratio
|
||||
from .utils import map_target_fps
|
||||
|
||||
|
||||
class Bucket:
|
||||
def __init__(self, bucket_config: dict[str, dict[int, tuple[float, int] | tuple[tuple[float, float], int]]]):
|
||||
"""
|
||||
Args:
|
||||
bucket_config (dict): A dictionary containing the bucket configuration.
|
||||
The dictionary should be in the following format:
|
||||
{
|
||||
"bucket_name": {
|
||||
"time": (probability, batch_size),
|
||||
"time": (probability, batch_size),
|
||||
...
|
||||
},
|
||||
...
|
||||
}
|
||||
|
||||
Or in the following format:
|
||||
{
|
||||
"bucket_name": {
|
||||
"time": ((probability, next_probability), batch_size),
|
||||
"time": ((probability, next_probability), batch_size),
|
||||
...
|
||||
},
|
||||
...
|
||||
}
|
||||
The bucket_name should be the name of the bucket, and the time should be the number of frames in the video.
|
||||
The probability should be a float between 0 and 1, and the batch_size should be an integer.
|
||||
If the probability is a tuple, the second value should be the probability to skip to the next time.
|
||||
"""
|
||||
|
||||
aspect_ratios = {key: get_resolution_with_aspect_ratio(key) for key in bucket_config.keys()}
|
||||
bucket_probs = OrderedDict()
|
||||
bucket_bs = OrderedDict()
|
||||
bucket_names = sorted(bucket_config.keys(), key=lambda x: aspect_ratios[x][0], reverse=True)
|
||||
|
||||
for key in bucket_names:
|
||||
bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True)
|
||||
bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names})
|
||||
bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names})
|
||||
|
||||
self.hw_criteria = {k: aspect_ratios[k][0] for k in bucket_names}
|
||||
self.t_criteria = {k1: {k2: k2 for k2 in bucket_config[k1].keys()} for k1 in bucket_names}
|
||||
self.ar_criteria = {
|
||||
k1: {k2: {k3: v3 for k3, v3 in aspect_ratios[k1][1].items()} for k2 in bucket_config[k1].keys()}
|
||||
for k1 in bucket_names
|
||||
}
|
||||
|
||||
bucket_id_cnt = num_bucket = 0
|
||||
bucket_id = dict()
|
||||
for k1, v1 in bucket_probs.items():
|
||||
bucket_id[k1] = dict()
|
||||
for k2, _ in v1.items():
|
||||
bucket_id[k1][k2] = bucket_id_cnt
|
||||
bucket_id_cnt += 1
|
||||
num_bucket += len(aspect_ratios[k1][1])
|
||||
|
||||
self.bucket_probs = bucket_probs
|
||||
self.bucket_bs = bucket_bs
|
||||
self.bucket_id = bucket_id
|
||||
self.num_bucket = num_bucket
|
||||
|
||||
log_message("Number of buckets: %s", num_bucket)
|
||||
|
||||
def get_bucket_id(
|
||||
self,
|
||||
T: int,
|
||||
H: int,
|
||||
W: int,
|
||||
fps: float,
|
||||
path: str | None = None,
|
||||
seed: int | None = None,
|
||||
fps_max: int = 16,
|
||||
) -> tuple[str, int, int] | None:
|
||||
approx = 0.8
|
||||
_, sampling_interval = map_target_fps(fps, fps_max)
|
||||
T = T // sampling_interval
|
||||
resolution = H * W
|
||||
rng = np.random.default_rng(seed)
|
||||
|
||||
# Reference to probabilities and criteria for faster access
|
||||
bucket_probs = self.bucket_probs
|
||||
hw_criteria = self.hw_criteria
|
||||
ar_criteria = self.ar_criteria
|
||||
|
||||
# Start searching for the appropriate bucket
|
||||
for hw_id, t_criteria in bucket_probs.items():
|
||||
# if resolution is too low, skip
|
||||
if resolution < hw_criteria[hw_id] * approx:
|
||||
continue
|
||||
|
||||
# if sample is an image
|
||||
if T == 1:
|
||||
if 1 in t_criteria:
|
||||
if rng.random() < t_criteria[1]:
|
||||
return hw_id, 1, get_closest_ratio(H, W, ar_criteria[hw_id][1])
|
||||
continue
|
||||
|
||||
# Look for suitable t_id for video
|
||||
for t_id, prob in t_criteria.items():
|
||||
if T >= t_id and t_id != 1:
|
||||
# if prob is a tuple, use the second value as the threshold to skip
|
||||
# to the next t_id
|
||||
if isinstance(prob, tuple):
|
||||
next_hw_prob, next_t_prob = prob
|
||||
if next_t_prob >= 1 or rng.random() <= next_t_prob:
|
||||
continue
|
||||
else:
|
||||
next_hw_prob = prob
|
||||
if next_hw_prob >= 1 or rng.random() <= next_hw_prob:
|
||||
ar_id = get_closest_ratio(H, W, ar_criteria[hw_id][t_id])
|
||||
return hw_id, t_id, ar_id
|
||||
else:
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
def get_thw(self, bucket_idx: tuple[str, int, int]) -> tuple[int, int, int]:
|
||||
assert len(bucket_idx) == 3
|
||||
T = self.t_criteria[bucket_idx[0]][bucket_idx[1]]
|
||||
H, W = self.ar_criteria[bucket_idx[0]][bucket_idx[1]][bucket_idx[2]]
|
||||
return T, H, W
|
||||
|
||||
def get_prob(self, bucket_idx: tuple[str, int]) -> float:
|
||||
return self.bucket_probs[bucket_idx[0]][bucket_idx[1]]
|
||||
|
||||
def get_batch_size(self, bucket_idx: tuple[str, int]) -> int:
|
||||
return self.bucket_bs[bucket_idx[0]][bucket_idx[1]]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_bucket
|
||||
|
|
@ -0,0 +1,402 @@
|
|||
import collections
|
||||
import functools
|
||||
import os
|
||||
import queue
|
||||
import random
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing as multiprocessing
|
||||
from torch._utils import ExceptionWrapper
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.utils.data import DataLoader, _utils
|
||||
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
|
||||
from torch.utils.data.dataloader import (
|
||||
IterDataPipe,
|
||||
MapDataPipe,
|
||||
_BaseDataLoaderIter,
|
||||
_MultiProcessingDataLoaderIter,
|
||||
_sharding_worker_init_fn,
|
||||
_SingleProcessDataLoaderIter,
|
||||
)
|
||||
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.registry import DATASETS, build_module
|
||||
from opensora.utils.config import parse_configs
|
||||
from opensora.utils.logger import create_logger
|
||||
from opensora.utils.misc import format_duration
|
||||
from opensora.utils.train import setup_device
|
||||
|
||||
from .datasets import TextDataset, VideoTextDataset
|
||||
from .pin_memory_cache import PinMemoryCache
|
||||
from .sampler import DistributedSampler, VariableVideoBatchSampler
|
||||
|
||||
|
||||
def _pin_memory_loop(
|
||||
in_queue, out_queue, device_id, done_event, device, pin_memory_cache: PinMemoryCache, pin_memory_key: str
|
||||
):
|
||||
# This setting is thread local, and prevents the copy in pin_memory from
|
||||
# consuming all CPU cores.
|
||||
torch.set_num_threads(1)
|
||||
|
||||
if device == "cuda":
|
||||
torch.cuda.set_device(device_id)
|
||||
elif device == "xpu":
|
||||
torch.xpu.set_device(device_id) # type: ignore[attr-defined]
|
||||
elif device == torch._C._get_privateuse1_backend_name():
|
||||
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
|
||||
custom_device_mod.set_device(device_id)
|
||||
|
||||
def do_one_step():
|
||||
try:
|
||||
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
||||
except queue.Empty:
|
||||
return
|
||||
idx, data = r
|
||||
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
|
||||
try:
|
||||
assert isinstance(data, dict)
|
||||
if pin_memory_key in data:
|
||||
val = data[pin_memory_key]
|
||||
pin_memory_value = pin_memory_cache.get(val)
|
||||
pin_memory_value.copy_(val)
|
||||
data[pin_memory_key] = pin_memory_value
|
||||
except Exception:
|
||||
data = ExceptionWrapper(where=f"in pin memory thread for device {device_id}")
|
||||
r = (idx, data)
|
||||
while not done_event.is_set():
|
||||
try:
|
||||
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
|
||||
break
|
||||
except queue.Full:
|
||||
continue
|
||||
|
||||
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
||||
# logic of this function.
|
||||
while not done_event.is_set():
|
||||
# Make sure that we don't preserve any object from one iteration
|
||||
# to the next
|
||||
do_one_step()
|
||||
|
||||
|
||||
class _MultiProcessingDataLoaderIterForVideo(_MultiProcessingDataLoaderIter):
|
||||
pin_memory_key: str = "video"
|
||||
|
||||
def __init__(self, loader):
|
||||
_BaseDataLoaderIter.__init__(self, loader)
|
||||
self.pin_memory_cache = PinMemoryCache()
|
||||
|
||||
self._prefetch_factor = loader.prefetch_factor
|
||||
|
||||
assert self._num_workers > 0
|
||||
assert self._prefetch_factor > 0
|
||||
|
||||
if loader.multiprocessing_context is None:
|
||||
multiprocessing_context = multiprocessing
|
||||
else:
|
||||
multiprocessing_context = loader.multiprocessing_context
|
||||
|
||||
self._worker_init_fn = loader.worker_init_fn
|
||||
|
||||
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
||||
# Additional worker init function will take care of sharding in MP and Distributed
|
||||
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
|
||||
self._worker_init_fn = functools.partial(
|
||||
_sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank
|
||||
)
|
||||
|
||||
# No certainty which module multiprocessing_context is
|
||||
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
||||
self._worker_pids_set = False
|
||||
self._shutdown = False
|
||||
self._workers_done_event = multiprocessing_context.Event()
|
||||
|
||||
self._index_queues = []
|
||||
self._workers = []
|
||||
for i in range(self._num_workers):
|
||||
# No certainty which module multiprocessing_context is
|
||||
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
||||
# Need to `cancel_join_thread` here!
|
||||
# See sections (2) and (3b) above.
|
||||
index_queue.cancel_join_thread()
|
||||
w = multiprocessing_context.Process(
|
||||
target=_utils.worker._worker_loop,
|
||||
args=(
|
||||
self._dataset_kind,
|
||||
self._dataset,
|
||||
index_queue,
|
||||
self._worker_result_queue,
|
||||
self._workers_done_event,
|
||||
self._auto_collation,
|
||||
self._collate_fn,
|
||||
self._drop_last,
|
||||
self._base_seed,
|
||||
self._worker_init_fn,
|
||||
i,
|
||||
self._num_workers,
|
||||
self._persistent_workers,
|
||||
self._shared_seed,
|
||||
),
|
||||
)
|
||||
w.daemon = True
|
||||
# NB: Process.start() actually take some time as it needs to
|
||||
# start a process and pass the arguments over via a pipe.
|
||||
# Therefore, we only add a worker to self._workers list after
|
||||
# it started, so that we do not call .join() if program dies
|
||||
# before it starts, and __del__ tries to join but will get:
|
||||
# AssertionError: can only join a started process.
|
||||
w.start()
|
||||
self._index_queues.append(index_queue)
|
||||
self._workers.append(w)
|
||||
|
||||
if self._pin_memory:
|
||||
self._pin_memory_thread_done_event = threading.Event()
|
||||
|
||||
# Queue is not type-annotated
|
||||
self._data_queue = queue.Queue() # type: ignore[var-annotated]
|
||||
if self._pin_memory_device == "xpu":
|
||||
current_device = torch.xpu.current_device() # type: ignore[attr-defined]
|
||||
elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
|
||||
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
|
||||
current_device = custom_device_mod.current_device()
|
||||
else:
|
||||
current_device = torch.cuda.current_device() # choose cuda for default
|
||||
pin_memory_thread = threading.Thread(
|
||||
target=_pin_memory_loop,
|
||||
args=(
|
||||
self._worker_result_queue,
|
||||
self._data_queue,
|
||||
current_device,
|
||||
self._pin_memory_thread_done_event,
|
||||
self._pin_memory_device,
|
||||
self.pin_memory_cache,
|
||||
self.pin_memory_key,
|
||||
),
|
||||
)
|
||||
pin_memory_thread.daemon = True
|
||||
pin_memory_thread.start()
|
||||
# Similar to workers (see comment above), we only register
|
||||
# pin_memory_thread once it is started.
|
||||
self._pin_memory_thread = pin_memory_thread
|
||||
else:
|
||||
self._data_queue = self._worker_result_queue # type: ignore[assignment]
|
||||
|
||||
# In some rare cases, persistent workers (daemonic processes)
|
||||
# would be terminated before `__del__` of iterator is invoked
|
||||
# when main process exits
|
||||
# It would cause failure when pin_memory_thread tries to read
|
||||
# corrupted data from worker_result_queue
|
||||
# atexit is used to shutdown thread and child processes in the
|
||||
# right sequence before main process exits
|
||||
if self._persistent_workers and self._pin_memory:
|
||||
import atexit
|
||||
|
||||
for w in self._workers:
|
||||
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
|
||||
|
||||
# .pid can be None only before process is spawned (not the case, so ignore)
|
||||
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
|
||||
_utils.signal_handling._set_SIGCHLD_handler()
|
||||
self._worker_pids_set = True
|
||||
self._reset(loader, first_iter=True)
|
||||
|
||||
def remove_cache(self, output_tensor: torch.Tensor):
|
||||
self.pin_memory_cache.remove(output_tensor)
|
||||
|
||||
def get_cache_info(self) -> str:
|
||||
return str(self.pin_memory_cache)
|
||||
|
||||
|
||||
class DataloaderForVideo(DataLoader):
|
||||
def _get_iterator(self) -> "_BaseDataLoaderIter":
|
||||
if self.num_workers == 0:
|
||||
return _SingleProcessDataLoaderIter(self)
|
||||
else:
|
||||
self.check_worker_number_rationality()
|
||||
return _MultiProcessingDataLoaderIterForVideo(self)
|
||||
|
||||
|
||||
# Deterministic dataloader
|
||||
def get_seed_worker(seed):
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
if seed is not None:
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return seed_worker
|
||||
|
||||
|
||||
def prepare_dataloader(
|
||||
dataset,
|
||||
batch_size=None,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group: ProcessGroup | None = None,
|
||||
bucket_config=None,
|
||||
num_bucket_build_workers=1,
|
||||
prefetch_factor=None,
|
||||
cache_pin_memory=False,
|
||||
num_groups=1,
|
||||
**kwargs,
|
||||
):
|
||||
_kwargs = kwargs.copy()
|
||||
if isinstance(dataset, VideoTextDataset):
|
||||
batch_sampler = VariableVideoBatchSampler(
|
||||
dataset,
|
||||
bucket_config,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
verbose=True,
|
||||
num_bucket_build_workers=num_bucket_build_workers,
|
||||
num_groups=num_groups,
|
||||
)
|
||||
dl_cls = DataloaderForVideo if cache_pin_memory else DataLoader
|
||||
return (
|
||||
dl_cls(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
worker_init_fn=get_seed_worker(seed),
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn_default,
|
||||
prefetch_factor=prefetch_factor,
|
||||
**_kwargs,
|
||||
),
|
||||
batch_sampler,
|
||||
)
|
||||
elif isinstance(dataset, TextDataset):
|
||||
if process_group is None:
|
||||
return (
|
||||
DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
worker_init_fn=get_seed_worker(seed),
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
prefetch_factor=prefetch_factor,
|
||||
**_kwargs,
|
||||
),
|
||||
None,
|
||||
)
|
||||
else:
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
return (
|
||||
DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
worker_init_fn=get_seed_worker(seed),
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn_default,
|
||||
prefetch_factor=prefetch_factor,
|
||||
**_kwargs,
|
||||
),
|
||||
sampler,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
|
||||
|
||||
|
||||
def collate_fn_default(batch):
|
||||
# filter out None
|
||||
batch = [x for x in batch if x is not None]
|
||||
assert len(batch) > 0, "batch is empty"
|
||||
|
||||
# HACK: for loading text features
|
||||
use_mask = False
|
||||
if "mask" in batch[0] and isinstance(batch[0]["mask"], int):
|
||||
masks = [x.pop("mask") for x in batch]
|
||||
|
||||
texts = [x.pop("text") for x in batch]
|
||||
texts = torch.cat(texts, dim=1)
|
||||
use_mask = True
|
||||
|
||||
ret = torch.utils.data.default_collate(batch)
|
||||
|
||||
if use_mask:
|
||||
ret["mask"] = masks
|
||||
ret["text"] = texts
|
||||
return ret
|
||||
|
||||
|
||||
def collate_fn_batch(batch):
|
||||
"""
|
||||
Used only with BatchDistributedSampler
|
||||
"""
|
||||
# filter out None
|
||||
batch = [x for x in batch if x is not None]
|
||||
|
||||
res = torch.utils.data.default_collate(batch)
|
||||
|
||||
# squeeze the first dimension, which is due to torch.stack() in default_collate()
|
||||
if isinstance(res, collections.abc.Mapping):
|
||||
for k, v in res.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
res[k] = v.squeeze(0)
|
||||
elif isinstance(res, collections.abc.Sequence):
|
||||
res = [x.squeeze(0) if isinstance(x, torch.Tensor) else x for x in res]
|
||||
elif isinstance(res, torch.Tensor):
|
||||
res = res.squeeze(0)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# NUM_GPU: number of GPUs for training
|
||||
# TIME_PER_STEP: time per step in seconds
|
||||
|
||||
# Example usage:
|
||||
# torchrun --nproc_per_node 1 -m opensora.datasets.dataloader configs/diffusion/train/video_cond.py
|
||||
cfg = parse_configs()
|
||||
setup_device()
|
||||
logger = create_logger()
|
||||
|
||||
# == build dataset ==
|
||||
dataset = build_module(cfg.dataset, DATASETS)
|
||||
|
||||
# == build dataloader ==
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.get("batch_size", None),
|
||||
num_workers=cfg.get("num_workers", 4),
|
||||
seed=cfg.get("seed", 1024),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
prefetch_factor=cfg.get("prefetch_factor", None),
|
||||
)
|
||||
dataloader, sampler = prepare_dataloader(
|
||||
bucket_config=cfg.get("bucket_config", None),
|
||||
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
|
||||
**dataloader_args,
|
||||
)
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
num_machines = int(os.environ.get("NUM_MACHINES", 28))
|
||||
num_gpu = num_machines * 8
|
||||
logger.info("Number of GPUs: %d", num_gpu)
|
||||
logger.info("Number of steps per epoch: %d", num_steps_per_epoch // num_gpu)
|
||||
time_per_step = int(os.environ.get("TIME_PER_STEP", 20))
|
||||
time_training = num_steps_per_epoch // num_gpu * time_per_step
|
||||
logger.info("Time per step: %s", format_duration(time_per_step))
|
||||
logger.info("Time for training: %s", format_duration(time_training))
|
||||
|
|
@ -0,0 +1,315 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from PIL import ImageFile
|
||||
from torchvision.datasets.folder import pil_loader
|
||||
|
||||
from opensora.registry import DATASETS
|
||||
|
||||
from .read_video import read_video
|
||||
from .utils import get_transforms_image, get_transforms_video, is_img, map_target_fps, read_file, temporal_random_crop
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
VALID_KEYS = ("neg", "path")
|
||||
K = 10000
|
||||
|
||||
|
||||
class Iloc:
|
||||
def __init__(self, data, sharded_folder, sharded_folders, rows_per_shard):
|
||||
self.data = data
|
||||
self.sharded_folder = sharded_folder
|
||||
self.sharded_folders = sharded_folders
|
||||
self.rows_per_shard = rows_per_shard
|
||||
|
||||
def __getitem__(self, index):
|
||||
return Item(
|
||||
index,
|
||||
self.data,
|
||||
self.sharded_folder,
|
||||
self.sharded_folders,
|
||||
self.rows_per_shard,
|
||||
)
|
||||
|
||||
|
||||
class Item:
|
||||
def __init__(self, index, data, sharded_folder, sharded_folders, rows_per_shard):
|
||||
self.index = index
|
||||
self.data = data
|
||||
self.sharded_folder = sharded_folder
|
||||
self.sharded_folders = sharded_folders
|
||||
self.rows_per_shard = rows_per_shard
|
||||
|
||||
def __getitem__(self, key):
|
||||
index = self.index
|
||||
if key in self.data.columns:
|
||||
return self.data[key].iloc[index]
|
||||
else:
|
||||
shard_idx = index // self.rows_per_shard
|
||||
idx = index % self.rows_per_shard
|
||||
shard_parquet = os.path.join(self.sharded_folder, self.sharded_folders[shard_idx])
|
||||
try:
|
||||
text_parquet = pd.read_parquet(shard_parquet, engine="fastparquet")
|
||||
path = text_parquet["path"].iloc[idx]
|
||||
assert path == self.data["path"].iloc[index]
|
||||
except Exception as e:
|
||||
print(f"Error reading {shard_parquet}: {e}")
|
||||
raise
|
||||
return text_parquet[key].iloc[idx]
|
||||
|
||||
def to_dict(self):
|
||||
index = self.index
|
||||
ret = {}
|
||||
ret.update(self.data.iloc[index].to_dict())
|
||||
shard_idx = index // self.rows_per_shard
|
||||
idx = index % self.rows_per_shard
|
||||
shard_parquet = os.path.join(self.sharded_folder, self.sharded_folders[shard_idx])
|
||||
try:
|
||||
text_parquet = pd.read_parquet(shard_parquet, engine="fastparquet")
|
||||
path = text_parquet["path"].iloc[idx]
|
||||
assert path == self.data["path"].iloc[index]
|
||||
ret.update(text_parquet.iloc[idx].to_dict())
|
||||
except Exception as e:
|
||||
print(f"Error reading {shard_parquet}: {e}")
|
||||
ret.update({"text": ""})
|
||||
return ret
|
||||
|
||||
|
||||
class EfficientParquet:
|
||||
def __init__(self, df, sharded_folder):
|
||||
self.data = df
|
||||
self.total_rows = len(df)
|
||||
self.rows_per_shard = (self.total_rows + K - 1) // K
|
||||
self.sharded_folder = sharded_folder
|
||||
assert os.path.exists(sharded_folder), f"Sharded folder {sharded_folder} does not exist."
|
||||
self.sharded_folders = os.listdir(sharded_folder)
|
||||
self.sharded_folders = sorted(self.sharded_folders)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_rows
|
||||
|
||||
@property
|
||||
def iloc(self):
|
||||
return Iloc(self.data, self.sharded_folder, self.sharded_folders, self.rows_per_shard)
|
||||
|
||||
|
||||
@DATASETS.register_module("text")
|
||||
class TextDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset for text data
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str = None,
|
||||
tokenize_fn: callable = None,
|
||||
fps_max: int = 16,
|
||||
vmaf: bool = False,
|
||||
memory_efficient: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.data_path = data_path
|
||||
self.data = read_file(data_path, memory_efficient=memory_efficient)
|
||||
self.memory_efficient = memory_efficient
|
||||
self.tokenize_fn = tokenize_fn
|
||||
self.vmaf = vmaf
|
||||
|
||||
if fps_max is not None:
|
||||
self.fps_max = fps_max
|
||||
else:
|
||||
self.fps_max = 999999999
|
||||
|
||||
def to_efficient(self):
|
||||
if self.memory_efficient:
|
||||
addition_data_path = self.data_path.split(".")[0]
|
||||
self._data = self.data
|
||||
self.data = EfficientParquet(self._data, addition_data_path)
|
||||
|
||||
def getitem(self, index: int) -> dict:
|
||||
ret = dict()
|
||||
sample = self.data.iloc[index].to_dict()
|
||||
sample_fps = sample.get("fps", np.nan)
|
||||
new_fps, sampling_interval = map_target_fps(sample_fps, self.fps_max)
|
||||
ret.update({"sampling_interval": sampling_interval})
|
||||
|
||||
if "text" in sample:
|
||||
ret["text"] = sample.pop("text")
|
||||
postfixs = []
|
||||
if new_fps != 0 and self.fps_max < 999:
|
||||
postfixs.append(f"{new_fps} FPS")
|
||||
if self.vmaf and "score_vmafmotion" in sample and not np.isnan(sample["score_vmafmotion"]):
|
||||
postfixs.append(f"{int(sample['score_vmafmotion'] + 0.5)} motion score")
|
||||
postfix = " " + ", ".join(postfixs) + "." if postfixs else ""
|
||||
ret["text"] = ret["text"] + postfix
|
||||
if self.tokenize_fn is not None:
|
||||
ret.update({k: v.squeeze(0) for k, v in self.tokenize_fn(ret["text"]).items()})
|
||||
|
||||
if "ref" in sample: # i2v & v2v reference
|
||||
ret["ref"] = sample.pop("ref")
|
||||
|
||||
# name of the generated sample
|
||||
if "name" in sample: # sample name (`dataset_idx`)
|
||||
ret["name"] = sample.pop("name")
|
||||
else:
|
||||
ret["index"] = index # use index for name
|
||||
valid_sample = {k: v for k, v in sample.items() if k in VALID_KEYS}
|
||||
ret.update(valid_sample)
|
||||
return ret
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.getitem(index)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
@DATASETS.register_module("video_text")
|
||||
class VideoTextDataset(TextDataset):
|
||||
def __init__(
|
||||
self,
|
||||
transform_name: str = None,
|
||||
bucket_class: str = "Bucket",
|
||||
rand_sample_interval: int = None, # random sample_interval value from [1, min(rand_sample_interval, video_allowed_max)]
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.transform_name = transform_name
|
||||
self.bucket_class = bucket_class
|
||||
self.rand_sample_interval = rand_sample_interval
|
||||
|
||||
def get_image(self, index: int, height: int, width: int) -> dict:
|
||||
sample = self.data.iloc[index]
|
||||
path = sample["path"]
|
||||
# loading
|
||||
image = pil_loader(path)
|
||||
|
||||
# transform
|
||||
transform = get_transforms_image(self.transform_name, (height, width))
|
||||
image = transform(image)
|
||||
|
||||
# CHW -> CTHW
|
||||
video = image.unsqueeze(1)
|
||||
|
||||
return {"video": video}
|
||||
|
||||
def get_video(self, index: int, num_frames: int, height: int, width: int, sampling_interval: int) -> dict:
|
||||
sample = self.data.iloc[index]
|
||||
path = sample["path"]
|
||||
|
||||
# loading
|
||||
vframes, vinfo = read_video(path, backend="av")
|
||||
|
||||
if self.rand_sample_interval is not None:
|
||||
# randomly sample from 1 - self.rand_sample_interval
|
||||
video_allowed_max = min(len(vframes) // num_frames, self.rand_sample_interval)
|
||||
sampling_interval = random.randint(1, video_allowed_max)
|
||||
|
||||
# Sampling video frames
|
||||
video = temporal_random_crop(vframes, num_frames, sampling_interval)
|
||||
|
||||
video = video.clone()
|
||||
del vframes
|
||||
|
||||
# transform
|
||||
transform = get_transforms_video(self.transform_name, (height, width))
|
||||
video = transform(video) # T C H W
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
|
||||
ret = {"video": video}
|
||||
|
||||
return ret
|
||||
|
||||
def get_image_or_video(self, index: int, num_frames: int, height: int, width: int, sampling_interval: int) -> dict:
|
||||
sample = self.data.iloc[index]
|
||||
path = sample["path"]
|
||||
|
||||
if is_img(path):
|
||||
return self.get_image(index, height, width)
|
||||
return self.get_video(index, num_frames, height, width, sampling_interval)
|
||||
|
||||
def getitem(self, index: str) -> dict:
|
||||
# a hack to pass in the (time, height, width) info from sampler
|
||||
index, num_frames, height, width = [int(val) for val in index.split("-")]
|
||||
ret = dict()
|
||||
ret.update(super().getitem(index))
|
||||
try:
|
||||
ret.update(self.get_image_or_video(index, num_frames, height, width, ret["sampling_interval"]))
|
||||
except Exception as e:
|
||||
path = self.data.iloc[index]["path"]
|
||||
print(f"video {path}: {e}")
|
||||
return None
|
||||
return ret
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.getitem(index)
|
||||
|
||||
|
||||
@DATASETS.register_module("cached_video_text")
|
||||
class CachedVideoTextDataset(VideoTextDataset):
|
||||
def __init__(
|
||||
self,
|
||||
transform_name: str = None,
|
||||
bucket_class: str = "Bucket",
|
||||
rand_sample_interval: int = None, # random sample_interval value from [1, min(rand_sample_interval, video_allowed_max)]
|
||||
cached_video: bool = False,
|
||||
cached_text: bool = False,
|
||||
return_latents_path: bool = False,
|
||||
load_original_video: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.transform_name = transform_name
|
||||
self.bucket_class = bucket_class
|
||||
self.rand_sample_interval = rand_sample_interval
|
||||
self.cached_video = cached_video
|
||||
self.cached_text = cached_text
|
||||
self.return_latents_path = return_latents_path
|
||||
self.load_original_video = load_original_video
|
||||
|
||||
def get_latents(self, path):
|
||||
try:
|
||||
latents = torch.load(path, map_location=torch.device("cpu"))
|
||||
except Exception as e:
|
||||
print(f"Error loading latents from {path}: {e}")
|
||||
return torch.zeros_like(torch.randn(1, 1, 1, 1))
|
||||
return latents
|
||||
|
||||
def get_conditioning_latents(self, index: int) -> dict:
|
||||
sample = self.data.iloc[index]
|
||||
latents_path = sample["latents_path"]
|
||||
text_t5_path = sample["text_t5_path"]
|
||||
text_clip_path = sample["text_clip_path"]
|
||||
res = dict()
|
||||
if self.cached_video:
|
||||
latents = self.get_latents(latents_path)
|
||||
res["video_latents"] = latents
|
||||
if self.cached_text:
|
||||
text_t5 = self.get_latents(text_t5_path)
|
||||
text_clip = self.get_latents(text_clip_path)
|
||||
res["text_t5"] = text_t5
|
||||
res["text_clip"] = text_clip
|
||||
if self.return_latents_path:
|
||||
res["latents_path"] = latents_path
|
||||
res["text_t5_path"] = text_t5_path
|
||||
res["text_clip_path"] = text_clip_path
|
||||
return res
|
||||
|
||||
def getitem(self, index: str) -> dict:
|
||||
# a hack to pass in the (time, height, width) info from sampler
|
||||
real_index, num_frames, height, width = [int(val) for val in index.split("-")]
|
||||
ret = dict()
|
||||
if self.load_original_video:
|
||||
ret.update(super().getitem(index))
|
||||
try:
|
||||
ret.update(self.get_conditioning_latents(real_index))
|
||||
except Exception as e:
|
||||
path = self.data.iloc[real_index]["path"]
|
||||
print(f"video {path}: {e}")
|
||||
return None
|
||||
return ret
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.getitem(index)
|
||||
|
|
@ -0,0 +1,176 @@
|
|||
import multiprocessing
|
||||
from itertools import count
|
||||
from multiprocessing.managers import SyncManager
|
||||
from typing import Any, Callable, Dict, Tuple, Type, cast
|
||||
|
||||
import dill
|
||||
import pandarallel
|
||||
import pandas as pd
|
||||
from pandarallel.data_types import DataType
|
||||
from pandarallel.progress_bars import ProgressBarsType, get_progress_bars, progress_wrapper
|
||||
from pandarallel.utils import WorkerStatus
|
||||
|
||||
CONTEXT = multiprocessing.get_context("fork")
|
||||
TMP = []
|
||||
|
||||
|
||||
class WrapWorkFunctionForPipe:
|
||||
def __init__(
|
||||
self,
|
||||
work_function: Callable[
|
||||
[
|
||||
Any,
|
||||
Callable,
|
||||
tuple,
|
||||
Dict[str, Any],
|
||||
Dict[str, Any],
|
||||
],
|
||||
Any,
|
||||
],
|
||||
) -> None:
|
||||
self.work_function = work_function
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
progress_bars_type: ProgressBarsType,
|
||||
worker_index: int,
|
||||
master_workers_queue: multiprocessing.Queue,
|
||||
dilled_user_defined_function: bytes,
|
||||
user_defined_function_args: tuple,
|
||||
user_defined_function_kwargs: Dict[str, Any],
|
||||
extra: Dict[str, Any],
|
||||
) -> Any:
|
||||
try:
|
||||
data = TMP[worker_index]
|
||||
data_size = len(data)
|
||||
user_defined_function: Callable = dill.loads(dilled_user_defined_function)
|
||||
|
||||
progress_wrapped_user_defined_function = progress_wrapper(
|
||||
user_defined_function, master_workers_queue, worker_index, data_size
|
||||
)
|
||||
|
||||
used_user_defined_function = (
|
||||
progress_wrapped_user_defined_function
|
||||
if progress_bars_type
|
||||
in (
|
||||
ProgressBarsType.InUserDefinedFunction,
|
||||
ProgressBarsType.InUserDefinedFunctionMultiplyByNumberOfColumns,
|
||||
)
|
||||
else user_defined_function
|
||||
)
|
||||
|
||||
results = self.work_function(
|
||||
data,
|
||||
used_user_defined_function,
|
||||
user_defined_function_args,
|
||||
user_defined_function_kwargs,
|
||||
extra,
|
||||
)
|
||||
|
||||
master_workers_queue.put((worker_index, WorkerStatus.Success, None))
|
||||
|
||||
return results
|
||||
|
||||
except:
|
||||
master_workers_queue.put((worker_index, WorkerStatus.Error, None))
|
||||
raise
|
||||
|
||||
|
||||
def parallelize_with_pipe(
|
||||
nb_requested_workers: int,
|
||||
data_type: Type[DataType],
|
||||
progress_bars_type: ProgressBarsType,
|
||||
):
|
||||
def closure(
|
||||
data: Any,
|
||||
user_defined_function: Callable,
|
||||
*user_defined_function_args: tuple,
|
||||
**user_defined_function_kwargs: Dict[str, Any],
|
||||
):
|
||||
wrapped_work_function = WrapWorkFunctionForPipe(data_type.work)
|
||||
dilled_user_defined_function = dill.dumps(user_defined_function)
|
||||
manager: SyncManager = CONTEXT.Manager()
|
||||
master_workers_queue = manager.Queue()
|
||||
|
||||
chunks = list(
|
||||
data_type.get_chunks(
|
||||
nb_requested_workers,
|
||||
data,
|
||||
user_defined_function_kwargs=user_defined_function_kwargs,
|
||||
)
|
||||
)
|
||||
TMP.extend(chunks)
|
||||
|
||||
nb_workers = len(chunks)
|
||||
|
||||
multiplicator_factor = (
|
||||
len(cast(pd.DataFrame, data).columns)
|
||||
if progress_bars_type == ProgressBarsType.InUserDefinedFunctionMultiplyByNumberOfColumns
|
||||
else 1
|
||||
)
|
||||
|
||||
progresses_length = [len(chunk_) * multiplicator_factor for chunk_ in chunks]
|
||||
|
||||
work_extra = data_type.get_work_extra(data)
|
||||
reduce_extra = data_type.get_reduce_extra(data, user_defined_function_kwargs)
|
||||
|
||||
show_progress_bars = progress_bars_type != ProgressBarsType.No
|
||||
|
||||
progress_bars = get_progress_bars(progresses_length, show_progress_bars)
|
||||
progresses = [0] * nb_workers
|
||||
workers_status = [WorkerStatus.Running] * nb_workers
|
||||
|
||||
work_args_list = [
|
||||
(
|
||||
progress_bars_type,
|
||||
worker_index,
|
||||
master_workers_queue,
|
||||
dilled_user_defined_function,
|
||||
user_defined_function_args,
|
||||
user_defined_function_kwargs,
|
||||
{
|
||||
**work_extra,
|
||||
**{
|
||||
"master_workers_queue": master_workers_queue,
|
||||
"show_progress_bars": show_progress_bars,
|
||||
"worker_index": worker_index,
|
||||
},
|
||||
},
|
||||
)
|
||||
for worker_index in range(nb_workers)
|
||||
]
|
||||
|
||||
pool = CONTEXT.Pool(nb_workers)
|
||||
results_promise = pool.starmap_async(wrapped_work_function, work_args_list)
|
||||
pool.close()
|
||||
|
||||
generation = count()
|
||||
|
||||
while any((worker_status == WorkerStatus.Running for worker_status in workers_status)):
|
||||
message: Tuple[int, WorkerStatus, Any] = master_workers_queue.get()
|
||||
worker_index, worker_status, payload = message
|
||||
workers_status[worker_index] = worker_status
|
||||
|
||||
if worker_status == WorkerStatus.Success:
|
||||
progresses[worker_index] = progresses_length[worker_index]
|
||||
progress_bars.update(progresses)
|
||||
elif worker_status == WorkerStatus.Running:
|
||||
progress = cast(int, payload)
|
||||
progresses[worker_index] = progress
|
||||
|
||||
if next(generation) % nb_workers == 0:
|
||||
progress_bars.update(progresses)
|
||||
elif worker_status == WorkerStatus.Error:
|
||||
progress_bars.set_error(worker_index)
|
||||
|
||||
results = results_promise.get()
|
||||
TMP.clear()
|
||||
|
||||
return data_type.reduce(results, reduce_extra)
|
||||
|
||||
return closure
|
||||
|
||||
|
||||
pandarallel.core.WrapWorkFunctionForPipe = WrapWorkFunctionForPipe
|
||||
pandarallel.core.parallelize_with_pipe = parallelize_with_pipe
|
||||
pandarallel = pandarallel.pandarallel
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class PinMemoryCache:
|
||||
force_dtype: Optional[torch.dtype] = None
|
||||
min_cache_numel: int = 0
|
||||
pre_alloc_numels: List[int] = []
|
||||
|
||||
def __init__(self):
|
||||
self.cache: Dict[int, torch.Tensor] = {}
|
||||
self.output_to_cache: Dict[int, int] = {}
|
||||
self.cache_to_output: Dict[int, int] = {}
|
||||
self.lock = threading.Lock()
|
||||
self.total_cnt = 0
|
||||
self.hit_cnt = 0
|
||||
|
||||
if len(self.pre_alloc_numels) > 0 and self.force_dtype is not None:
|
||||
for n in self.pre_alloc_numels:
|
||||
cache_tensor = torch.empty(n, dtype=self.force_dtype, device="cpu", pin_memory=True)
|
||||
with self.lock:
|
||||
self.cache[id(cache_tensor)] = cache_tensor
|
||||
|
||||
def get(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Receive a cpu tensor and return the corresponding pinned tensor. Note that this only manage memory allocation, doesn't copy content.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be pinned.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The pinned tensor.
|
||||
"""
|
||||
self.total_cnt += 1
|
||||
with self.lock:
|
||||
# find free cache
|
||||
for cache_id, cache_tensor in self.cache.items():
|
||||
if cache_id not in self.cache_to_output and cache_tensor.numel() >= tensor.numel():
|
||||
target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape)
|
||||
out_id = id(target_cache_tensor)
|
||||
self.output_to_cache[out_id] = cache_id
|
||||
self.cache_to_output[cache_id] = out_id
|
||||
self.hit_cnt += 1
|
||||
return target_cache_tensor
|
||||
# no free cache, create a new one
|
||||
dtype = self.force_dtype if self.force_dtype is not None else tensor.dtype
|
||||
cache_numel = max(tensor.numel(), self.min_cache_numel)
|
||||
cache_tensor = torch.empty(cache_numel, dtype=dtype, device="cpu", pin_memory=True)
|
||||
target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape)
|
||||
out_id = id(target_cache_tensor)
|
||||
with self.lock:
|
||||
self.cache[id(cache_tensor)] = cache_tensor
|
||||
self.output_to_cache[out_id] = id(cache_tensor)
|
||||
self.cache_to_output[id(cache_tensor)] = out_id
|
||||
return target_cache_tensor
|
||||
|
||||
def remove(self, output_tensor: torch.Tensor) -> None:
|
||||
"""Release corresponding cache tensor.
|
||||
|
||||
Args:
|
||||
output_tensor (torch.Tensor): The tensor to be released.
|
||||
"""
|
||||
out_id = id(output_tensor)
|
||||
with self.lock:
|
||||
if out_id not in self.output_to_cache:
|
||||
raise ValueError("Tensor not found in cache.")
|
||||
cache_id = self.output_to_cache.pop(out_id)
|
||||
del self.cache_to_output[cache_id]
|
||||
|
||||
def __str__(self):
|
||||
with self.lock:
|
||||
num_cached = len(self.cache)
|
||||
num_used = len(self.output_to_cache)
|
||||
total_cache_size = sum([v.numel() * v.element_size() for v in self.cache.values()])
|
||||
return f"PinMemoryCache(num_cached={num_cached}, num_used={num_used}, total_cache_size={total_cache_size / 1024**3:.2f} GB, hit rate={self.hit_cnt / self.total_cnt:.2f})"
|
||||
|
|
@ -0,0 +1,257 @@
|
|||
import gc
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from fractions import Fraction
|
||||
|
||||
import av
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import get_video_backend
|
||||
from torchvision.io.video import _check_av_available
|
||||
|
||||
MAX_NUM_FRAMES = 2500
|
||||
|
||||
|
||||
def read_video_av(
|
||||
filename: str,
|
||||
start_pts: float | Fraction = 0,
|
||||
end_pts: float | Fraction | None = None,
|
||||
pts_unit: str = "pts",
|
||||
output_format: str = "THWC",
|
||||
) -> tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
"""
|
||||
Reads a video from a file, returning both the video frames and the audio frames
|
||||
|
||||
This method is modified from torchvision.io.video.read_video, with the following changes:
|
||||
|
||||
1. will not extract audio frames and return empty for aframes
|
||||
2. remove checks and only support pyav
|
||||
3. add container.close() and gc.collect() to avoid thread leakage
|
||||
4. try our best to avoid memory leak
|
||||
|
||||
Args:
|
||||
filename (str): path to the video file
|
||||
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
||||
The start presentation time of the video
|
||||
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
||||
The end presentation time
|
||||
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
|
||||
either 'pts' or 'sec'. Defaults to 'pts'.
|
||||
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
|
||||
|
||||
Returns:
|
||||
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
|
||||
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
|
||||
info (dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
|
||||
"""
|
||||
# format
|
||||
output_format = output_format.upper()
|
||||
if output_format not in ("THWC", "TCHW"):
|
||||
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
|
||||
# file existence
|
||||
if not os.path.exists(filename):
|
||||
raise RuntimeError(f"File not found: {filename}")
|
||||
# backend check
|
||||
assert get_video_backend() == "pyav", "pyav backend is required for read_video_av"
|
||||
_check_av_available()
|
||||
# end_pts check
|
||||
if end_pts is None:
|
||||
end_pts = float("inf")
|
||||
if end_pts < start_pts:
|
||||
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
|
||||
|
||||
# == get video info ==
|
||||
info = {}
|
||||
# TODO: creating an container leads to memory leak (1G for 8 workers 1 GPU)
|
||||
container = av.open(filename, metadata_errors="ignore")
|
||||
# fps
|
||||
video_fps = container.streams.video[0].average_rate
|
||||
# guard against potentially corrupted files
|
||||
if video_fps is not None:
|
||||
info["video_fps"] = float(video_fps)
|
||||
iter_video = container.decode(**{"video": 0})
|
||||
frame = next(iter_video).to_rgb().to_ndarray()
|
||||
height, width = frame.shape[:2]
|
||||
total_frames = container.streams.video[0].frames
|
||||
if total_frames == 0:
|
||||
total_frames = MAX_NUM_FRAMES
|
||||
warnings.warn(f"total_frames is 0, using {MAX_NUM_FRAMES} as a fallback")
|
||||
container.close()
|
||||
del container
|
||||
|
||||
# HACK: must create before iterating stream
|
||||
# use np.zeros will not actually allocate memory
|
||||
# use np.ones will lead to a little memory leak
|
||||
video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8)
|
||||
|
||||
# == read ==
|
||||
try:
|
||||
# TODO: The reading has memory leak (4G for 8 workers 1 GPU)
|
||||
container = av.open(filename, metadata_errors="ignore")
|
||||
assert container.streams.video is not None
|
||||
video_frames = _read_from_stream(
|
||||
video_frames,
|
||||
container,
|
||||
start_pts,
|
||||
end_pts,
|
||||
pts_unit,
|
||||
container.streams.video[0],
|
||||
{"video": 0},
|
||||
filename=filename,
|
||||
)
|
||||
except av.AVError as e:
|
||||
print(f"[Warning] Error while reading video {filename}: {e}")
|
||||
|
||||
vframes = torch.from_numpy(video_frames).clone()
|
||||
del video_frames
|
||||
if output_format == "TCHW":
|
||||
# [T,H,W,C] --> [T,C,H,W]
|
||||
vframes = vframes.permute(0, 3, 1, 2)
|
||||
|
||||
aframes = torch.empty((1, 0), dtype=torch.float32)
|
||||
return vframes, aframes, info
|
||||
|
||||
|
||||
def _read_from_stream(
|
||||
video_frames,
|
||||
container: "av.container.Container",
|
||||
start_offset: float,
|
||||
end_offset: float,
|
||||
pts_unit: str,
|
||||
stream: "av.stream.Stream",
|
||||
stream_name: dict[str, int | tuple[int, ...] | list[int] | None],
|
||||
filename: str | None = None,
|
||||
) -> list["av.frame.Frame"]:
|
||||
if pts_unit == "sec":
|
||||
# TODO: we should change all of this from ground up to simply take
|
||||
# sec and convert to MS in C++
|
||||
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
|
||||
if end_offset != float("inf"):
|
||||
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
|
||||
else:
|
||||
warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
|
||||
|
||||
should_buffer = True
|
||||
max_buffer_size = 5
|
||||
if stream.type == "video":
|
||||
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
|
||||
# so need to buffer some extra frames to sort everything
|
||||
# properly
|
||||
extradata = stream.codec_context.extradata
|
||||
# overly complicated way of finding if `divx_packed` is set, following
|
||||
# https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
|
||||
if extradata and b"DivX" in extradata:
|
||||
# can't use regex directly because of some weird characters sometimes...
|
||||
pos = extradata.find(b"DivX")
|
||||
d = extradata[pos:]
|
||||
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
|
||||
if o is None:
|
||||
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
|
||||
if o is not None:
|
||||
should_buffer = o.group(3) == b"p"
|
||||
seek_offset = start_offset
|
||||
# some files don't seek to the right location, so better be safe here
|
||||
seek_offset = max(seek_offset - 1, 0)
|
||||
if should_buffer:
|
||||
# FIXME this is kind of a hack, but we will jump to the previous keyframe
|
||||
# so this will be safe
|
||||
seek_offset = max(seek_offset - max_buffer_size, 0)
|
||||
try:
|
||||
# TODO check if stream needs to always be the video stream here or not
|
||||
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
|
||||
except av.AVError as e:
|
||||
print(f"[Warning] Error while seeking video {filename}: {e}")
|
||||
return []
|
||||
|
||||
# == main ==
|
||||
buffer_count = 0
|
||||
frames_pts = []
|
||||
cnt = 0
|
||||
try:
|
||||
for _idx, frame in enumerate(container.decode(**stream_name)):
|
||||
frames_pts.append(frame.pts)
|
||||
video_frames[cnt] = frame.to_rgb().to_ndarray()
|
||||
cnt += 1
|
||||
if cnt >= len(video_frames):
|
||||
break
|
||||
if frame.pts >= end_offset:
|
||||
if should_buffer and buffer_count < max_buffer_size:
|
||||
buffer_count += 1
|
||||
continue
|
||||
break
|
||||
except av.AVError as e:
|
||||
print(f"[Warning] Error while reading video {filename}: {e}")
|
||||
|
||||
# garbage collection for thread leakage
|
||||
container.close()
|
||||
del container
|
||||
# NOTE: manually garbage collect to close pyav threads
|
||||
gc.collect()
|
||||
|
||||
# ensure that the results are sorted wrt the pts
|
||||
# NOTE: here we assert frames_pts is sorted
|
||||
start_ptr = 0
|
||||
end_ptr = cnt
|
||||
while start_ptr < end_ptr and frames_pts[start_ptr] < start_offset:
|
||||
start_ptr += 1
|
||||
while start_ptr < end_ptr and frames_pts[end_ptr - 1] > end_offset:
|
||||
end_ptr -= 1
|
||||
if start_offset > 0 and start_offset not in frames_pts[start_ptr:end_ptr]:
|
||||
# if there is no frame that exactly matches the pts of start_offset
|
||||
# add the last frame smaller than start_offset, to guarantee that
|
||||
# we will have all the necessary data. This is most useful for audio
|
||||
if start_ptr > 0:
|
||||
start_ptr -= 1
|
||||
result = video_frames[start_ptr:end_ptr].copy()
|
||||
return result
|
||||
|
||||
|
||||
def read_video_cv2(video_path):
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
if not cap.isOpened():
|
||||
# print("Error: Unable to open video")
|
||||
raise ValueError
|
||||
else:
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
vinfo = {
|
||||
"video_fps": fps,
|
||||
}
|
||||
|
||||
frames = []
|
||||
while True:
|
||||
# Read a frame from the video
|
||||
ret, frame = cap.read()
|
||||
|
||||
# If frame is not read correctly, break the loop
|
||||
if not ret:
|
||||
break
|
||||
|
||||
frames.append(frame[:, :, ::-1]) # BGR to RGB
|
||||
|
||||
# Exit if 'q' is pressed
|
||||
if cv2.waitKey(25) & 0xFF == ord("q"):
|
||||
break
|
||||
|
||||
# Release the video capture object and close all windows
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
frames = np.stack(frames)
|
||||
frames = torch.from_numpy(frames) # [T, H, W, C=3]
|
||||
frames = frames.permute(0, 3, 1, 2)
|
||||
return frames, vinfo
|
||||
|
||||
|
||||
def read_video(video_path, backend="av"):
|
||||
if backend == "cv2":
|
||||
vframes, vinfo = read_video_cv2(video_path)
|
||||
elif backend == "av":
|
||||
vframes, _, vinfo = read_video_av(filename=video_path, pts_unit="sec", output_format="TCHW")
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return vframes, vinfo
|
||||
|
|
@ -0,0 +1,393 @@
|
|||
from collections import OrderedDict, defaultdict
|
||||
from typing import Iterator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Dataset, DistributedSampler
|
||||
|
||||
from opensora.utils.logger import log_message
|
||||
from opensora.utils.misc import format_numel_str
|
||||
|
||||
from .aspect import get_num_pexels_from_name
|
||||
from .bucket import Bucket
|
||||
from .datasets import VideoTextDataset
|
||||
from .parallel import pandarallel
|
||||
from .utils import sync_object_across_devices
|
||||
|
||||
|
||||
# use pandarallel to accelerate bucket processing
|
||||
# NOTE: pandarallel should only access local variables
|
||||
def apply(data, method=None, seed=None, num_bucket=None, fps_max=16):
|
||||
return method(
|
||||
data["num_frames"],
|
||||
data["height"],
|
||||
data["width"],
|
||||
data["fps"],
|
||||
data["path"],
|
||||
seed + data["id"] * num_bucket,
|
||||
fps_max,
|
||||
)
|
||||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
num_replicas: int | None = None,
|
||||
rank: int | None = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
) -> None:
|
||||
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
||||
self.start_index: int = 0
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
iterator = super().__iter__()
|
||||
indices = list(iterator)
|
||||
indices = indices[self.start_index :]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples - self.start_index
|
||||
|
||||
def reset(self) -> None:
|
||||
self.start_index = 0
|
||||
|
||||
def state_dict(self, step) -> dict:
|
||||
return {"start_index": step}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
|
||||
class VariableVideoBatchSampler(DistributedSampler):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: VideoTextDataset,
|
||||
bucket_config: dict,
|
||||
num_replicas: int | None = None,
|
||||
rank: int | None = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
verbose: bool = False,
|
||||
num_bucket_build_workers: int = 1,
|
||||
num_groups: int = 1,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last
|
||||
)
|
||||
self.dataset = dataset
|
||||
assert dataset.bucket_class == "Bucket", "Only support Bucket class for now"
|
||||
self.bucket = Bucket(bucket_config)
|
||||
self.verbose = verbose
|
||||
self.last_micro_batch_access_index = 0
|
||||
self.num_bucket_build_workers = num_bucket_build_workers
|
||||
self._cached_bucket_sample_dict = None
|
||||
self._cached_num_total_batch = None
|
||||
self.num_groups = num_groups
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
pandarallel.initialize(
|
||||
nb_workers=self.num_bucket_build_workers,
|
||||
progress_bar=False,
|
||||
verbose=0,
|
||||
use_memory_fs=False,
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator[list[int]]:
|
||||
bucket_sample_dict, _ = self.group_by_bucket()
|
||||
self.clear_cache()
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
bucket_micro_batch_count = OrderedDict()
|
||||
bucket_last_consumed = OrderedDict()
|
||||
|
||||
# process the samples
|
||||
for bucket_id, data_list in bucket_sample_dict.items():
|
||||
# handle droplast
|
||||
bs_per_gpu = self.bucket.get_batch_size(bucket_id)
|
||||
remainder = len(data_list) % bs_per_gpu
|
||||
|
||||
if remainder > 0:
|
||||
if not self.drop_last:
|
||||
# if there is remainder, we pad to make it divisible
|
||||
data_list += data_list[: bs_per_gpu - remainder]
|
||||
else:
|
||||
# we just drop the remainder to make it divisible
|
||||
data_list = data_list[:-remainder]
|
||||
bucket_sample_dict[bucket_id] = data_list
|
||||
|
||||
# handle shuffle
|
||||
if self.shuffle:
|
||||
data_indices = torch.randperm(len(data_list), generator=g).tolist()
|
||||
data_list = [data_list[i] for i in data_indices]
|
||||
bucket_sample_dict[bucket_id] = data_list
|
||||
|
||||
# compute how many micro-batches each bucket has
|
||||
num_micro_batches = len(data_list) // bs_per_gpu
|
||||
bucket_micro_batch_count[bucket_id] = num_micro_batches
|
||||
|
||||
# compute the bucket access order
|
||||
# each bucket may have more than one batch of data
|
||||
# thus bucket_id may appear more than 1 time
|
||||
bucket_id_access_order = []
|
||||
for bucket_id, num_micro_batch in bucket_micro_batch_count.items():
|
||||
bucket_id_access_order.extend([bucket_id] * num_micro_batch)
|
||||
|
||||
# randomize the access order
|
||||
if self.shuffle:
|
||||
bucket_id_access_order_indices = torch.randperm(len(bucket_id_access_order), generator=g).tolist()
|
||||
bucket_id_access_order = [bucket_id_access_order[i] for i in bucket_id_access_order_indices]
|
||||
|
||||
# make the number of bucket accesses divisible by dp size
|
||||
remainder = len(bucket_id_access_order) % self.num_replicas
|
||||
if remainder > 0:
|
||||
if self.drop_last:
|
||||
bucket_id_access_order = bucket_id_access_order[: len(bucket_id_access_order) - remainder]
|
||||
else:
|
||||
bucket_id_access_order += bucket_id_access_order[: self.num_replicas - remainder]
|
||||
|
||||
# prepare each batch from its bucket
|
||||
# according to the predefined bucket access order
|
||||
num_iters = len(bucket_id_access_order) // self.num_replicas
|
||||
start_iter_idx = self.last_micro_batch_access_index // self.num_replicas
|
||||
|
||||
# re-compute the micro-batch consumption
|
||||
# this is useful when resuming from a state dict with a different number of GPUs
|
||||
self.last_micro_batch_access_index = start_iter_idx * self.num_replicas
|
||||
for i in range(self.last_micro_batch_access_index):
|
||||
bucket_id = bucket_id_access_order[i]
|
||||
bucket_bs = self.bucket.get_batch_size(bucket_id)
|
||||
if bucket_id in bucket_last_consumed:
|
||||
bucket_last_consumed[bucket_id] += bucket_bs
|
||||
else:
|
||||
bucket_last_consumed[bucket_id] = bucket_bs
|
||||
|
||||
for i in range(start_iter_idx, num_iters):
|
||||
bucket_access_list = bucket_id_access_order[i * self.num_replicas : (i + 1) * self.num_replicas]
|
||||
self.last_micro_batch_access_index += self.num_replicas
|
||||
|
||||
# compute the data samples consumed by each access
|
||||
bucket_access_boundaries = []
|
||||
for bucket_id in bucket_access_list:
|
||||
bucket_bs = self.bucket.get_batch_size(bucket_id)
|
||||
last_consumed_index = bucket_last_consumed.get(bucket_id, 0)
|
||||
bucket_access_boundaries.append([last_consumed_index, last_consumed_index + bucket_bs])
|
||||
|
||||
# update consumption
|
||||
if bucket_id in bucket_last_consumed:
|
||||
bucket_last_consumed[bucket_id] += bucket_bs
|
||||
else:
|
||||
bucket_last_consumed[bucket_id] = bucket_bs
|
||||
|
||||
# compute the range of data accessed by each GPU
|
||||
bucket_id = bucket_access_list[self.rank]
|
||||
boundary = bucket_access_boundaries[self.rank]
|
||||
cur_micro_batch = bucket_sample_dict[bucket_id][boundary[0] : boundary[1]]
|
||||
|
||||
# encode t, h, w into the sample index
|
||||
real_t, real_h, real_w = self.bucket.get_thw(bucket_id)
|
||||
cur_micro_batch = [f"{idx}-{real_t}-{real_h}-{real_w}" for idx in cur_micro_batch]
|
||||
yield cur_micro_batch
|
||||
|
||||
self.reset()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.get_num_batch() // self.num_groups
|
||||
|
||||
def get_num_batch(self) -> int:
|
||||
_, num_total_batch = self.group_by_bucket()
|
||||
return num_total_batch
|
||||
|
||||
def clear_cache(self):
|
||||
self._cached_bucket_sample_dict = None
|
||||
self._cached_num_total_batch = 0
|
||||
|
||||
def group_by_bucket(self) -> dict:
|
||||
"""
|
||||
Group the dataset samples into buckets.
|
||||
This method will set `self._cached_bucket_sample_dict` to the bucket sample dict.
|
||||
|
||||
Returns:
|
||||
dict: a dictionary with bucket id as key and a list of sample indices as value
|
||||
"""
|
||||
if self._cached_bucket_sample_dict is not None:
|
||||
return self._cached_bucket_sample_dict, self._cached_num_total_batch
|
||||
|
||||
# use pandarallel to accelerate bucket processing
|
||||
log_message("Building buckets using %d workers...", self.num_bucket_build_workers)
|
||||
bucket_ids = None
|
||||
if dist.get_rank() == 0:
|
||||
data = self.dataset.data.copy(deep=True)
|
||||
data["id"] = data.index
|
||||
bucket_ids = data.parallel_apply(
|
||||
apply,
|
||||
axis=1,
|
||||
method=self.bucket.get_bucket_id,
|
||||
seed=self.seed + self.epoch,
|
||||
num_bucket=self.bucket.num_bucket,
|
||||
fps_max=self.dataset.fps_max,
|
||||
)
|
||||
dist.barrier()
|
||||
bucket_ids = sync_object_across_devices(bucket_ids)
|
||||
dist.barrier()
|
||||
|
||||
# group by bucket
|
||||
# each data sample is put into a bucket with a similar image/video size
|
||||
bucket_sample_dict = defaultdict(list)
|
||||
bucket_ids_np = np.array(bucket_ids)
|
||||
valid_indices = np.where(bucket_ids_np != None)[0]
|
||||
for i in valid_indices:
|
||||
bucket_sample_dict[bucket_ids_np[i]].append(i)
|
||||
|
||||
# cache the bucket sample dict
|
||||
self._cached_bucket_sample_dict = bucket_sample_dict
|
||||
|
||||
# num total batch
|
||||
num_total_batch = self.print_bucket_info(bucket_sample_dict)
|
||||
self._cached_num_total_batch = num_total_batch
|
||||
|
||||
return bucket_sample_dict, num_total_batch
|
||||
|
||||
def print_bucket_info(self, bucket_sample_dict: dict) -> int:
|
||||
# collect statistics
|
||||
num_total_samples = num_total_batch = 0
|
||||
num_total_img_samples = num_total_vid_samples = 0
|
||||
num_total_img_batch = num_total_vid_batch = 0
|
||||
num_total_vid_batch_256 = num_total_vid_batch_768 = 0
|
||||
num_aspect_dict = defaultdict(lambda: [0, 0])
|
||||
num_hwt_dict = defaultdict(lambda: [0, 0])
|
||||
for k, v in bucket_sample_dict.items():
|
||||
size = len(v)
|
||||
num_batch = size // self.bucket.get_batch_size(k[:-1])
|
||||
|
||||
num_total_samples += size
|
||||
num_total_batch += num_batch
|
||||
|
||||
if k[1] == 1:
|
||||
num_total_img_samples += size
|
||||
num_total_img_batch += num_batch
|
||||
else:
|
||||
if k[0] == "256px":
|
||||
num_total_vid_batch_256 += num_batch
|
||||
elif k[0] == "768px":
|
||||
num_total_vid_batch_768 += num_batch
|
||||
num_total_vid_samples += size
|
||||
num_total_vid_batch += num_batch
|
||||
|
||||
num_aspect_dict[k[-1]][0] += size
|
||||
num_aspect_dict[k[-1]][1] += num_batch
|
||||
num_hwt_dict[k[:-1]][0] += size
|
||||
num_hwt_dict[k[:-1]][1] += num_batch
|
||||
|
||||
# sort
|
||||
num_aspect_dict = dict(sorted(num_aspect_dict.items(), key=lambda x: x[0]))
|
||||
num_hwt_dict = dict(
|
||||
sorted(num_hwt_dict.items(), key=lambda x: (get_num_pexels_from_name(x[0][0]), x[0][1]), reverse=True)
|
||||
)
|
||||
num_hwt_img_dict = {k: v for k, v in num_hwt_dict.items() if k[1] == 1}
|
||||
num_hwt_vid_dict = {k: v for k, v in num_hwt_dict.items() if k[1] > 1}
|
||||
|
||||
# log
|
||||
if dist.get_rank() == 0 and self.verbose:
|
||||
log_message("Bucket Info:")
|
||||
log_message("Bucket [#sample, #batch] by aspect ratio:")
|
||||
for k, v in num_aspect_dict.items():
|
||||
log_message("(%s): #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
|
||||
log_message("===== Image Info =====")
|
||||
log_message("Image Bucket by HxWxT:")
|
||||
for k, v in num_hwt_img_dict.items():
|
||||
log_message("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
|
||||
log_message("--------------------------------")
|
||||
log_message(
|
||||
"#image sample: %s, #image batch: %s",
|
||||
format_numel_str(num_total_img_samples),
|
||||
format_numel_str(num_total_img_batch),
|
||||
)
|
||||
log_message("===== Video Info =====")
|
||||
log_message("Video Bucket by HxWxT:")
|
||||
for k, v in num_hwt_vid_dict.items():
|
||||
log_message("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
|
||||
log_message("--------------------------------")
|
||||
log_message(
|
||||
"#video sample: %s, #video batch: %s",
|
||||
format_numel_str(num_total_vid_samples),
|
||||
format_numel_str(num_total_vid_batch),
|
||||
)
|
||||
log_message("===== Summary =====")
|
||||
log_message("#non-empty buckets: %s", len(bucket_sample_dict))
|
||||
log_message(
|
||||
"Img/Vid sample ratio: %.2f",
|
||||
num_total_img_samples / num_total_vid_samples if num_total_vid_samples > 0 else 0,
|
||||
)
|
||||
log_message(
|
||||
"Img/Vid batch ratio: %.2f", num_total_img_batch / num_total_vid_batch if num_total_vid_batch > 0 else 0
|
||||
)
|
||||
log_message(
|
||||
"vid batch 256: %s, vid batch 768: %s", format_numel_str(num_total_vid_batch_256), format_numel_str(num_total_vid_batch_768)
|
||||
)
|
||||
log_message(
|
||||
"Vid batch ratio (256px/768px): %.2f", num_total_vid_batch_256 / num_total_vid_batch_768 if num_total_vid_batch_768 > 0 else 0
|
||||
)
|
||||
log_message(
|
||||
"#training sample: %s, #training batch: %s",
|
||||
format_numel_str(num_total_samples),
|
||||
format_numel_str(num_total_batch),
|
||||
)
|
||||
return num_total_batch
|
||||
|
||||
def reset(self):
|
||||
self.last_micro_batch_access_index = 0
|
||||
|
||||
def set_step(self, start_step: int):
|
||||
self.last_micro_batch_access_index = start_step * self.num_replicas
|
||||
|
||||
def state_dict(self, num_steps: int) -> dict:
|
||||
# the last_micro_batch_access_index in the __iter__ is often
|
||||
# not accurate during multi-workers and data prefetching
|
||||
# thus, we need the user to pass the actual steps which have been executed
|
||||
# to calculate the correct last_micro_batch_access_index
|
||||
return {"seed": self.seed, "epoch": self.epoch, "last_micro_batch_access_index": num_steps * self.num_replicas}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
|
||||
class BatchDistributedSampler(DistributedSampler):
|
||||
"""
|
||||
Used with BatchDataset;
|
||||
Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then
|
||||
| buffer {i} | buffer {i+1}
|
||||
------ | ------------------- | -------------------
|
||||
rank 0 | 0, 1, 2, 3, 4, | 5, 6, 7, 8, 9
|
||||
rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19
|
||||
rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset, **kwargs):
|
||||
super().__init__(dataset, **kwargs)
|
||||
self.start_index = 0
|
||||
|
||||
def __iter__(self):
|
||||
num_buffers = self.dataset.num_buffers
|
||||
len_buffer = self.dataset.len_buffer
|
||||
num_buffers_i = num_buffers // self.num_replicas
|
||||
num_samples_i = len_buffer * num_buffers_i
|
||||
|
||||
indices_i = np.arange(self.start_index, num_samples_i) + self.rank * num_samples_i
|
||||
indices_i = indices_i.tolist()
|
||||
|
||||
return iter(indices_i)
|
||||
|
||||
def reset(self):
|
||||
self.start_index = 0
|
||||
|
||||
def state_dict(self, step) -> dict:
|
||||
return {"start_index": step}
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
self.start_index = state_dict["start_index"] + 1
|
||||
|
|
@ -0,0 +1,419 @@
|
|||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
||||
from torchvision.io import write_video
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from . import video_transforms
|
||||
from .read_video import read_video
|
||||
|
||||
try:
|
||||
import dask.dataframe as dd
|
||||
|
||||
SUPPORT_DASK = True
|
||||
except:
|
||||
SUPPORT_DASK = False
|
||||
|
||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||
|
||||
regex = re.compile(
|
||||
r"^(?:http|ftp)s?://" # http:// or https://
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
|
||||
r"localhost|" # localhost...
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
|
||||
r"(?::\d+)?" # optional port
|
||||
r"(?:/?|[/?]\S+)$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def is_img(path):
|
||||
ext = os.path.splitext(path)[-1].lower()
|
||||
return ext in IMG_EXTENSIONS
|
||||
|
||||
|
||||
def is_vid(path):
|
||||
ext = os.path.splitext(path)[-1].lower()
|
||||
return ext in VID_EXTENSIONS
|
||||
|
||||
|
||||
def is_url(url):
|
||||
return re.match(regex, url) is not None
|
||||
|
||||
|
||||
def read_file(input_path, memory_efficient=False):
|
||||
if input_path.endswith(".csv"):
|
||||
assert not memory_efficient, "Memory efficient mode is not supported for CSV files"
|
||||
return pd.read_csv(input_path)
|
||||
elif input_path.endswith(".parquet"):
|
||||
columns = None
|
||||
if memory_efficient:
|
||||
columns = ["path", "num_frames", "height", "width", "aspect_ratio", "fps", "resolution"]
|
||||
if SUPPORT_DASK:
|
||||
ret = dd.read_parquet(input_path, columns=columns).compute()
|
||||
else:
|
||||
ret = pd.read_parquet(input_path, columns=columns)
|
||||
return ret
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported file format: {input_path}")
|
||||
|
||||
|
||||
def download_url(input_path):
|
||||
output_dir = "cache"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
base_name = os.path.basename(input_path)
|
||||
output_path = os.path.join(output_dir, base_name)
|
||||
img_data = requests.get(input_path).content
|
||||
with open(output_path, "wb", encoding="utf-8") as handler:
|
||||
handler.write(img_data)
|
||||
print(f"URL {input_path} downloaded to {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def temporal_random_crop(
|
||||
vframes: torch.Tensor, num_frames: int, frame_interval: int, return_frame_indices: bool = False
|
||||
) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
|
||||
temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
|
||||
total_frames = len(vframes)
|
||||
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
|
||||
|
||||
assert (
|
||||
end_frame_ind - start_frame_ind >= num_frames
|
||||
), f"Not enough frames to sample, {end_frame_ind} - {start_frame_ind} < {num_frames}"
|
||||
|
||||
frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int)
|
||||
video = vframes[frame_indices]
|
||||
if return_frame_indices:
|
||||
return video, frame_indices
|
||||
else:
|
||||
return video
|
||||
|
||||
|
||||
def get_transforms_video(name="center", image_size=(256, 256)):
|
||||
if name is None:
|
||||
return None
|
||||
elif name == "center":
|
||||
assert image_size[0] == image_size[1], "image_size must be square for center crop"
|
||||
transform_video = transforms.Compose(
|
||||
[
|
||||
video_transforms.ToTensorVideo(), # TCHW
|
||||
# video_transforms.RandomHorizontalFlipVideo(),
|
||||
video_transforms.UCFCenterCropVideo(image_size[0]),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
)
|
||||
elif name == "resize_crop":
|
||||
transform_video = transforms.Compose(
|
||||
[
|
||||
video_transforms.ToTensorVideo(), # TCHW
|
||||
video_transforms.ResizeCrop(image_size),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
)
|
||||
elif name == "rand_size_crop":
|
||||
transform_video = transforms.Compose(
|
||||
[
|
||||
video_transforms.ToTensorVideo(), # TCHW
|
||||
video_transforms.RandomSizedCrop(image_size),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Transform {name} not implemented")
|
||||
return transform_video
|
||||
|
||||
|
||||
def get_transforms_image(name="center", image_size=(256, 256)):
|
||||
if name is None:
|
||||
return None
|
||||
elif name == "center":
|
||||
assert image_size[0] == image_size[1], "Image size must be square for center crop"
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])),
|
||||
# transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
)
|
||||
elif name == "resize_crop":
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
)
|
||||
elif name == "rand_size_crop":
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda pil_image: rand_size_crop_arr(pil_image, image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Transform {name} not implemented")
|
||||
return transform
|
||||
|
||||
|
||||
def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)):
|
||||
image = pil_loader(path)
|
||||
if transform is None:
|
||||
transform = get_transforms_image(image_size=image_size, name=transform_name)
|
||||
image = transform(image)
|
||||
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
return video
|
||||
|
||||
|
||||
def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)):
|
||||
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
if transform is None:
|
||||
transform = get_transforms_video(image_size=image_size, name=transform_name)
|
||||
video = transform(vframes) # T C H W
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
return video
|
||||
|
||||
|
||||
def read_from_path(path, image_size, transform_name="center"):
|
||||
if is_url(path):
|
||||
path = download_url(path)
|
||||
ext = os.path.splitext(path)[-1].lower()
|
||||
if ext.lower() in VID_EXTENSIONS:
|
||||
return read_video_from_path(path, image_size=image_size, transform_name=transform_name)
|
||||
else:
|
||||
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
||||
return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
|
||||
|
||||
|
||||
def save_sample(
|
||||
x,
|
||||
save_path=None,
|
||||
fps=8,
|
||||
normalize=True,
|
||||
value_range=(-1, 1),
|
||||
force_video=False,
|
||||
verbose=True,
|
||||
crf=23,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): shape [C, T, H, W]
|
||||
"""
|
||||
assert x.ndim == 4
|
||||
|
||||
if not force_video and x.shape[1] == 1: # T = 1: save as image
|
||||
save_path += ".png"
|
||||
x = x.squeeze(1)
|
||||
save_image([x], save_path, normalize=normalize, value_range=value_range)
|
||||
else:
|
||||
save_path += ".mp4"
|
||||
if normalize:
|
||||
low, high = value_range
|
||||
x.clamp_(min=low, max=high)
|
||||
x.sub_(low).div_(max(high - low, 1e-5))
|
||||
|
||||
x = x.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
|
||||
|
||||
write_video(save_path, x, fps=fps, video_codec="h264", options={"crf": str(crf)})
|
||||
if verbose:
|
||||
print(f"Saved to {save_path}")
|
||||
return save_path
|
||||
|
||||
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
|
||||
|
||||
|
||||
def rand_size_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Randomly crop image for height and width, ranging from image_size[0] to image_size[1]
|
||||
"""
|
||||
arr = np.array(pil_image)
|
||||
|
||||
# get random target h w
|
||||
height = random.randint(image_size[0], image_size[1])
|
||||
width = random.randint(image_size[0], image_size[1])
|
||||
# ensure that h w are factors of 8
|
||||
height = height - height % 8
|
||||
width = width - width % 8
|
||||
|
||||
# get random start pos
|
||||
h_start = random.randint(0, max(len(arr) - height, 0))
|
||||
w_start = random.randint(0, max(len(arr[0]) - height, 0))
|
||||
|
||||
# crop
|
||||
return Image.fromarray(arr[h_start : h_start + height, w_start : w_start + width])
|
||||
|
||||
|
||||
def resize_crop_to_fill(pil_image, image_size):
|
||||
w, h = pil_image.size # PIL is (W, H)
|
||||
th, tw = image_size
|
||||
rh, rw = th / h, tw / w
|
||||
if rh > rw:
|
||||
sh, sw = th, round(w * rh)
|
||||
image = pil_image.resize((sw, sh), Image.BICUBIC)
|
||||
i = 0
|
||||
j = int(round((sw - tw) / 2.0))
|
||||
else:
|
||||
sh, sw = round(h * rw), tw
|
||||
image = pil_image.resize((sw, sh), Image.BICUBIC)
|
||||
i = int(round((sh - th) / 2.0))
|
||||
j = 0
|
||||
arr = np.array(image)
|
||||
assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
|
||||
return Image.fromarray(arr[i : i + th, j : j + tw])
|
||||
|
||||
|
||||
def map_target_fps(
|
||||
fps: float,
|
||||
max_fps: float,
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
Map fps to a new fps that is less than max_fps.
|
||||
|
||||
Args:
|
||||
fps (float): Original fps.
|
||||
max_fps (float): Maximum fps.
|
||||
|
||||
Returns:
|
||||
tuple[float, int]: New fps and sampling interval.
|
||||
"""
|
||||
if math.isnan(fps):
|
||||
return 0, 1
|
||||
if fps < max_fps:
|
||||
return fps, 1
|
||||
sampling_interval = math.ceil(fps / max_fps)
|
||||
new_fps = math.floor(fps / sampling_interval)
|
||||
return new_fps, sampling_interval
|
||||
|
||||
|
||||
def sync_object_across_devices(obj: Any, rank: int = 0):
|
||||
"""
|
||||
Synchronizes any picklable object across devices in a PyTorch distributed setting
|
||||
using `broadcast_object_list` with CUDA support.
|
||||
|
||||
Parameters:
|
||||
obj (Any): The object to synchronize. Can be any picklable object (e.g., list, dict, custom class).
|
||||
rank (int): The rank of the device from which to broadcast the object state. Default is 0.
|
||||
|
||||
Note: Ensure torch.distributed is initialized before using this function and CUDA is available.
|
||||
"""
|
||||
|
||||
# Move the object to a list for broadcasting
|
||||
object_list = [obj]
|
||||
|
||||
# Broadcast the object list from the source rank to all other ranks
|
||||
dist.broadcast_object_list(object_list, src=rank, device="cuda")
|
||||
|
||||
# Retrieve the synchronized object
|
||||
obj = object_list[0]
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def rescale_image_by_path(path: str, height: int, width: int):
|
||||
"""
|
||||
Rescales an image to the specified height and width and saves it back to the original path.
|
||||
|
||||
Args:
|
||||
path (str): The file path of the image.
|
||||
height (int): The target height of the image.
|
||||
width (int): The target width of the image.
|
||||
"""
|
||||
try:
|
||||
# read image
|
||||
image = Image.open(path)
|
||||
|
||||
# check if image is valid
|
||||
if image is None:
|
||||
raise ValueError("The image is invalid or empty.")
|
||||
|
||||
# resize image
|
||||
resize_transform = transforms.Resize((width, height))
|
||||
resized_image = resize_transform(image)
|
||||
|
||||
# save resized image back to the original path
|
||||
resized_image.save(path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error rescaling image: {e}")
|
||||
|
||||
|
||||
def rescale_video_by_path(path: str, height: int, width: int):
|
||||
"""
|
||||
Rescales an MP4 video (without audio) to the specified height and width.
|
||||
|
||||
Args:
|
||||
path (str): The file path of the video.
|
||||
height (int): The target height of the video.
|
||||
width (int): The target width of the video.
|
||||
"""
|
||||
try:
|
||||
# Read video and metadata
|
||||
video, info = read_video(path, backend="av")
|
||||
|
||||
# Check if video is valid
|
||||
if video is None or video.size(0) == 0:
|
||||
raise ValueError("The video is invalid or empty.")
|
||||
|
||||
# Resize video frames using a performant method
|
||||
resize_transform = transforms.Compose([transforms.Resize((width, height))])
|
||||
resized_video = torch.stack([resize_transform(frame) for frame in video])
|
||||
|
||||
# Save resized video back to the original path
|
||||
resized_video = resized_video.permute(0, 2, 3, 1)
|
||||
write_video(path, resized_video, fps=int(info["video_fps"]), video_codec="h264")
|
||||
except Exception as e:
|
||||
print(f"Error rescaling video: {e}")
|
||||
|
||||
|
||||
def save_tensor_to_disk(tensor, path, exist_handling="overwrite"):
|
||||
if os.path.exists(path):
|
||||
if exist_handling == "ignore":
|
||||
return
|
||||
elif exist_handling == "raise":
|
||||
raise UserWarning(f"File {path} already exists, rewriting!")
|
||||
torch.save(tensor, path)
|
||||
|
||||
|
||||
def save_tensor_to_internet(tensor, path):
|
||||
raise NotImplementedError("save_tensor_to_internet is not implemented yet!")
|
||||
|
||||
|
||||
def save_latent(latent, path, exist_handling="overwrite"):
|
||||
if path.startswith(("http://", "https://", "ftp://", "sftp://")):
|
||||
save_tensor_to_internet(latent, path)
|
||||
else:
|
||||
save_tensor_to_disk(latent, path, exist_handling=exist_handling)
|
||||
|
||||
|
||||
def cache_latents(latents, path, exist_handling="overwrite"):
|
||||
for i in range(latents.shape[0]):
|
||||
save_latent(latents[i], path[i], exist_handling=exist_handling)
|
||||
|
|
@ -0,0 +1,595 @@
|
|||
# Copyright 2024 Vchitect/Latte
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.# Modified from Latte
|
||||
|
||||
import numbers
|
||||
|
||||
# - This file is adapted from https://github.com/Vchitect/Latte/blob/main/datasets/video_transforms.py
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def _is_tensor_video_clip(clip):
|
||||
if not torch.is_tensor(clip):
|
||||
raise TypeError("clip should be Tensor. Got %s" % type(clip))
|
||||
|
||||
if not clip.ndimension() == 4:
|
||||
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def crop(clip, i, j, h, w):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||
"""
|
||||
if len(clip.size()) != 4:
|
||||
raise ValueError("clip should be a 4D tensor")
|
||||
return clip[..., i : i + h, j : j + w]
|
||||
|
||||
|
||||
def resize(clip, target_size, interpolation_mode):
|
||||
if len(target_size) != 2:
|
||||
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
||||
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
|
||||
|
||||
|
||||
def resize_scale(clip, target_size, interpolation_mode):
|
||||
if len(target_size) != 2:
|
||||
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
||||
H, W = clip.size(-2), clip.size(-1)
|
||||
scale_ = target_size[0] / min(H, W)
|
||||
th, tw = int(round(H * scale_)), int(round(W * scale_))
|
||||
return torch.nn.functional.interpolate(clip, size=(th, tw), mode=interpolation_mode, align_corners=False)
|
||||
|
||||
|
||||
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
|
||||
"""
|
||||
Do spatial cropping and resizing to the video clip
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
||||
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
||||
h (int): Height of the cropped region.
|
||||
w (int): Width of the cropped region.
|
||||
size (tuple(int, int)): height and width of resized clip
|
||||
Returns:
|
||||
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
|
||||
"""
|
||||
if not _is_tensor_video_clip(clip):
|
||||
raise ValueError("clip should be a 4D torch.tensor")
|
||||
clip = crop(clip, i, j, h, w)
|
||||
clip = resize(clip, size, interpolation_mode)
|
||||
return clip
|
||||
|
||||
|
||||
def center_crop(clip, crop_size):
|
||||
if not _is_tensor_video_clip(clip):
|
||||
raise ValueError("clip should be a 4D torch.tensor")
|
||||
h, w = clip.size(-2), clip.size(-1)
|
||||
th, tw = crop_size
|
||||
if h < th or w < tw:
|
||||
raise ValueError("height and width must be no smaller than crop_size")
|
||||
|
||||
i = int(round((h - th) / 2.0))
|
||||
j = int(round((w - tw) / 2.0))
|
||||
return crop(clip, i, j, th, tw)
|
||||
|
||||
|
||||
def center_crop_using_short_edge(clip):
|
||||
if not _is_tensor_video_clip(clip):
|
||||
raise ValueError("clip should be a 4D torch.tensor")
|
||||
h, w = clip.size(-2), clip.size(-1)
|
||||
if h < w:
|
||||
th, tw = h, h
|
||||
i = 0
|
||||
j = int(round((w - tw) / 2.0))
|
||||
else:
|
||||
th, tw = w, w
|
||||
i = int(round((h - th) / 2.0))
|
||||
j = 0
|
||||
return crop(clip, i, j, th, tw)
|
||||
|
||||
|
||||
def resize_crop_to_fill(clip, target_size):
|
||||
if not _is_tensor_video_clip(clip):
|
||||
raise ValueError("clip should be a 4D torch.tensor")
|
||||
h, w = clip.size(-2), clip.size(-1)
|
||||
th, tw = target_size[0], target_size[1]
|
||||
rh, rw = th / h, tw / w
|
||||
if rh > rw:
|
||||
sh, sw = th, round(w * rh)
|
||||
clip = resize(clip, (sh, sw), "bilinear")
|
||||
i = 0
|
||||
j = int(round(sw - tw) / 2.0)
|
||||
else:
|
||||
sh, sw = round(h * rw), tw
|
||||
clip = resize(clip, (sh, sw), "bilinear")
|
||||
i = int(round(sh - th) / 2.0)
|
||||
j = 0
|
||||
assert i + th <= clip.size(-2) and j + tw <= clip.size(-1)
|
||||
return crop(clip, i, j, th, tw)
|
||||
|
||||
|
||||
# def rand_crop_h_w(clip, target_size_range, multiples_of: int = 8):
|
||||
# # NOTE: for some reason, if don't re-import, gives same randint results
|
||||
# import sys
|
||||
|
||||
# del sys.modules["random"]
|
||||
# import random
|
||||
|
||||
# if not _is_tensor_video_clip(clip):
|
||||
# raise ValueError("clip should be a 4D torch.tensor")
|
||||
# h, w = clip.size(-2), clip.size(-1)
|
||||
|
||||
# # get random target h w
|
||||
# th = random.randint(target_size_range[0], target_size_range[1])
|
||||
# tw = random.randint(target_size_range[0], target_size_range[1])
|
||||
|
||||
# # ensure that h w are factors of 8
|
||||
# th = th - th % multiples_of
|
||||
# tw = tw - tw % multiples_of
|
||||
|
||||
# # get random start pos
|
||||
# i = random.randint(0, h-th) if h > th else 0
|
||||
# j = random.randint(0, w-tw) if w > tw else 0
|
||||
|
||||
# th = th if th < h else h
|
||||
# tw = tw if tw < w else w
|
||||
|
||||
# # print("target size range:",target_size_range)
|
||||
# # print("original size:", h, w)
|
||||
# # print("crop size:", th, tw)
|
||||
# # print(f"crop:{i}-{i+th}, {j}-{j+tw}")
|
||||
|
||||
# return (crop(clip, i, j, th, tw), th, tw)
|
||||
|
||||
|
||||
def random_shift_crop(clip):
|
||||
"""
|
||||
Slide along the long edge, with the short edge as crop size
|
||||
"""
|
||||
if not _is_tensor_video_clip(clip):
|
||||
raise ValueError("clip should be a 4D torch.tensor")
|
||||
h, w = clip.size(-2), clip.size(-1)
|
||||
|
||||
if h <= w:
|
||||
short_edge = h
|
||||
else:
|
||||
short_edge = w
|
||||
|
||||
th, tw = short_edge, short_edge
|
||||
|
||||
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
||||
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
||||
return crop(clip, i, j, th, tw)
|
||||
|
||||
|
||||
def to_tensor(clip):
|
||||
"""
|
||||
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
||||
permute the dimensions of clip tensor
|
||||
Args:
|
||||
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
||||
Return:
|
||||
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
||||
"""
|
||||
_is_tensor_video_clip(clip)
|
||||
if not clip.dtype == torch.uint8:
|
||||
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
|
||||
# return clip.float().permute(3, 0, 1, 2) / 255.0
|
||||
return clip.float() / 255.0
|
||||
|
||||
|
||||
def normalize(clip, mean, std, inplace=False):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
||||
mean (tuple): pixel RGB mean. Size is (3)
|
||||
std (tuple): pixel standard deviation. Size is (3)
|
||||
Returns:
|
||||
normalized clip (torch.tensor): Size is (T, C, H, W)
|
||||
"""
|
||||
if not _is_tensor_video_clip(clip):
|
||||
raise ValueError("clip should be a 4D torch.tensor")
|
||||
if not inplace:
|
||||
clip = clip.clone()
|
||||
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
|
||||
# print(mean)
|
||||
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
|
||||
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
||||
return clip
|
||||
|
||||
|
||||
def hflip(clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
||||
Returns:
|
||||
flipped clip (torch.tensor): Size is (T, C, H, W)
|
||||
"""
|
||||
if not _is_tensor_video_clip(clip):
|
||||
raise ValueError("clip should be a 4D torch.tensor")
|
||||
return clip.flip(-1)
|
||||
|
||||
|
||||
class ResizeCrop:
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, clip):
|
||||
clip = resize_crop_to_fill(clip, self.size)
|
||||
return clip
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(size={self.size})"
|
||||
|
||||
|
||||
class RandomSizedCrop:
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, clip):
|
||||
i, j, h, w = self.get_params(clip)
|
||||
# self.size = (h, w)
|
||||
return crop(clip, i, j, h, w)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(size={self.size})"
|
||||
|
||||
def get_params(self, clip, multiples_of=8):
|
||||
h, w = clip.shape[-2:]
|
||||
|
||||
# get random target h w
|
||||
th = random.randint(self.size[0], self.size[1])
|
||||
tw = random.randint(self.size[0], self.size[1])
|
||||
# ensure that h w are factors of 8
|
||||
th = th - th % multiples_of
|
||||
tw = tw - tw % multiples_of
|
||||
|
||||
if h < th:
|
||||
th = h - h % multiples_of
|
||||
if w < tw:
|
||||
tw = w - w % multiples_of
|
||||
|
||||
if w == tw and h == th:
|
||||
return 0, 0, h, w
|
||||
|
||||
else:
|
||||
# get random start pos
|
||||
i = random.randint(0, h - th)
|
||||
j = random.randint(0, w - tw)
|
||||
|
||||
return i, j, th, tw
|
||||
|
||||
|
||||
class RandomCropVideo:
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||
Returns:
|
||||
torch.tensor: randomly cropped video clip.
|
||||
size is (T, C, OH, OW)
|
||||
"""
|
||||
i, j, h, w = self.get_params(clip)
|
||||
return crop(clip, i, j, h, w)
|
||||
|
||||
def get_params(self, clip):
|
||||
h, w = clip.shape[-2:]
|
||||
th, tw = self.size
|
||||
|
||||
if h < th or w < tw:
|
||||
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
|
||||
|
||||
if w == tw and h == th:
|
||||
return 0, 0, h, w
|
||||
|
||||
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
||||
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
||||
|
||||
return i, j, th, tw
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(size={self.size})"
|
||||
|
||||
|
||||
class CenterCropResizeVideo:
|
||||
"""
|
||||
First use the short side for cropping length,
|
||||
center crop video, then resize to the specified size
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
interpolation_mode="bilinear",
|
||||
):
|
||||
if isinstance(size, tuple):
|
||||
if len(size) != 2:
|
||||
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
|
||||
self.interpolation_mode = interpolation_mode
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||
Returns:
|
||||
torch.tensor: scale resized / center cropped video clip.
|
||||
size is (T, C, crop_size, crop_size)
|
||||
"""
|
||||
clip_center_crop = center_crop_using_short_edge(clip)
|
||||
clip_center_crop_resize = resize(
|
||||
clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode
|
||||
)
|
||||
return clip_center_crop_resize
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
||||
|
||||
|
||||
class UCFCenterCropVideo:
|
||||
"""
|
||||
First scale to the specified size in equal proportion to the short edge,
|
||||
then center cropping
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
interpolation_mode="bilinear",
|
||||
):
|
||||
if isinstance(size, tuple):
|
||||
if len(size) != 2:
|
||||
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
|
||||
self.interpolation_mode = interpolation_mode
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||
Returns:
|
||||
torch.tensor: scale resized / center cropped video clip.
|
||||
size is (T, C, crop_size, crop_size)
|
||||
"""
|
||||
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
||||
clip_center_crop = center_crop(clip_resize, self.size)
|
||||
return clip_center_crop
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
||||
|
||||
|
||||
class KineticsRandomCropResizeVideo:
|
||||
"""
|
||||
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
interpolation_mode="bilinear",
|
||||
):
|
||||
if isinstance(size, tuple):
|
||||
if len(size) != 2:
|
||||
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
|
||||
self.interpolation_mode = interpolation_mode
|
||||
|
||||
def __call__(self, clip):
|
||||
clip_random_crop = random_shift_crop(clip)
|
||||
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
|
||||
return clip_resize
|
||||
|
||||
|
||||
class CenterCropVideo:
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
interpolation_mode="bilinear",
|
||||
):
|
||||
if isinstance(size, tuple):
|
||||
if len(size) != 2:
|
||||
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
|
||||
self.interpolation_mode = interpolation_mode
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||
Returns:
|
||||
torch.tensor: center cropped video clip.
|
||||
size is (T, C, crop_size, crop_size)
|
||||
"""
|
||||
clip_center_crop = center_crop(clip, self.size)
|
||||
return clip_center_crop
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
||||
|
||||
|
||||
class NormalizeVideo:
|
||||
"""
|
||||
Normalize the video clip by mean subtraction and division by standard deviation
|
||||
Args:
|
||||
mean (3-tuple): pixel RGB mean
|
||||
std (3-tuple): pixel RGB standard deviation
|
||||
inplace (boolean): whether do in-place normalization
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, inplace=False):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.inplace = inplace
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
|
||||
"""
|
||||
return normalize(clip, self.mean, self.std, self.inplace)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
|
||||
|
||||
|
||||
class ToTensorVideo:
|
||||
"""
|
||||
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
||||
permute the dimensions of clip tensor
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
||||
Return:
|
||||
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
||||
"""
|
||||
return to_tensor(clip)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
class RandomHorizontalFlipVideo:
|
||||
"""
|
||||
Flip the video clip along the horizontal direction with a given probability
|
||||
Args:
|
||||
p (float): probability of the clip being flipped. Default value is 0.5
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, clip):
|
||||
"""
|
||||
Args:
|
||||
clip (torch.tensor): Size is (T, C, H, W)
|
||||
Return:
|
||||
clip (torch.tensor): Size is (T, C, H, W)
|
||||
"""
|
||||
if random.random() < self.p:
|
||||
clip = hflip(clip)
|
||||
return clip
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(p={self.p})"
|
||||
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# --------------------- Sampling ---------------------------
|
||||
# ------------------------------------------------------------
|
||||
class TemporalRandomCrop(object):
|
||||
"""Temporally crop the given frame indices at a random location.
|
||||
|
||||
Args:
|
||||
size (int): Desired length of frames will be seen in the model.
|
||||
"""
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, total_frames):
|
||||
rand_end = max(0, total_frames - self.size - 1)
|
||||
begin_index = random.randint(0, rand_end)
|
||||
end_index = min(begin_index + self.size, total_frames)
|
||||
return begin_index, end_index
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torchvision.io as io
|
||||
from torchvision import transforms
|
||||
from torchvision.utils import save_image
|
||||
|
||||
vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW")
|
||||
|
||||
trans = transforms.Compose(
|
||||
[
|
||||
ToTensorVideo(),
|
||||
RandomHorizontalFlipVideo(),
|
||||
UCFCenterCropVideo(512),
|
||||
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
)
|
||||
|
||||
target_video_len = 32
|
||||
frame_interval = 1
|
||||
total_frames = len(vframes)
|
||||
print(total_frames)
|
||||
|
||||
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
|
||||
|
||||
# Sampling video frames
|
||||
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
|
||||
# print(start_frame_ind)
|
||||
# print(end_frame_ind)
|
||||
assert end_frame_ind - start_frame_ind >= target_video_len
|
||||
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
|
||||
print(frame_indice)
|
||||
|
||||
select_vframes = vframes[frame_indice]
|
||||
print(select_vframes.shape)
|
||||
print(select_vframes.dtype)
|
||||
|
||||
select_vframes_trans = trans(select_vframes)
|
||||
print(select_vframes_trans.shape)
|
||||
print(select_vframes_trans.dtype)
|
||||
|
||||
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
|
||||
print(select_vframes_trans_int.dtype)
|
||||
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
|
||||
|
||||
io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
|
||||
|
||||
for i in range(target_video_len):
|
||||
save_image(
|
||||
select_vframes_trans[i], os.path.join("./test000", "%04d.png" % i), normalize=True, value_range=(-1, 1)
|
||||
)
|
||||
|
|
@ -0,0 +1,282 @@
|
|||
# Dataset Management
|
||||
|
||||
- [Dataset Management](#dataset-management)
|
||||
- [Dataset Format](#dataset-format)
|
||||
- [Dataset to CSV](#dataset-to-csv)
|
||||
- [Manage datasets](#manage-datasets)
|
||||
- [Requirement](#requirement)
|
||||
- [Basic Usage](#basic-usage)
|
||||
- [Score filtering](#score-filtering)
|
||||
- [Documentation](#documentation)
|
||||
- [Transform datasets](#transform-datasets)
|
||||
- [Resize](#resize)
|
||||
- [Frame extraction](#frame-extraction)
|
||||
- [Crop Midjourney 4 grid](#crop-midjourney-4-grid)
|
||||
- [Analyze datasets](#analyze-datasets)
|
||||
- [Data Process Pipeline](#data-process-pipeline)
|
||||
|
||||
After preparing the raw dataset according to the [instructions](/docs/datasets.md), you can use the following commands to manage the dataset.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
All dataset should be provided in a `.csv` file (or `parquet.gzip` to save space), which is used for both training and data preprocessing. The columns should follow the words below:
|
||||
|
||||
- `path`: the relative/absolute path or url to the image or video file. Required.
|
||||
- `text`: the caption or description of the image or video. Required for training.
|
||||
- `num_frames`: the number of frames in the video. Required for training.
|
||||
- `width`: the width of the video frame. Required for dynamic bucket.
|
||||
- `height`: the height of the video frame. Required for dynamic bucket.
|
||||
- `aspect_ratio`: the aspect ratio of the video frame (height / width). Required for dynamic bucket.
|
||||
- `resolution`: height x width. For analysis.
|
||||
- `text_len`: the number of tokens in the text. For analysis.
|
||||
- `aes`: aesthetic score calculated by [asethetic scorer](/tools/aesthetic/README.md). For filtering.
|
||||
- `flow`: optical flow score calculated by [UniMatch](/tools/scoring/README.md). For filtering.
|
||||
- `match`: matching score of a image-text/video-text pair calculated by [CLIP](/tools/scoring/README.md). For filtering.
|
||||
- `fps`: the frame rate of the video. Optional.
|
||||
- `cmotion`: the camera motion.
|
||||
|
||||
An example ready for training:
|
||||
|
||||
```csv
|
||||
path, text, num_frames, width, height, aspect_ratio
|
||||
/absolute/path/to/image1.jpg, caption, 1, 720, 1280, 0.5625
|
||||
/absolute/path/to/video1.mp4, caption, 120, 720, 1280, 0.5625
|
||||
/absolute/path/to/video2.mp4, caption, 20, 256, 256, 1
|
||||
```
|
||||
|
||||
We use pandas to manage the `.csv` or `.parquet` files. The following code is for reading and writing files:
|
||||
|
||||
```python
|
||||
df = pd.read_csv(input_path)
|
||||
df = df.to_csv(output_path, index=False)
|
||||
# or use parquet, which is smaller
|
||||
df = pd.read_parquet(input_path)
|
||||
df = df.to_parquet(output_path, index=False)
|
||||
```
|
||||
|
||||
## Dataset to CSV
|
||||
|
||||
As a start point, `convert.py` is used to convert the dataset to a CSV file. You can use the following commands to convert the dataset to a CSV file:
|
||||
|
||||
```bash
|
||||
python -m tools.datasets.convert DATASET-TYPE DATA_FOLDER
|
||||
|
||||
# general video folder
|
||||
python -m tools.datasets.convert video VIDEO_FOLDER --output video.csv
|
||||
# general image folder
|
||||
python -m tools.datasets.convert image IMAGE_FOLDER --output image.csv
|
||||
# imagenet
|
||||
python -m tools.datasets.convert imagenet IMAGENET_FOLDER --split train
|
||||
# ucf101
|
||||
python -m tools.datasets.convert ucf101 UCF101_FOLDER --split videos
|
||||
# vidprom
|
||||
python -m tools.datasets.convert vidprom VIDPROM_FOLDER --info VidProM_semantic_unique.csv
|
||||
```
|
||||
|
||||
## Manage datasets
|
||||
|
||||
Use `datautil` to manage the dataset.
|
||||
|
||||
### Requirement
|
||||
|
||||
Follow our [installation guide](../../docs/installation.md)'s "Data Dependencies" and "Datasets" section to install the required packages.
|
||||
<!-- To accelerate processing speed, you can install [pandarallel](https://github.com/nalepae/pandarallel):
|
||||
|
||||
```bash
|
||||
pip install pandarallel
|
||||
``` -->
|
||||
|
||||
<!-- To get image and video information, you need to install [opencv-python](https://github.com/opencv/opencv-python): -->
|
||||
|
||||
<!-- ```bash
|
||||
pip install opencv-python
|
||||
# If your videos are in av1 codec instead of h264, you need to
|
||||
# - install ffmpeg first
|
||||
# - install via conda to support av1 codec
|
||||
conda install -c conda-forge opencv
|
||||
``` -->
|
||||
|
||||
<!-- Or to get video information, you can install ffmpeg and ffmpeg-python:
|
||||
|
||||
```bash
|
||||
pip install ffmpeg-python
|
||||
``` -->
|
||||
|
||||
<!-- To filter a specific language, you need to install [lingua](https://github.com/pemistahl/lingua-py):
|
||||
|
||||
```bash
|
||||
pip install lingua-language-detector
|
||||
``` -->
|
||||
|
||||
### Basic Usage
|
||||
|
||||
You can use the following commands to process the `csv` or `parquet` files. The output file will be saved in the same directory as the input, with different suffixes indicating the processed method.
|
||||
|
||||
```bash
|
||||
# datautil takes multiple CSV files as input and merge them into one CSV file
|
||||
# output: DATA1+DATA2.csv
|
||||
python -m tools.datasets.datautil DATA1.csv DATA2.csv
|
||||
|
||||
# shard CSV files into multiple CSV files
|
||||
# output: DATA1_0.csv, DATA1_1.csv, ...
|
||||
python -m tools.datasets.datautil DATA1.csv --shard 10
|
||||
|
||||
# filter frames between 128 and 256, with captions
|
||||
# output: DATA1_fmin_128_fmax_256.csv
|
||||
python -m tools.datasets.datautil DATA.csv --fmin 128 --fmax 256
|
||||
|
||||
# Disable parallel processing
|
||||
python -m tools.datasets.datautil DATA.csv --fmin 128 --fmax 256 --disable-parallel
|
||||
|
||||
# Compute num_frames, height, width, fps, aspect_ratio for videos or images
|
||||
# output: IMG_DATA+VID_DATA_vinfo.csv
|
||||
python -m tools.datasets.datautil IMG_DATA.csv VID_DATA.csv --video-info
|
||||
|
||||
# You can run multiple operations at the same time.
|
||||
python -m tools.datasets.datautil DATA.csv --video-info --remove-empty-caption --remove-url --lang en
|
||||
```
|
||||
|
||||
### Score filtering
|
||||
|
||||
To examine and filter the quality of the dataset by aesthetic score and clip score, you can use the following commands:
|
||||
|
||||
```bash
|
||||
# sort the dataset by aesthetic score
|
||||
# output: DATA_sort.csv
|
||||
python -m tools.datasets.datautil DATA.csv --sort aesthetic_score
|
||||
# View examples of high aesthetic score
|
||||
head -n 10 DATA_sort.csv
|
||||
# View examples of low aesthetic score
|
||||
tail -n 10 DATA_sort.csv
|
||||
|
||||
# sort the dataset by clip score
|
||||
# output: DATA_sort.csv
|
||||
python -m tools.datasets.datautil DATA.csv --sort clip_score
|
||||
|
||||
# filter the dataset by aesthetic score
|
||||
# output: DATA_aesmin_0.5.csv
|
||||
python -m tools.datasets.datautil DATA.csv --aesmin 0.5
|
||||
# filter the dataset by clip score
|
||||
# output: DATA_matchmin_0.5.csv
|
||||
python -m tools.datasets.datautil DATA.csv --matchmin 0.5
|
||||
```
|
||||
|
||||
### Documentation
|
||||
|
||||
You can also use `python -m tools.datasets.datautil --help` to see usage.
|
||||
|
||||
| Args | File suffix | Description |
|
||||
| --------------------------- | -------------- | ------------------------------------------------------------- |
|
||||
| `--output OUTPUT` | | Output path |
|
||||
| `--format FORMAT` | | Output format (csv, parquet, parquet.gzip) |
|
||||
| `--disable-parallel` | | Disable `pandarallel` |
|
||||
| `--seed SEED` | | Random seed |
|
||||
| `--shard SHARD` | `_0`,`_1`, ... | Shard the dataset |
|
||||
| `--sort KEY` | `_sort` | Sort the dataset by KEY |
|
||||
| `--sort-descending KEY` | `_sort` | Sort the dataset by KEY in descending order |
|
||||
| `--difference DATA.csv` | | Remove the paths in DATA.csv from the dataset |
|
||||
| `--intersection DATA.csv` | | Keep the paths in DATA.csv from the dataset and merge columns |
|
||||
| `--info` | `_info` | Get the basic information of each video and image (cv2) |
|
||||
| `--ext` | `_ext` | Remove rows if the file does not exist |
|
||||
| `--relpath` | `_relpath` | Modify the path to relative path by root given |
|
||||
| `--abspath` | `_abspath` | Modify the path to absolute path by root given |
|
||||
| `--remove-empty-caption` | `_noempty` | Remove rows with empty caption |
|
||||
| `--remove-url` | `_nourl` | Remove rows with url in caption |
|
||||
| `--lang LANG` | `_lang` | Remove rows with other language |
|
||||
| `--remove-path-duplication` | `_noduppath` | Remove rows with duplicated path |
|
||||
| `--remove-text-duplication` | `_noduptext` | Remove rows with duplicated caption |
|
||||
| `--refine-llm-caption` | `_llm` | Modify the caption generated by LLM |
|
||||
| `--clean-caption MODEL` | `_clean` | Modify the caption according to T5 pipeline to suit training |
|
||||
| `--unescape` | `_unescape` | Unescape the caption |
|
||||
| `--merge-cmotion` | `_cmotion` | Merge the camera motion to the caption |
|
||||
| `--count-num-token` | `_ntoken` | Count the number of tokens in the caption |
|
||||
| `--load-caption EXT` | `_load` | Load the caption from the file |
|
||||
| `--fmin FMIN` | `_fmin` | Filter the dataset by minimum number of frames |
|
||||
| `--fmax FMAX` | `_fmax` | Filter the dataset by maximum number of frames |
|
||||
| `--hwmax HWMAX` | `_hwmax` | Filter the dataset by maximum height x width |
|
||||
| `--aesmin AESMIN` | `_aesmin` | Filter the dataset by minimum aesthetic score |
|
||||
| `--matchmin MATCHMIN` | `_matchmin` | Filter the dataset by minimum clip score |
|
||||
| `--flowmin FLOWMIN` | `_flowmin` | Filter the dataset by minimum optical flow score |
|
||||
|
||||
## Transform datasets
|
||||
|
||||
The `tools.datasets.transform` module provides a set of tools to transform the dataset. The general usage is as follows:
|
||||
|
||||
```bash
|
||||
python -m tools.datasets.transform TRANSFORM_TYPE META.csv ORIGINAL_DATA_FOLDER DATA_FOLDER_TO_SAVE_RESULTS --additional-args
|
||||
```
|
||||
|
||||
### Resize
|
||||
|
||||
Sometimes you may need to resize the images or videos to a specific resolution. You can use the following commands to resize the dataset:
|
||||
|
||||
```bash
|
||||
python -m tools.datasets.transform meta.csv /path/to/raw/data /path/to/new/data --length 2160
|
||||
```
|
||||
|
||||
### Frame extraction
|
||||
|
||||
To extract frames from videos, you can use the following commands:
|
||||
|
||||
```bash
|
||||
python -m tools.datasets.transform vid_frame_extract meta.csv /path/to/raw/data /path/to/new/data --points 0.1 0.5 0.9
|
||||
```
|
||||
|
||||
### Crop Midjourney 4 grid
|
||||
|
||||
Randomly select one of the 4 images in the 4 grid generated by Midjourney.
|
||||
|
||||
```bash
|
||||
python -m tools.datasets.transform img_rand_crop meta.csv /path/to/raw/data /path/to/new/data
|
||||
```
|
||||
|
||||
## Analyze datasets
|
||||
|
||||
You can easily get basic information about a `.csv` dataset by using the following commands:
|
||||
|
||||
```bash
|
||||
# examine the first 10 rows of the CSV file
|
||||
head -n 10 DATA1.csv
|
||||
# count the number of data in the CSV file (approximately)
|
||||
wc -l DATA1.csv
|
||||
```
|
||||
|
||||
For the dataset provided in a `.csv` or `.parquet` file, you can easily analyze the dataset using the following commands. Plots will be automatically saved.
|
||||
|
||||
```python
|
||||
pyhton -m tools.datasets.analyze DATA_info.csv
|
||||
```
|
||||
|
||||
## Data Process Pipeline
|
||||
|
||||
```bash
|
||||
# Suppose videos and images under ~/dataset/
|
||||
# 1. Convert dataset to CSV
|
||||
python -m tools.datasets.convert video ~/dataset --output meta.csv
|
||||
|
||||
# 2. Get video information
|
||||
python -m tools.datasets.datautil meta.csv --info --fmin 1
|
||||
|
||||
# 3. Get caption
|
||||
# 3.1. generate caption
|
||||
torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava meta_info_fmin1.csv --dp-size 8 --tp-size 1 --model-path liuhaotian/llava-v1.6-mistral-7b --prompt video
|
||||
# merge generated results
|
||||
python -m tools.datasets.datautil meta_info_fmin1_caption_part*.csv --output meta_caption.csv
|
||||
# merge caption and info
|
||||
python -m tools.datasets.datautil meta_info_fmin1.csv --intersection meta_caption.csv --output meta_caption_info.csv
|
||||
# clean caption
|
||||
python -m tools.datasets.datautil meta_caption_info.csv --clean-caption --refine-llm-caption --remove-empty-caption --output meta_caption_processed.csv
|
||||
# 3.2. extract caption
|
||||
python -m tools.datasets.datautil meta_info_fmin1.csv --load-caption json --remove-empty-caption --clean-caption
|
||||
|
||||
# 4. Scoring
|
||||
# aesthetic scoring
|
||||
torchrun --standalone --nproc_per_node 8 -m tools.scoring.aesthetic.inference meta_caption_processed.csv
|
||||
python -m tools.datasets.datautil meta_caption_processed_part*.csv --output meta_caption_processed_aes.csv
|
||||
# optical flow scoring
|
||||
torchrun --standalone --nproc_per_node 8 -m tools.scoring.optical_flow.inference meta_caption_processed.csv
|
||||
# matching scoring
|
||||
torchrun --standalone --nproc_per_node 8 -m tools.scoring.matching.inference meta_caption_processed.csv
|
||||
# camera motion
|
||||
python -m tools.caption.camera_motion_detect meta_caption_processed.csv
|
||||
```
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def read_file(input_path):
|
||||
if input_path.endswith(".csv"):
|
||||
return pd.read_csv(input_path)
|
||||
elif input_path.endswith(".parquet"):
|
||||
return pd.read_parquet(input_path)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported file format: {input_path}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=str, help="Path to the input dataset")
|
||||
parser.add_argument("--save-img", type=str, default="samples/infos/", help="Path to save the image")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def plot_data(data, column, bins, name):
|
||||
plt.clf()
|
||||
data.hist(column=column, bins=bins)
|
||||
os.makedirs(os.path.dirname(name), exist_ok=True)
|
||||
plt.savefig(name)
|
||||
print(f"Saved {name}")
|
||||
|
||||
|
||||
def plot_categorical_data(data, column, name):
|
||||
plt.clf()
|
||||
data[column].value_counts().plot(kind="bar")
|
||||
os.makedirs(os.path.dirname(name), exist_ok=True)
|
||||
plt.savefig(name)
|
||||
print(f"Saved {name}")
|
||||
|
||||
|
||||
COLUMNS = {
|
||||
"num_frames": 100,
|
||||
"resolution": 100,
|
||||
"text_len": 100,
|
||||
"aes": 100,
|
||||
"match": 100,
|
||||
"flow": 100,
|
||||
"cmotion": None,
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
data = read_file(args.input)
|
||||
|
||||
# === Image Data Info ===
|
||||
image_index = data["num_frames"] == 1
|
||||
if image_index.sum() > 0:
|
||||
print("=== Image Data Info ===")
|
||||
img_data = data[image_index]
|
||||
print(f"Number of images: {len(img_data)}")
|
||||
print(img_data.head())
|
||||
print(img_data.describe())
|
||||
if args.save_img:
|
||||
for column in COLUMNS:
|
||||
if column in img_data.columns and column not in ["num_frames", "cmotion"]:
|
||||
if COLUMNS[column] is None:
|
||||
plot_categorical_data(img_data, column, os.path.join(args.save_img, f"image_{column}.png"))
|
||||
else:
|
||||
plot_data(img_data, column, COLUMNS[column], os.path.join(args.save_img, f"image_{column}.png"))
|
||||
|
||||
# === Video Data Info ===
|
||||
if not image_index.all():
|
||||
print("=== Video Data Info ===")
|
||||
video_data = data[~image_index]
|
||||
print(f"Number of videos: {len(video_data)}")
|
||||
if "num_frames" in video_data.columns:
|
||||
total_num_frames = video_data["num_frames"].sum()
|
||||
print(f"Number of frames: {total_num_frames}")
|
||||
DEFAULT_FPS = 30
|
||||
total_hours = total_num_frames / DEFAULT_FPS / 3600
|
||||
print(f"Total hours (30 FPS): {int(total_hours)}")
|
||||
print(video_data.head())
|
||||
print(video_data.describe())
|
||||
if args.save_img:
|
||||
for column in COLUMNS:
|
||||
if column in video_data.columns:
|
||||
if COLUMNS[column] is None:
|
||||
plot_categorical_data(video_data, column, os.path.join(args.save_img, f"video_{column}.png"))
|
||||
else:
|
||||
plot_data(
|
||||
video_data, column, COLUMNS[column], os.path.join(args.save_img, f"video_{column}.png")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
import argparse
|
||||
import subprocess
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
try:
|
||||
from pandarallel import pandarallel
|
||||
|
||||
PANDA_USE_PARALLEL = True
|
||||
except ImportError:
|
||||
PANDA_USE_PARALLEL = False
|
||||
|
||||
import shutil
|
||||
|
||||
if not shutil.which("ffmpeg"):
|
||||
raise ImportError("FFmpeg is not installed")
|
||||
|
||||
|
||||
def apply(df, func, **kwargs):
|
||||
if PANDA_USE_PARALLEL:
|
||||
return df.parallel_apply(func, **kwargs)
|
||||
return df.progress_apply(func, **kwargs)
|
||||
|
||||
|
||||
def check_video_integrity(video_path):
|
||||
# try:
|
||||
can_open_result = subprocess.run(
|
||||
["ffmpeg", "-v", "error", "-i", video_path, "-t", "0", "-f", "null", "-"], # open video and capture 0 seconds
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
fast_scan_result = subprocess.run(
|
||||
["ffmpeg", "-v", "error", "-analyzeduration", "10M", "-probesize", "10M", "-i", video_path, "-f", "null", "-"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
if can_open_result.stderr == "" and fast_scan_result.stderr == "":
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
# except Exception as e:
|
||||
# return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=str, help="path to the input dataset")
|
||||
parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing")
|
||||
parser.add_argument("--num-workers", type=int, default=None, help="number of workers")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.disable_parallel:
|
||||
PANDA_USE_PARALLEL = False
|
||||
if PANDA_USE_PARALLEL:
|
||||
if args.num_workers is not None:
|
||||
pandarallel.initialize(nb_workers=args.num_workers, progress_bar=True)
|
||||
else:
|
||||
pandarallel.initialize(progress_bar=True)
|
||||
|
||||
data = pd.read_csv(args.input)
|
||||
assert "path" in data.columns
|
||||
data["integrity"] = apply(data["path"], check_video_integrity)
|
||||
|
||||
integrity_file_path = args.input.replace(".csv", "_intact.csv")
|
||||
broken_file_path = args.input.replace(".csv", "_broken.csv")
|
||||
|
||||
intact_data = data[data["integrity"] == True].drop(columns=["integrity"])
|
||||
intact_data.to_csv(integrity_file_path, index=False)
|
||||
broken_data = data[data["integrity"] == False].drop(columns=["integrity"])
|
||||
broken_data.to_csv(broken_file_path, index=False)
|
||||
|
||||
print(
|
||||
f"Integrity check completed. Intact videos saved to: {integrity_file_path}, broken videos saved to {broken_file_path}."
|
||||
)
|
||||
|
|
@ -0,0 +1,144 @@
|
|||
import argparse
|
||||
import os
|
||||
import time
|
||||
|
||||
import pandas as pd
|
||||
from torchvision.datasets import ImageNet
|
||||
|
||||
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
|
||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv", ".m2ts")
|
||||
|
||||
|
||||
def scan_recursively(root):
|
||||
num = 0
|
||||
for entry in os.scandir(root):
|
||||
if entry.is_file():
|
||||
yield entry
|
||||
elif entry.is_dir():
|
||||
num += 1
|
||||
if num % 100 == 0:
|
||||
print(f"Scanned {num} directories.")
|
||||
yield from scan_recursively(entry.path)
|
||||
|
||||
|
||||
def get_filelist(file_path, exts=None):
|
||||
filelist = []
|
||||
time_start = time.time()
|
||||
|
||||
# == OS Walk ==
|
||||
# for home, dirs, files in os.walk(file_path):
|
||||
# for filename in files:
|
||||
# ext = os.path.splitext(filename)[-1].lower()
|
||||
# if exts is None or ext in exts:
|
||||
# filelist.append(os.path.join(home, filename))
|
||||
|
||||
# == Scandir ==
|
||||
obj = scan_recursively(file_path)
|
||||
for entry in obj:
|
||||
if entry.is_file():
|
||||
ext = os.path.splitext(entry.name)[-1].lower()
|
||||
if exts is None or ext in exts:
|
||||
filelist.append(entry.path)
|
||||
|
||||
time_end = time.time()
|
||||
print(f"Scanned {len(filelist)} files in {time_end - time_start:.2f} seconds.")
|
||||
return filelist
|
||||
|
||||
|
||||
def split_by_capital(name):
|
||||
# BoxingPunchingBag -> Boxing Punching Bag
|
||||
new_name = ""
|
||||
for i in range(len(name)):
|
||||
if name[i].isupper() and i != 0:
|
||||
new_name += " "
|
||||
new_name += name[i]
|
||||
return new_name
|
||||
|
||||
|
||||
def process_imagenet(root, split):
|
||||
root = os.path.expanduser(root)
|
||||
data = ImageNet(root, split=split)
|
||||
samples = [(path, data.classes[label][0]) for path, label in data.samples]
|
||||
output = f"imagenet_{split}.csv"
|
||||
|
||||
df = pd.DataFrame(samples, columns=["path", "text"])
|
||||
df.to_csv(output, index=False)
|
||||
print(f"Saved {len(samples)} samples to {output}.")
|
||||
|
||||
|
||||
def process_ucf101(root, split):
|
||||
root = os.path.expanduser(root)
|
||||
video_lists = get_filelist(os.path.join(root, split))
|
||||
classes = [x.split("/")[-2] for x in video_lists]
|
||||
classes = [split_by_capital(x) for x in classes]
|
||||
samples = list(zip(video_lists, classes))
|
||||
output = f"ucf101_{split}.csv"
|
||||
|
||||
df = pd.DataFrame(samples, columns=["path", "text"])
|
||||
df.to_csv(output, index=False)
|
||||
print(f"Saved {len(samples)} samples to {output}.")
|
||||
|
||||
|
||||
def process_vidprom(root, info):
|
||||
root = os.path.expanduser(root)
|
||||
video_lists = get_filelist(root)
|
||||
video_set = set(video_lists)
|
||||
# read info csv
|
||||
infos = pd.read_csv(info)
|
||||
abs_path = infos["uuid"].apply(lambda x: os.path.join(root, f"pika-{x}.mp4"))
|
||||
is_exist = abs_path.apply(lambda x: x in video_set)
|
||||
df = pd.DataFrame(dict(path=abs_path[is_exist], text=infos["prompt"][is_exist]))
|
||||
df.to_csv("vidprom.csv", index=False)
|
||||
print(f"Saved {len(df)} samples to vidprom.csv.")
|
||||
|
||||
|
||||
def process_general_images(root, output):
|
||||
root = os.path.expanduser(root)
|
||||
if not os.path.exists(root):
|
||||
return
|
||||
path_list = get_filelist(root, IMG_EXTENSIONS)
|
||||
fname_list = [os.path.splitext(os.path.basename(x))[0] for x in path_list]
|
||||
relpath_list = [os.path.relpath(x, root) for x in path_list]
|
||||
df = pd.DataFrame(dict(path=path_list, id=fname_list, relpath=relpath_list))
|
||||
|
||||
os.makedirs(os.path.dirname(output), exist_ok=True)
|
||||
df.to_csv(output, index=False)
|
||||
print(f"Saved {len(df)} samples to {output}.")
|
||||
|
||||
|
||||
def process_general_videos(root, output):
|
||||
root = os.path.expanduser(root)
|
||||
if not os.path.exists(root):
|
||||
return
|
||||
path_list = get_filelist(root, VID_EXTENSIONS)
|
||||
path_list = list(set(path_list)) # remove duplicates
|
||||
fname_list = [os.path.splitext(os.path.basename(x))[0] for x in path_list]
|
||||
relpath_list = [os.path.relpath(x, root) for x in path_list]
|
||||
df = pd.DataFrame(dict(path=path_list, id=fname_list, relpath=relpath_list))
|
||||
|
||||
os.makedirs(os.path.dirname(output), exist_ok=True)
|
||||
df.to_csv(output, index=False)
|
||||
print(f"Saved {len(df)} samples to {output}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101", "vidprom", "image", "video"])
|
||||
parser.add_argument("root", type=str)
|
||||
parser.add_argument("--split", type=str, default="train")
|
||||
parser.add_argument("--info", type=str, default=None)
|
||||
parser.add_argument("--output", type=str, default=None, required=True, help="Output path")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dataset == "imagenet":
|
||||
process_imagenet(args.root, args.split)
|
||||
elif args.dataset == "ucf101":
|
||||
process_ucf101(args.root, args.split)
|
||||
elif args.dataset == "vidprom":
|
||||
process_vidprom(args.root, args.info)
|
||||
elif args.dataset == "image":
|
||||
process_general_images(args.root, args.output)
|
||||
elif args.dataset == "video":
|
||||
process_general_videos(args.root, args.output)
|
||||
else:
|
||||
raise ValueError("Invalid dataset")
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert CSV file to txt file")
|
||||
parser.add_argument("csv_file", type=str, help="CSV file to convert")
|
||||
parser.add_argument("txt_file", type=str, help="TXT file to save")
|
||||
args = parser.parse_args()
|
||||
|
||||
data = pd.read_csv(args.csv_file)
|
||||
text = data["text"].to_list()
|
||||
text = "\n".join(text)
|
||||
with open(args.txt_file, "w") as f:
|
||||
f.write(text)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,262 @@
|
|||
# TODO: remove this file before releasing
|
||||
|
||||
import argparse
|
||||
import html
|
||||
import os
|
||||
import re
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
try:
|
||||
from pandarallel import pandarallel
|
||||
|
||||
pandarallel.initialize(progress_bar=True)
|
||||
pandas_has_parallel = True
|
||||
except ImportError:
|
||||
pandas_has_parallel = False
|
||||
|
||||
|
||||
def apply(df, func, **kwargs):
|
||||
if pandas_has_parallel:
|
||||
return df.parallel_apply(func, **kwargs)
|
||||
return df.progress_apply(func, **kwargs)
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
import ftfy
|
||||
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
BAD_PUNCT_REGEX = re.compile(
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
|
||||
def clean_caption(caption):
|
||||
import urllib.parse as ul
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = basic_clean(caption)
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
|
||||
def get_10m_set():
|
||||
meta_path_10m = "/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m.csv"
|
||||
meta_10m = pd.read_csv(meta_path_10m)
|
||||
|
||||
def process_single_caption(row):
|
||||
text_list = eval(row["caption"])
|
||||
clean_list = [clean_caption(x) for x in text_list]
|
||||
return str(clean_list)
|
||||
|
||||
ret = apply(meta_10m, process_single_caption, axis=1)
|
||||
# ret = meta_10m.progress_apply(process_single_caption, axis=1)
|
||||
print("==> text processed.")
|
||||
|
||||
text_list = []
|
||||
for x in ret:
|
||||
text_list += eval(x)
|
||||
# text_set = text_set.union(set(eval(x)))
|
||||
text_set = set(text_list)
|
||||
# meta_10m['caption_new'] = ret
|
||||
# meta_10m.to_csv('/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m_new-cap.csv')
|
||||
|
||||
# video_id_set = set(meta_10m['videoID'])
|
||||
# id2t = {}
|
||||
# for idx, row in tqdm(meta_10m.iterrows(), total=len(meta_10m)):
|
||||
# video_id = row['videoID']
|
||||
# text_list = eval(row['caption'])
|
||||
# id2t[video_id] = set(text_list)
|
||||
|
||||
print(f"==> Loaded meta_10m from '{meta_path_10m}'")
|
||||
return text_set
|
||||
|
||||
|
||||
def filter_panda10m_text(meta_path, text_set):
|
||||
def process_single_row(row):
|
||||
# path = row['path']
|
||||
t = row["text"]
|
||||
# fname = os.path.basename(path)
|
||||
# video_id = fname[:fname.rindex('_')]
|
||||
if t not in text_set:
|
||||
return False
|
||||
return True
|
||||
|
||||
meta = pd.read_csv(meta_path)
|
||||
ret = apply(meta, process_single_row, axis=1)
|
||||
# ret = meta.progress_apply(process_single_row, axis=1)
|
||||
|
||||
meta = meta[ret]
|
||||
wo_ext, ext = os.path.splitext(meta_path)
|
||||
out_path = f"{wo_ext}_filter-10m{ext}"
|
||||
meta.to_csv(out_path, index=False)
|
||||
print(f"New meta (shape={meta.shape}) saved to '{out_path}'.")
|
||||
|
||||
|
||||
def filter_panda10m_timestamp(meta_path):
|
||||
meta_path_10m = "/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m.csv"
|
||||
meta_10m = pd.read_csv(meta_path_10m)
|
||||
|
||||
id2t = {}
|
||||
for idx, row in tqdm(meta_10m.iterrows(), total=len(meta_10m)):
|
||||
video_id = row["videoID"]
|
||||
timestamp = eval(row["timestamp"])
|
||||
timestamp = [str(tuple(x)) for x in timestamp]
|
||||
id2t[video_id] = timestamp
|
||||
|
||||
# video_id_set_10m = set(meta_10m['videoID'])
|
||||
print(f"==> Loaded meta_10m from '{meta_path_10m}'")
|
||||
|
||||
def process_single_row(row):
|
||||
path = row["path"]
|
||||
t = row["timestamp"]
|
||||
fname = os.path.basename(path)
|
||||
video_id = fname[: fname.rindex("_")]
|
||||
if video_id not in id2t:
|
||||
return False
|
||||
if t not in id2t[video_id]:
|
||||
return False
|
||||
return True
|
||||
# return video_id in video_id_set_10m
|
||||
|
||||
meta = pd.read_csv(meta_path)
|
||||
ret = apply(meta, process_single_row, axis=1)
|
||||
|
||||
meta = meta[ret]
|
||||
wo_ext, ext = os.path.splitext(meta_path)
|
||||
out_path = f"{wo_ext}_filter-10m{ext}"
|
||||
meta.to_csv(out_path, index=False)
|
||||
print(f"New meta (shape={meta.shape}) saved to '{out_path}'.")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--meta_path", type=str, nargs="+")
|
||||
parser.add_argument("--num_workers", default=5, type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
text_set = get_10m_set()
|
||||
for x in args.meta_path:
|
||||
filter_panda10m_text(x, text_set)
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
try:
|
||||
from pandarallel import pandarallel
|
||||
|
||||
PANDA_USE_PARALLEL = True
|
||||
except ImportError:
|
||||
PANDA_USE_PARALLEL = False
|
||||
|
||||
|
||||
def save_first_frame(video_path, img_dir):
|
||||
if not os.path.exists(video_path):
|
||||
print(f"Video not found: {video_path}")
|
||||
return ""
|
||||
|
||||
try:
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
success, frame = cap.read()
|
||||
if success:
|
||||
video_name = os.path.basename(video_path)
|
||||
image_name = os.path.splitext(video_name)[0] + "_first_frame.jpg"
|
||||
image_path = os.path.join(img_dir, image_name)
|
||||
|
||||
cv2.imwrite(image_path, frame)
|
||||
else:
|
||||
raise ValueError("Video broken.")
|
||||
cap.release()
|
||||
return image_path
|
||||
except Exception as e:
|
||||
print(f"Save first frame of `{video_path}` failed. {e}")
|
||||
return ""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=str, help="path to the input csv dataset")
|
||||
parser.add_argument("--img-dir", type=str, help="path to save first frame image")
|
||||
parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing")
|
||||
parser.add_argument("--num-workers", type=int, default=None, help="number of workers")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.disable_parallel:
|
||||
PANDA_USE_PARALLEL = False
|
||||
if PANDA_USE_PARALLEL:
|
||||
if args.num_workers is not None:
|
||||
pandarallel.initialize(nb_workers=args.num_workers, progress_bar=True)
|
||||
else:
|
||||
pandarallel.initialize(progress_bar=True)
|
||||
|
||||
if not os.path.exists(args.img_dir):
|
||||
os.makedirs(args.img_dir)
|
||||
|
||||
data = pd.read_csv(args.input)
|
||||
|
||||
data["first_frame_path"] = data["path"].parallel_apply(save_first_frame, img_dir=args.img_dir)
|
||||
data_filtered = data.loc[data["first_frame_path"] != ""]
|
||||
output_csv_path = args.input.replace(".csv", "_first-frame.csv")
|
||||
data_filtered.to_csv(output_csv_path, index=False)
|
||||
print(f"First frame csv saved to: {output_csv_path}, first frame images saved to {args.img_dir}.")
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
import argparse
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
from mmengine.config import Config
|
||||
|
||||
from opensora.datasets.bucket import Bucket
|
||||
|
||||
|
||||
def split_by_bucket(
|
||||
bucket: Bucket,
|
||||
input_files: List[str],
|
||||
output_path: str,
|
||||
limit: int,
|
||||
frame_interval: int,
|
||||
):
|
||||
print(f"Split {len(input_files)} files into {len(bucket)} buckets")
|
||||
total_limit = len(bucket) * limit
|
||||
bucket_cnt = {}
|
||||
# get all bucket id
|
||||
for hw_id, d in bucket.ar_criteria.items():
|
||||
for t_id, v in d.items():
|
||||
for ar_id in v.keys():
|
||||
bucket_id = (hw_id, t_id, ar_id)
|
||||
bucket_cnt[bucket_id] = 0
|
||||
output_df = None
|
||||
# split files
|
||||
for path in input_files:
|
||||
df = pd.read_csv(path)
|
||||
if output_df is None:
|
||||
output_df = pd.DataFrame(columns=df.columns)
|
||||
for i in range(len(df)):
|
||||
row = df.iloc[i]
|
||||
t, h, w = row["num_frames"], row["height"], row["width"]
|
||||
bucket_id = bucket.get_bucket_id(t, h, w, frame_interval)
|
||||
if bucket_id is None:
|
||||
continue
|
||||
if bucket_cnt[bucket_id] < limit:
|
||||
bucket_cnt[bucket_id] += 1
|
||||
output_df = pd.concat([output_df, pd.DataFrame([row])], ignore_index=True)
|
||||
if len(output_df) >= total_limit:
|
||||
break
|
||||
if len(output_df) >= total_limit:
|
||||
break
|
||||
assert len(output_df) <= total_limit
|
||||
if len(output_df) == total_limit:
|
||||
print(f"All buckets are full ({total_limit} samples)")
|
||||
else:
|
||||
print(f"Only {len(output_df)} files are used")
|
||||
output_df.to_csv(output_path, index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=str, nargs="+")
|
||||
parser.add_argument("-o", "--output", required=True)
|
||||
parser.add_argument("-c", "--config", required=True)
|
||||
parser.add_argument("-l", "--limit", default=200, type=int)
|
||||
args = parser.parse_args()
|
||||
assert args.limit > 0
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
bucket_config = cfg.bucket_config
|
||||
# rewrite bucket_config
|
||||
for ar, d in bucket_config.items():
|
||||
for frames, t in d.items():
|
||||
p, bs = t
|
||||
if p > 0.0:
|
||||
p = 1.0
|
||||
d[frames] = (p, bs)
|
||||
bucket = Bucket(bucket_config)
|
||||
split_by_bucket(bucket, args.input, args.output, args.limit, cfg.dataset.frame_interval)
|
||||
|
|
@ -0,0 +1,306 @@
|
|||
import argparse
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
import cv2
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandarallel import pandarallel
|
||||
from tqdm import tqdm
|
||||
|
||||
from .utils import IMG_EXTENSIONS, extract_frames
|
||||
|
||||
tqdm.pandas()
|
||||
USE_PANDARALLEL = True
|
||||
|
||||
|
||||
def apply(df, func, **kwargs):
|
||||
if USE_PANDARALLEL:
|
||||
return df.parallel_apply(func, **kwargs)
|
||||
return df.progress_apply(func, **kwargs)
|
||||
|
||||
|
||||
def get_new_path(path, input_dir, output):
|
||||
path_new = os.path.join(output, os.path.relpath(path, input_dir))
|
||||
os.makedirs(os.path.dirname(path_new), exist_ok=True)
|
||||
return path_new
|
||||
|
||||
|
||||
def resize_longer(path, length, input_dir, output_dir):
|
||||
path_new = get_new_path(path, input_dir, output_dir)
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
assert ext in IMG_EXTENSIONS
|
||||
img = cv2.imread(path)
|
||||
if img is not None:
|
||||
h, w = img.shape[:2]
|
||||
if min(h, w) > length:
|
||||
if h > w:
|
||||
new_h = length
|
||||
new_w = int(w / h * length)
|
||||
else:
|
||||
new_w = length
|
||||
new_h = int(h / w * length)
|
||||
img = cv2.resize(img, (new_w, new_h))
|
||||
cv2.imwrite(path_new, img)
|
||||
else:
|
||||
path_new = ""
|
||||
return path_new
|
||||
|
||||
|
||||
def resize_shorter(path, length, input_dir, output_dir):
|
||||
path_new = get_new_path(path, input_dir, output_dir)
|
||||
if os.path.exists(path_new):
|
||||
return path_new
|
||||
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
assert ext in IMG_EXTENSIONS
|
||||
img = cv2.imread(path)
|
||||
if img is not None:
|
||||
h, w = img.shape[:2]
|
||||
if min(h, w) > length:
|
||||
if h > w:
|
||||
new_w = length
|
||||
new_h = int(h / w * length)
|
||||
else:
|
||||
new_h = length
|
||||
new_w = int(w / h * length)
|
||||
img = cv2.resize(img, (new_w, new_h))
|
||||
cv2.imwrite(path_new, img)
|
||||
else:
|
||||
path_new = ""
|
||||
return path_new
|
||||
|
||||
|
||||
def rand_crop(path, input_dir, output):
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
path_new = get_new_path(path, input_dir, output)
|
||||
assert ext in IMG_EXTENSIONS
|
||||
img = cv2.imread(path)
|
||||
if img is not None:
|
||||
h, w = img.shape[:2]
|
||||
width, height, _ = img.shape
|
||||
pos = random.randint(0, 3)
|
||||
if pos == 0:
|
||||
img_cropped = img[: width // 2, : height // 2]
|
||||
elif pos == 1:
|
||||
img_cropped = img[width // 2 :, : height // 2]
|
||||
elif pos == 2:
|
||||
img_cropped = img[: width // 2, height // 2 :]
|
||||
else:
|
||||
img_cropped = img[width // 2 :, height // 2 :]
|
||||
cv2.imwrite(path_new, img_cropped)
|
||||
else:
|
||||
path_new = ""
|
||||
return path_new
|
||||
|
||||
|
||||
def m2ts_to_mp4(row, output_dir):
|
||||
input_path = row["path"]
|
||||
output_name = os.path.basename(input_path).replace(".m2ts", ".mp4")
|
||||
output_path = os.path.join(output_dir, output_name)
|
||||
# create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
try:
|
||||
ffmpeg.input(input_path).output(output_path).overwrite_output().global_args("-loglevel", "quiet").run(
|
||||
capture_stdout=True
|
||||
)
|
||||
row["path"] = output_path
|
||||
row["relpath"] = os.path.splitext(row["relpath"])[0] + ".mp4"
|
||||
except Exception as e:
|
||||
print(f"Error converting {input_path} to mp4: {e}")
|
||||
row["path"] = ""
|
||||
row["relpath"] = ""
|
||||
return row
|
||||
return row
|
||||
|
||||
|
||||
def mkv_to_mp4(row, output_dir):
|
||||
# str_to_replace and str_to_replace_with account for the different directory structure
|
||||
input_path = row["path"]
|
||||
output_name = os.path.basename(input_path).replace(".mkv", ".mp4")
|
||||
output_path = os.path.join(output_dir, output_name)
|
||||
|
||||
# create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
try:
|
||||
ffmpeg.input(input_path).output(output_path).overwrite_output().global_args("-loglevel", "quiet").run(
|
||||
capture_stdout=True
|
||||
)
|
||||
row["path"] = output_path
|
||||
row["relpath"] = os.path.splitext(row["relpath"])[0] + ".mp4"
|
||||
except Exception as e:
|
||||
print(f"Error converting {input_path} to mp4: {e}")
|
||||
row["path"] = ""
|
||||
row["relpath"] = ""
|
||||
return row
|
||||
return row
|
||||
|
||||
|
||||
def mp4_to_mp4(row, output_dir):
|
||||
# str_to_replace and str_to_replace_with account for the different directory structure
|
||||
input_path = row["path"]
|
||||
|
||||
# 检查输入文件是否为.mp4文件
|
||||
if not input_path.lower().endswith(".mp4"):
|
||||
print(f"Error: {input_path} is not an .mp4 file.")
|
||||
row["path"] = ""
|
||||
row["relpath"] = ""
|
||||
return row
|
||||
output_name = os.path.basename(input_path)
|
||||
output_path = os.path.join(output_dir, output_name)
|
||||
|
||||
# create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
try:
|
||||
shutil.copy2(input_path, output_path) # 使用shutil复制文件
|
||||
row["path"] = output_path
|
||||
row["relpath"] = os.path.splitext(row["relpath"])[0] + ".mp4"
|
||||
except Exception as e:
|
||||
print(f"Error coy {input_path} to mp4: {e}")
|
||||
row["path"] = ""
|
||||
row["relpath"] = ""
|
||||
return row
|
||||
return row
|
||||
|
||||
|
||||
def crop_to_square(input_path, output_path):
|
||||
cmd = (
|
||||
f"ffmpeg -i {input_path} "
|
||||
f"-vf \"crop='min(in_w,in_h)':'min(in_w,in_h)':'(in_w-min(in_w,in_h))/2':'(in_h-min(in_w,in_h))/2'\" "
|
||||
f"-c:v libx264 -an "
|
||||
f"-map 0:v {output_path}"
|
||||
)
|
||||
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True)
|
||||
stdout, stderr = proc.communicate()
|
||||
|
||||
|
||||
def vid_crop_center(row, input_dir, output_dir):
|
||||
input_path = row["path"]
|
||||
relpath = os.path.relpath(input_path, input_dir)
|
||||
assert not relpath.startswith("..")
|
||||
output_path = os.path.join(output_dir, relpath)
|
||||
|
||||
# create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
try:
|
||||
crop_to_square(input_path, output_path)
|
||||
size = min(row["height"], row["width"])
|
||||
row["path"] = output_path
|
||||
row["height"] = size
|
||||
row["width"] = size
|
||||
row["aspect_ratio"] = 1.0
|
||||
row["resolution"] = size**2
|
||||
except Exception as e:
|
||||
print(f"Error cropping {input_path} to center: {e}")
|
||||
row["path"] = ""
|
||||
return row
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
global USE_PANDARALLEL
|
||||
|
||||
assert args.num_workers is None or not args.disable_parallel
|
||||
if args.disable_parallel:
|
||||
USE_PANDARALLEL = False
|
||||
if args.num_workers is not None:
|
||||
pandarallel.initialize(progress_bar=True, nb_workers=args.num_workers)
|
||||
else:
|
||||
pandarallel.initialize(progress_bar=True)
|
||||
|
||||
random.seed(args.seed)
|
||||
data = pd.read_csv(args.meta_path)
|
||||
if args.task == "img_rand_crop":
|
||||
data["path"] = apply(data["path"], lambda x: rand_crop(x, args.input_dir, args.output_dir))
|
||||
output_csv = args.meta_path.replace(".csv", "_rand_crop.csv")
|
||||
elif args.task == "img_resize_longer":
|
||||
data["path"] = apply(data["path"], lambda x: resize_longer(x, args.length, args.input_dir, args.output_dir))
|
||||
output_csv = args.meta_path.replace(".csv", f"_resize-longer-{args.length}.csv")
|
||||
elif args.task == "img_resize_shorter":
|
||||
data["path"] = apply(data["path"], lambda x: resize_shorter(x, args.length, args.input_dir, args.output_dir))
|
||||
output_csv = args.meta_path.replace(".csv", f"_resize-shorter-{args.length}.csv")
|
||||
elif args.task == "vid_frame_extract":
|
||||
points = args.points if args.points is not None else args.points_index
|
||||
data = pd.DataFrame(np.repeat(data.values, 3, axis=0), columns=data.columns)
|
||||
num_points = len(points)
|
||||
data["point"] = np.nan
|
||||
for i, point in enumerate(points):
|
||||
if isinstance(point, int):
|
||||
data.loc[i::num_points, "point"] = point
|
||||
else:
|
||||
data.loc[i::num_points, "point"] = data.loc[i::num_points, "num_frames"] * point
|
||||
data["path"] = apply(
|
||||
data, lambda x: extract_frames(x["path"], args.input_dir, args.output_dir, x["point"]), axis=1
|
||||
)
|
||||
output_csv = args.meta_path.replace(".csv", "_vid_frame_extract.csv")
|
||||
elif args.task == "m2ts_to_mp4":
|
||||
print(f"m2ts_to_mp4作业开始:{args.output_dir}")
|
||||
assert args.meta_path.endswith("_m2ts.csv"), "Input file must end with '_m2ts.csv'"
|
||||
m2ts_to_mp4_partial = lambda x: m2ts_to_mp4(x, args.output_dir)
|
||||
data = apply(data, m2ts_to_mp4_partial, axis=1)
|
||||
data = data[data["path"] != ""]
|
||||
output_csv = args.meta_path.replace("_m2ts.csv", ".csv")
|
||||
elif args.task == "mkv_to_mp4":
|
||||
print(f"mkv_to_mp4作业开始:{args.output_dir}")
|
||||
assert args.meta_path.endswith("_mkv.csv"), "Input file must end with '_mkv.csv'"
|
||||
mkv_to_mp4_partial = lambda x: mkv_to_mp4(x, args.output_dir)
|
||||
data = apply(data, mkv_to_mp4_partial, axis=1)
|
||||
data = data[data["path"] != ""]
|
||||
output_csv = args.meta_path.replace("_mkv.csv", ".csv")
|
||||
elif args.task == "mp4_to_mp4":
|
||||
# assert args.meta_path.endswith("meta.csv"), "Input file must end with '_mkv.csv'"
|
||||
print(f"MP4复制作业开始:{args.output_dir}")
|
||||
mkv_to_mp4_partial = lambda x: mp4_to_mp4(x, args.output_dir)
|
||||
data = apply(data, mkv_to_mp4_partial, axis=1)
|
||||
data = data[data["path"] != ""]
|
||||
output_csv = args.meta_path
|
||||
elif args.task == "vid_crop_center":
|
||||
vid_crop_center_partial = lambda x: vid_crop_center(x, args.input_dir, args.output_dir)
|
||||
data = apply(data, vid_crop_center_partial, axis=1)
|
||||
data = data[data["path"] != ""]
|
||||
output_csv = args.meta_path.replace(".csv", "_center-crop.csv")
|
||||
else:
|
||||
raise ValueError
|
||||
data.to_csv(output_csv, index=False)
|
||||
print(f"Saved to {output_csv}")
|
||||
raise SystemExit(0)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=[
|
||||
"img_resize_longer",
|
||||
"img_resize_shorter",
|
||||
"img_rand_crop",
|
||||
"vid_frame_extract",
|
||||
"m2ts_to_mp4",
|
||||
"mkv_to_mp4",
|
||||
"mp4_to_mp4",
|
||||
"vid_crop_center",
|
||||
],
|
||||
)
|
||||
parser.add_argument("--meta_path", type=str, required=True)
|
||||
parser.add_argument("--input_dir", type=str)
|
||||
parser.add_argument("--output_dir", type=str)
|
||||
parser.add_argument("--length", type=int, default=1080)
|
||||
parser.add_argument("--disable-parallel", action="store_true")
|
||||
parser.add_argument("--num_workers", type=int, default=None)
|
||||
parser.add_argument("--seed", type=int, default=42, help="seed for random")
|
||||
parser.add_argument("--points", nargs="+", type=float, default=None)
|
||||
parser.add_argument("--points_index", nargs="+", type=int, default=None)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
|
||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||
|
||||
|
||||
def is_video(filename):
|
||||
ext = os.path.splitext(filename)[-1].lower()
|
||||
return ext in VID_EXTENSIONS
|
||||
|
||||
|
||||
def extract_frames(
|
||||
video_path,
|
||||
frame_inds=None,
|
||||
points=None,
|
||||
backend="opencv",
|
||||
return_length=False,
|
||||
num_frames=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
video_path (str): path to video
|
||||
frame_inds (List[int]): indices of frames to extract
|
||||
points (List[float]): values within [0, 1); multiply #frames to get frame indices
|
||||
Return:
|
||||
List[PIL.Image]
|
||||
"""
|
||||
assert backend in ["av", "opencv", "decord"]
|
||||
assert (frame_inds is None) or (points is None)
|
||||
|
||||
if backend == "av":
|
||||
import av
|
||||
|
||||
container = av.open(video_path)
|
||||
if num_frames is not None:
|
||||
total_frames = num_frames
|
||||
else:
|
||||
total_frames = container.streams.video[0].frames
|
||||
|
||||
if points is not None:
|
||||
frame_inds = [int(p * total_frames) for p in points]
|
||||
|
||||
frames = []
|
||||
for idx in frame_inds:
|
||||
if idx >= total_frames:
|
||||
idx = total_frames - 1
|
||||
target_timestamp = int(idx * av.time_base / container.streams.video[0].average_rate)
|
||||
container.seek(target_timestamp) # return the nearest key frame, not the precise timestamp!!!
|
||||
frame = next(container.decode(video=0)).to_image()
|
||||
frames.append(frame)
|
||||
|
||||
if return_length:
|
||||
return frames, total_frames
|
||||
return frames
|
||||
|
||||
elif backend == "decord":
|
||||
import decord
|
||||
|
||||
container = decord.VideoReader(video_path, num_threads=1)
|
||||
if num_frames is not None:
|
||||
total_frames = num_frames
|
||||
else:
|
||||
total_frames = len(container)
|
||||
|
||||
if points is not None:
|
||||
frame_inds = [int(p * total_frames) for p in points]
|
||||
|
||||
frame_inds = np.array(frame_inds).astype(np.int32)
|
||||
frame_inds[frame_inds >= total_frames] = total_frames - 1
|
||||
frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C]
|
||||
frames = [Image.fromarray(x) for x in frames]
|
||||
|
||||
if return_length:
|
||||
return frames, total_frames
|
||||
return frames
|
||||
|
||||
elif backend == "opencv":
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if num_frames is not None:
|
||||
total_frames = num_frames
|
||||
else:
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
if points is not None:
|
||||
frame_inds = [int(p * total_frames) for p in points]
|
||||
|
||||
frames = []
|
||||
for idx in frame_inds:
|
||||
if idx >= total_frames:
|
||||
idx = total_frames - 1
|
||||
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
||||
|
||||
# HACK: sometimes OpenCV fails to read frames, return a black frame instead
|
||||
try:
|
||||
ret, frame = cap.read()
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame = Image.fromarray(frame)
|
||||
except Exception as e:
|
||||
print(f"[Warning] Error reading frame {idx} from {video_path}: {e}")
|
||||
# First, try to read the first frame
|
||||
try:
|
||||
print(f"[Warning] Try reading first frame.")
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
ret, frame = cap.read()
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame = Image.fromarray(frame)
|
||||
# If that fails, return a black frame
|
||||
except Exception as e:
|
||||
print(f"[Warning] Error in reading first frame from {video_path}: {e}")
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame = Image.new("RGB", (width, height), (0, 0, 0))
|
||||
|
||||
# HACK: if height or width is 0, return a black frame instead
|
||||
if frame.height == 0 or frame.width == 0:
|
||||
height = width = 256
|
||||
frame = Image.new("RGB", (width, height), (0, 0, 0))
|
||||
|
||||
frames.append(frame)
|
||||
|
||||
if return_length:
|
||||
return frames, total_frames
|
||||
return frames
|
||||
else:
|
||||
raise ValueError
|
||||
Loading…
Reference in New Issue