diff --git a/.gitignore b/.gitignore index b0b62a0..6d65ee8 100644 --- a/.gitignore +++ b/.gitignore @@ -195,4 +195,5 @@ package.json exps ckpts flash-attention -datasets +/datasets +!opensora/datasets/ diff --git a/opensora/datasets/__init__.py b/opensora/datasets/__init__.py new file mode 100644 index 0000000..6911533 --- /dev/null +++ b/opensora/datasets/__init__.py @@ -0,0 +1,2 @@ +from .datasets import TextDataset, VideoTextDataset +from .utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample diff --git a/opensora/datasets/aspect.py b/opensora/datasets/aspect.py new file mode 100644 index 0000000..5320165 --- /dev/null +++ b/opensora/datasets/aspect.py @@ -0,0 +1,151 @@ +import math +import os + +ASPECT_RATIO_LD_LIST = [ # width:height + "2.39:1", # cinemascope, 2.39 + "2:1", # rare, 2 + "16:9", # rare, 1.89 + "1.85:1", # american widescreen, 1.85 + "9:16", # popular, 1.78 + "5:8", # rare, 1.6 + "3:2", # rare, 1.5 + "4:3", # classic, 1.33 + "1:1", # square +] + + +def get_ratio(name: str) -> float: + width, height = map(float, name.split(":")) + return height / width + + +def get_aspect_ratios_dict( + total_pixels: int = 256 * 256, training: bool = True +) -> dict[str, tuple[int, int]]: + D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16)) + aspect_ratios_dict = {} + aspect_ratios_vertical_dict = {} + for ratio in ASPECT_RATIO_LD_LIST: + width_ratio, height_ratio = map(float, ratio.split(":")) + width = int(math.sqrt(total_pixels * (width_ratio / height_ratio)) // D) * D + height = int((total_pixels / width) // D) * D + + if training: + # adjust aspect ratio to match total pixels + diff = abs(height * width - total_pixels) + candidate = [ + (height - D, width), + (height + D, width), + (height, width - D), + (height, width + D), + ] + for h, w in candidate: + if abs(h * w - total_pixels) < diff: + height, width = h, w + diff = abs(h * w - total_pixels) + + # remove duplicated aspect ratio + if (height, width) not in aspect_ratios_dict.values() or not training: + aspect_ratios_dict[ratio] = (height, width) + vertial_ratios = ":".join(ratio.split(":")[::-1]) + aspect_ratios_vertical_dict[vertial_ratios] = (width, height) + + aspect_ratios_dict.update(aspect_ratios_vertical_dict) + + return aspect_ratios_dict + + +def get_num_pexels(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]: + return {ratio: h * w for ratio, (h, w) in aspect_ratios_dict.items()} + + +def get_num_tokens(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]: + D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16)) + return {ratio: h * w // D // D for ratio, (h, w) in aspect_ratios_dict.items()} + + +def get_num_pexels_from_name(resolution: str) -> int: + resolution = resolution.split("_")[0] + if resolution.endswith("px"): + size = int(resolution[:-2]) + num_pexels = size * size + elif resolution.endswith("p"): + size = int(resolution[:-1]) + num_pexels = int(size * size / 9 * 16) + else: + raise ValueError(f"Invalid resolution {resolution}") + return num_pexels + + +def get_resolution_with_aspect_ratio( + resolution: str, +) -> tuple[int, dict[str, tuple[int, int]]]: + """Get resolution with aspect ratio + + Args: + resolution (str): resolution name. The format is name only or "{name}_{setting}". + name supports "256px" or "360p". setting supports "ar1:1" or "max". + + Returns: + tuple[int, dict[str, tuple[int, int]]]: resolution with aspect ratio + """ + keys = resolution.split("_") + if len(keys) == 1: + resolution = keys[0] + setting = "" + else: + resolution, setting = keys + assert setting == "max" or setting.startswith( + "ar" + ), f"Invalid setting {setting}" + + # get resolution + num_pexels = get_num_pexels_from_name(resolution) + + # get aspect ratio + aspect_ratio_dict = get_aspect_ratios_dict(num_pexels) + + # handle setting + if setting == "max": + aspect_ratio = max( + aspect_ratio_dict, + key=lambda x: aspect_ratio_dict[x][0] * aspect_ratio_dict[x][1], + ) + aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]} + elif setting.startswith("ar"): + aspect_ratio = setting[2:] + assert ( + aspect_ratio in aspect_ratio_dict + ), f"Aspect ratio {aspect_ratio} not found" + aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]} + + return num_pexels, aspect_ratio_dict + + +def get_closest_ratio(height: float, width: float, ratios: dict) -> str: + aspect_ratio = height / width + closest_ratio = min( + ratios.keys(), key=lambda ratio: abs(aspect_ratio - get_ratio(ratio)) + ) + return closest_ratio + + +def get_image_size( + resolution: str, ar_ratio: str, training: bool = True +) -> tuple[int, int]: + num_pexels = get_num_pexels_from_name(resolution) + ar_dict = get_aspect_ratios_dict(num_pexels, training) + assert ar_ratio in ar_dict, f"Aspect ratio {ar_ratio} not found" + return ar_dict[ar_ratio] + + +def bucket_to_shapes(bucket_config, batch_size=None): + shapes = [] + for resolution, infos in bucket_config.items(): + for num_frames, (_, bs) in infos.items(): + aspect_ratios = get_aspect_ratios_dict(get_num_pexels_from_name(resolution)) + for ar, (height, width) in aspect_ratios.items(): + if batch_size is not None: + bs = batch_size + shapes.append((bs, 3, num_frames, height, width)) + return shapes diff --git a/opensora/datasets/bucket.py b/opensora/datasets/bucket.py new file mode 100644 index 0000000..ce38bf6 --- /dev/null +++ b/opensora/datasets/bucket.py @@ -0,0 +1,139 @@ +from collections import OrderedDict + +import numpy as np + +from opensora.utils.logger import log_message + +from .aspect import get_closest_ratio, get_resolution_with_aspect_ratio +from .utils import map_target_fps + + +class Bucket: + def __init__(self, bucket_config: dict[str, dict[int, tuple[float, int] | tuple[tuple[float, float], int]]]): + """ + Args: + bucket_config (dict): A dictionary containing the bucket configuration. + The dictionary should be in the following format: + { + "bucket_name": { + "time": (probability, batch_size), + "time": (probability, batch_size), + ... + }, + ... + } + + Or in the following format: + { + "bucket_name": { + "time": ((probability, next_probability), batch_size), + "time": ((probability, next_probability), batch_size), + ... + }, + ... + } + The bucket_name should be the name of the bucket, and the time should be the number of frames in the video. + The probability should be a float between 0 and 1, and the batch_size should be an integer. + If the probability is a tuple, the second value should be the probability to skip to the next time. + """ + + aspect_ratios = {key: get_resolution_with_aspect_ratio(key) for key in bucket_config.keys()} + bucket_probs = OrderedDict() + bucket_bs = OrderedDict() + bucket_names = sorted(bucket_config.keys(), key=lambda x: aspect_ratios[x][0], reverse=True) + + for key in bucket_names: + bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True) + bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names}) + bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names}) + + self.hw_criteria = {k: aspect_ratios[k][0] for k in bucket_names} + self.t_criteria = {k1: {k2: k2 for k2 in bucket_config[k1].keys()} for k1 in bucket_names} + self.ar_criteria = { + k1: {k2: {k3: v3 for k3, v3 in aspect_ratios[k1][1].items()} for k2 in bucket_config[k1].keys()} + for k1 in bucket_names + } + + bucket_id_cnt = num_bucket = 0 + bucket_id = dict() + for k1, v1 in bucket_probs.items(): + bucket_id[k1] = dict() + for k2, _ in v1.items(): + bucket_id[k1][k2] = bucket_id_cnt + bucket_id_cnt += 1 + num_bucket += len(aspect_ratios[k1][1]) + + self.bucket_probs = bucket_probs + self.bucket_bs = bucket_bs + self.bucket_id = bucket_id + self.num_bucket = num_bucket + + log_message("Number of buckets: %s", num_bucket) + + def get_bucket_id( + self, + T: int, + H: int, + W: int, + fps: float, + path: str | None = None, + seed: int | None = None, + fps_max: int = 16, + ) -> tuple[str, int, int] | None: + approx = 0.8 + _, sampling_interval = map_target_fps(fps, fps_max) + T = T // sampling_interval + resolution = H * W + rng = np.random.default_rng(seed) + + # Reference to probabilities and criteria for faster access + bucket_probs = self.bucket_probs + hw_criteria = self.hw_criteria + ar_criteria = self.ar_criteria + + # Start searching for the appropriate bucket + for hw_id, t_criteria in bucket_probs.items(): + # if resolution is too low, skip + if resolution < hw_criteria[hw_id] * approx: + continue + + # if sample is an image + if T == 1: + if 1 in t_criteria: + if rng.random() < t_criteria[1]: + return hw_id, 1, get_closest_ratio(H, W, ar_criteria[hw_id][1]) + continue + + # Look for suitable t_id for video + for t_id, prob in t_criteria.items(): + if T >= t_id and t_id != 1: + # if prob is a tuple, use the second value as the threshold to skip + # to the next t_id + if isinstance(prob, tuple): + next_hw_prob, next_t_prob = prob + if next_t_prob >= 1 or rng.random() <= next_t_prob: + continue + else: + next_hw_prob = prob + if next_hw_prob >= 1 or rng.random() <= next_hw_prob: + ar_id = get_closest_ratio(H, W, ar_criteria[hw_id][t_id]) + return hw_id, t_id, ar_id + else: + break + + return None + + def get_thw(self, bucket_idx: tuple[str, int, int]) -> tuple[int, int, int]: + assert len(bucket_idx) == 3 + T = self.t_criteria[bucket_idx[0]][bucket_idx[1]] + H, W = self.ar_criteria[bucket_idx[0]][bucket_idx[1]][bucket_idx[2]] + return T, H, W + + def get_prob(self, bucket_idx: tuple[str, int]) -> float: + return self.bucket_probs[bucket_idx[0]][bucket_idx[1]] + + def get_batch_size(self, bucket_idx: tuple[str, int]) -> int: + return self.bucket_bs[bucket_idx[0]][bucket_idx[1]] + + def __len__(self) -> int: + return self.num_bucket diff --git a/opensora/datasets/dataloader.py b/opensora/datasets/dataloader.py new file mode 100644 index 0000000..90fabe4 --- /dev/null +++ b/opensora/datasets/dataloader.py @@ -0,0 +1,402 @@ +import collections +import functools +import os +import queue +import random +import threading + +import numpy as np +import torch +import torch.multiprocessing as multiprocessing +from torch._utils import ExceptionWrapper +from torch.distributed import ProcessGroup +from torch.utils.data import DataLoader, _utils +from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL +from torch.utils.data.dataloader import ( + IterDataPipe, + MapDataPipe, + _BaseDataLoaderIter, + _MultiProcessingDataLoaderIter, + _sharding_worker_init_fn, + _SingleProcessDataLoaderIter, +) + +from opensora.acceleration.parallel_states import get_data_parallel_group +from opensora.registry import DATASETS, build_module +from opensora.utils.config import parse_configs +from opensora.utils.logger import create_logger +from opensora.utils.misc import format_duration +from opensora.utils.train import setup_device + +from .datasets import TextDataset, VideoTextDataset +from .pin_memory_cache import PinMemoryCache +from .sampler import DistributedSampler, VariableVideoBatchSampler + + +def _pin_memory_loop( + in_queue, out_queue, device_id, done_event, device, pin_memory_cache: PinMemoryCache, pin_memory_key: str +): + # This setting is thread local, and prevents the copy in pin_memory from + # consuming all CPU cores. + torch.set_num_threads(1) + + if device == "cuda": + torch.cuda.set_device(device_id) + elif device == "xpu": + torch.xpu.set_device(device_id) # type: ignore[attr-defined] + elif device == torch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) + custom_device_mod.set_device(device_id) + + def do_one_step(): + try: + r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + return + idx, data = r + if not done_event.is_set() and not isinstance(data, ExceptionWrapper): + try: + assert isinstance(data, dict) + if pin_memory_key in data: + val = data[pin_memory_key] + pin_memory_value = pin_memory_cache.get(val) + pin_memory_value.copy_(val) + data[pin_memory_key] = pin_memory_value + except Exception: + data = ExceptionWrapper(where=f"in pin memory thread for device {device_id}") + r = (idx, data) + while not done_event.is_set(): + try: + out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) + break + except queue.Full: + continue + + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + while not done_event.is_set(): + # Make sure that we don't preserve any object from one iteration + # to the next + do_one_step() + + +class _MultiProcessingDataLoaderIterForVideo(_MultiProcessingDataLoaderIter): + pin_memory_key: str = "video" + + def __init__(self, loader): + _BaseDataLoaderIter.__init__(self, loader) + self.pin_memory_cache = PinMemoryCache() + + self._prefetch_factor = loader.prefetch_factor + + assert self._num_workers > 0 + assert self._prefetch_factor > 0 + + if loader.multiprocessing_context is None: + multiprocessing_context = multiprocessing + else: + multiprocessing_context = loader.multiprocessing_context + + self._worker_init_fn = loader.worker_init_fn + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Additional worker init function will take care of sharding in MP and Distributed + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + self._worker_init_fn = functools.partial( + _sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank + ) + + # No certainty which module multiprocessing_context is + self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + self._worker_pids_set = False + self._shutdown = False + self._workers_done_event = multiprocessing_context.Event() + + self._index_queues = [] + self._workers = [] + for i in range(self._num_workers): + # No certainty which module multiprocessing_context is + index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + # Need to `cancel_join_thread` here! + # See sections (2) and (3b) above. + index_queue.cancel_join_thread() + w = multiprocessing_context.Process( + target=_utils.worker._worker_loop, + args=( + self._dataset_kind, + self._dataset, + index_queue, + self._worker_result_queue, + self._workers_done_event, + self._auto_collation, + self._collate_fn, + self._drop_last, + self._base_seed, + self._worker_init_fn, + i, + self._num_workers, + self._persistent_workers, + self._shared_seed, + ), + ) + w.daemon = True + # NB: Process.start() actually take some time as it needs to + # start a process and pass the arguments over via a pipe. + # Therefore, we only add a worker to self._workers list after + # it started, so that we do not call .join() if program dies + # before it starts, and __del__ tries to join but will get: + # AssertionError: can only join a started process. + w.start() + self._index_queues.append(index_queue) + self._workers.append(w) + + if self._pin_memory: + self._pin_memory_thread_done_event = threading.Event() + + # Queue is not type-annotated + self._data_queue = queue.Queue() # type: ignore[var-annotated] + if self._pin_memory_device == "xpu": + current_device = torch.xpu.current_device() # type: ignore[attr-defined] + elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) + current_device = custom_device_mod.current_device() + else: + current_device = torch.cuda.current_device() # choose cuda for default + pin_memory_thread = threading.Thread( + target=_pin_memory_loop, + args=( + self._worker_result_queue, + self._data_queue, + current_device, + self._pin_memory_thread_done_event, + self._pin_memory_device, + self.pin_memory_cache, + self.pin_memory_key, + ), + ) + pin_memory_thread.daemon = True + pin_memory_thread.start() + # Similar to workers (see comment above), we only register + # pin_memory_thread once it is started. + self._pin_memory_thread = pin_memory_thread + else: + self._data_queue = self._worker_result_queue # type: ignore[assignment] + + # In some rare cases, persistent workers (daemonic processes) + # would be terminated before `__del__` of iterator is invoked + # when main process exits + # It would cause failure when pin_memory_thread tries to read + # corrupted data from worker_result_queue + # atexit is used to shutdown thread and child processes in the + # right sequence before main process exits + if self._persistent_workers and self._pin_memory: + import atexit + + for w in self._workers: + atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) + + # .pid can be None only before process is spawned (not the case, so ignore) + _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] + _utils.signal_handling._set_SIGCHLD_handler() + self._worker_pids_set = True + self._reset(loader, first_iter=True) + + def remove_cache(self, output_tensor: torch.Tensor): + self.pin_memory_cache.remove(output_tensor) + + def get_cache_info(self) -> str: + return str(self.pin_memory_cache) + + +class DataloaderForVideo(DataLoader): + def _get_iterator(self) -> "_BaseDataLoaderIter": + if self.num_workers == 0: + return _SingleProcessDataLoaderIter(self) + else: + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIterForVideo(self) + + +# Deterministic dataloader +def get_seed_worker(seed): + def seed_worker(worker_id): + worker_seed = seed + if seed is not None: + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return seed_worker + + +def prepare_dataloader( + dataset, + batch_size=None, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + process_group: ProcessGroup | None = None, + bucket_config=None, + num_bucket_build_workers=1, + prefetch_factor=None, + cache_pin_memory=False, + num_groups=1, + **kwargs, +): + _kwargs = kwargs.copy() + if isinstance(dataset, VideoTextDataset): + batch_sampler = VariableVideoBatchSampler( + dataset, + bucket_config, + num_replicas=process_group.size(), + rank=process_group.rank(), + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + verbose=True, + num_bucket_build_workers=num_bucket_build_workers, + num_groups=num_groups, + ) + dl_cls = DataloaderForVideo if cache_pin_memory else DataLoader + return ( + dl_cls( + dataset, + batch_sampler=batch_sampler, + worker_init_fn=get_seed_worker(seed), + pin_memory=pin_memory, + num_workers=num_workers, + collate_fn=collate_fn_default, + prefetch_factor=prefetch_factor, + **_kwargs, + ), + batch_sampler, + ) + elif isinstance(dataset, TextDataset): + if process_group is None: + return ( + DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + worker_init_fn=get_seed_worker(seed), + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + **_kwargs, + ), + None, + ) + else: + sampler = DistributedSampler( + dataset, + num_replicas=process_group.size(), + rank=process_group.rank(), + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + ) + return ( + DataLoader( + dataset, + sampler=sampler, + worker_init_fn=get_seed_worker(seed), + pin_memory=pin_memory, + num_workers=num_workers, + collate_fn=collate_fn_default, + prefetch_factor=prefetch_factor, + **_kwargs, + ), + sampler, + ) + else: + raise ValueError(f"Unsupported dataset type: {type(dataset)}") + + +def collate_fn_default(batch): + # filter out None + batch = [x for x in batch if x is not None] + assert len(batch) > 0, "batch is empty" + + # HACK: for loading text features + use_mask = False + if "mask" in batch[0] and isinstance(batch[0]["mask"], int): + masks = [x.pop("mask") for x in batch] + + texts = [x.pop("text") for x in batch] + texts = torch.cat(texts, dim=1) + use_mask = True + + ret = torch.utils.data.default_collate(batch) + + if use_mask: + ret["mask"] = masks + ret["text"] = texts + return ret + + +def collate_fn_batch(batch): + """ + Used only with BatchDistributedSampler + """ + # filter out None + batch = [x for x in batch if x is not None] + + res = torch.utils.data.default_collate(batch) + + # squeeze the first dimension, which is due to torch.stack() in default_collate() + if isinstance(res, collections.abc.Mapping): + for k, v in res.items(): + if isinstance(v, torch.Tensor): + res[k] = v.squeeze(0) + elif isinstance(res, collections.abc.Sequence): + res = [x.squeeze(0) if isinstance(x, torch.Tensor) else x for x in res] + elif isinstance(res, torch.Tensor): + res = res.squeeze(0) + else: + raise TypeError + + return res + + +if __name__ == "__main__": + # NUM_GPU: number of GPUs for training + # TIME_PER_STEP: time per step in seconds + + # Example usage: + # torchrun --nproc_per_node 1 -m opensora.datasets.dataloader configs/diffusion/train/video_cond.py + cfg = parse_configs() + setup_device() + logger = create_logger() + + # == build dataset == + dataset = build_module(cfg.dataset, DATASETS) + + # == build dataloader == + dataloader_args = dict( + dataset=dataset, + batch_size=cfg.get("batch_size", None), + num_workers=cfg.get("num_workers", 4), + seed=cfg.get("seed", 1024), + shuffle=True, + drop_last=True, + pin_memory=True, + process_group=get_data_parallel_group(), + prefetch_factor=cfg.get("prefetch_factor", None), + ) + dataloader, sampler = prepare_dataloader( + bucket_config=cfg.get("bucket_config", None), + num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1), + **dataloader_args, + ) + num_steps_per_epoch = len(dataloader) + num_machines = int(os.environ.get("NUM_MACHINES", 28)) + num_gpu = num_machines * 8 + logger.info("Number of GPUs: %d", num_gpu) + logger.info("Number of steps per epoch: %d", num_steps_per_epoch // num_gpu) + time_per_step = int(os.environ.get("TIME_PER_STEP", 20)) + time_training = num_steps_per_epoch // num_gpu * time_per_step + logger.info("Time per step: %s", format_duration(time_per_step)) + logger.info("Time for training: %s", format_duration(time_training)) diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py new file mode 100644 index 0000000..2f23304 --- /dev/null +++ b/opensora/datasets/datasets.py @@ -0,0 +1,315 @@ +import os +import random + +import numpy as np +import pandas as pd +import torch +from PIL import ImageFile +from torchvision.datasets.folder import pil_loader + +from opensora.registry import DATASETS + +from .read_video import read_video +from .utils import get_transforms_image, get_transforms_video, is_img, map_target_fps, read_file, temporal_random_crop + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +VALID_KEYS = ("neg", "path") +K = 10000 + + +class Iloc: + def __init__(self, data, sharded_folder, sharded_folders, rows_per_shard): + self.data = data + self.sharded_folder = sharded_folder + self.sharded_folders = sharded_folders + self.rows_per_shard = rows_per_shard + + def __getitem__(self, index): + return Item( + index, + self.data, + self.sharded_folder, + self.sharded_folders, + self.rows_per_shard, + ) + + +class Item: + def __init__(self, index, data, sharded_folder, sharded_folders, rows_per_shard): + self.index = index + self.data = data + self.sharded_folder = sharded_folder + self.sharded_folders = sharded_folders + self.rows_per_shard = rows_per_shard + + def __getitem__(self, key): + index = self.index + if key in self.data.columns: + return self.data[key].iloc[index] + else: + shard_idx = index // self.rows_per_shard + idx = index % self.rows_per_shard + shard_parquet = os.path.join(self.sharded_folder, self.sharded_folders[shard_idx]) + try: + text_parquet = pd.read_parquet(shard_parquet, engine="fastparquet") + path = text_parquet["path"].iloc[idx] + assert path == self.data["path"].iloc[index] + except Exception as e: + print(f"Error reading {shard_parquet}: {e}") + raise + return text_parquet[key].iloc[idx] + + def to_dict(self): + index = self.index + ret = {} + ret.update(self.data.iloc[index].to_dict()) + shard_idx = index // self.rows_per_shard + idx = index % self.rows_per_shard + shard_parquet = os.path.join(self.sharded_folder, self.sharded_folders[shard_idx]) + try: + text_parquet = pd.read_parquet(shard_parquet, engine="fastparquet") + path = text_parquet["path"].iloc[idx] + assert path == self.data["path"].iloc[index] + ret.update(text_parquet.iloc[idx].to_dict()) + except Exception as e: + print(f"Error reading {shard_parquet}: {e}") + ret.update({"text": ""}) + return ret + + +class EfficientParquet: + def __init__(self, df, sharded_folder): + self.data = df + self.total_rows = len(df) + self.rows_per_shard = (self.total_rows + K - 1) // K + self.sharded_folder = sharded_folder + assert os.path.exists(sharded_folder), f"Sharded folder {sharded_folder} does not exist." + self.sharded_folders = os.listdir(sharded_folder) + self.sharded_folders = sorted(self.sharded_folders) + + def __len__(self): + return self.total_rows + + @property + def iloc(self): + return Iloc(self.data, self.sharded_folder, self.sharded_folders, self.rows_per_shard) + + +@DATASETS.register_module("text") +class TextDataset(torch.utils.data.Dataset): + """ + Dataset for text data + """ + + def __init__( + self, + data_path: str = None, + tokenize_fn: callable = None, + fps_max: int = 16, + vmaf: bool = False, + memory_efficient: bool = False, + **kwargs, + ): + self.data_path = data_path + self.data = read_file(data_path, memory_efficient=memory_efficient) + self.memory_efficient = memory_efficient + self.tokenize_fn = tokenize_fn + self.vmaf = vmaf + + if fps_max is not None: + self.fps_max = fps_max + else: + self.fps_max = 999999999 + + def to_efficient(self): + if self.memory_efficient: + addition_data_path = self.data_path.split(".")[0] + self._data = self.data + self.data = EfficientParquet(self._data, addition_data_path) + + def getitem(self, index: int) -> dict: + ret = dict() + sample = self.data.iloc[index].to_dict() + sample_fps = sample.get("fps", np.nan) + new_fps, sampling_interval = map_target_fps(sample_fps, self.fps_max) + ret.update({"sampling_interval": sampling_interval}) + + if "text" in sample: + ret["text"] = sample.pop("text") + postfixs = [] + if new_fps != 0 and self.fps_max < 999: + postfixs.append(f"{new_fps} FPS") + if self.vmaf and "score_vmafmotion" in sample and not np.isnan(sample["score_vmafmotion"]): + postfixs.append(f"{int(sample['score_vmafmotion'] + 0.5)} motion score") + postfix = " " + ", ".join(postfixs) + "." if postfixs else "" + ret["text"] = ret["text"] + postfix + if self.tokenize_fn is not None: + ret.update({k: v.squeeze(0) for k, v in self.tokenize_fn(ret["text"]).items()}) + + if "ref" in sample: # i2v & v2v reference + ret["ref"] = sample.pop("ref") + + # name of the generated sample + if "name" in sample: # sample name (`dataset_idx`) + ret["name"] = sample.pop("name") + else: + ret["index"] = index # use index for name + valid_sample = {k: v for k, v in sample.items() if k in VALID_KEYS} + ret.update(valid_sample) + return ret + + def __getitem__(self, index): + return self.getitem(index) + + def __len__(self): + return len(self.data) + + +@DATASETS.register_module("video_text") +class VideoTextDataset(TextDataset): + def __init__( + self, + transform_name: str = None, + bucket_class: str = "Bucket", + rand_sample_interval: int = None, # random sample_interval value from [1, min(rand_sample_interval, video_allowed_max)] + **kwargs, + ): + super().__init__(**kwargs) + self.transform_name = transform_name + self.bucket_class = bucket_class + self.rand_sample_interval = rand_sample_interval + + def get_image(self, index: int, height: int, width: int) -> dict: + sample = self.data.iloc[index] + path = sample["path"] + # loading + image = pil_loader(path) + + # transform + transform = get_transforms_image(self.transform_name, (height, width)) + image = transform(image) + + # CHW -> CTHW + video = image.unsqueeze(1) + + return {"video": video} + + def get_video(self, index: int, num_frames: int, height: int, width: int, sampling_interval: int) -> dict: + sample = self.data.iloc[index] + path = sample["path"] + + # loading + vframes, vinfo = read_video(path, backend="av") + + if self.rand_sample_interval is not None: + # randomly sample from 1 - self.rand_sample_interval + video_allowed_max = min(len(vframes) // num_frames, self.rand_sample_interval) + sampling_interval = random.randint(1, video_allowed_max) + + # Sampling video frames + video = temporal_random_crop(vframes, num_frames, sampling_interval) + + video = video.clone() + del vframes + + # transform + transform = get_transforms_video(self.transform_name, (height, width)) + video = transform(video) # T C H W + video = video.permute(1, 0, 2, 3) + + ret = {"video": video} + + return ret + + def get_image_or_video(self, index: int, num_frames: int, height: int, width: int, sampling_interval: int) -> dict: + sample = self.data.iloc[index] + path = sample["path"] + + if is_img(path): + return self.get_image(index, height, width) + return self.get_video(index, num_frames, height, width, sampling_interval) + + def getitem(self, index: str) -> dict: + # a hack to pass in the (time, height, width) info from sampler + index, num_frames, height, width = [int(val) for val in index.split("-")] + ret = dict() + ret.update(super().getitem(index)) + try: + ret.update(self.get_image_or_video(index, num_frames, height, width, ret["sampling_interval"])) + except Exception as e: + path = self.data.iloc[index]["path"] + print(f"video {path}: {e}") + return None + return ret + + def __getitem__(self, index): + return self.getitem(index) + + +@DATASETS.register_module("cached_video_text") +class CachedVideoTextDataset(VideoTextDataset): + def __init__( + self, + transform_name: str = None, + bucket_class: str = "Bucket", + rand_sample_interval: int = None, # random sample_interval value from [1, min(rand_sample_interval, video_allowed_max)] + cached_video: bool = False, + cached_text: bool = False, + return_latents_path: bool = False, + load_original_video: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.transform_name = transform_name + self.bucket_class = bucket_class + self.rand_sample_interval = rand_sample_interval + self.cached_video = cached_video + self.cached_text = cached_text + self.return_latents_path = return_latents_path + self.load_original_video = load_original_video + + def get_latents(self, path): + try: + latents = torch.load(path, map_location=torch.device("cpu")) + except Exception as e: + print(f"Error loading latents from {path}: {e}") + return torch.zeros_like(torch.randn(1, 1, 1, 1)) + return latents + + def get_conditioning_latents(self, index: int) -> dict: + sample = self.data.iloc[index] + latents_path = sample["latents_path"] + text_t5_path = sample["text_t5_path"] + text_clip_path = sample["text_clip_path"] + res = dict() + if self.cached_video: + latents = self.get_latents(latents_path) + res["video_latents"] = latents + if self.cached_text: + text_t5 = self.get_latents(text_t5_path) + text_clip = self.get_latents(text_clip_path) + res["text_t5"] = text_t5 + res["text_clip"] = text_clip + if self.return_latents_path: + res["latents_path"] = latents_path + res["text_t5_path"] = text_t5_path + res["text_clip_path"] = text_clip_path + return res + + def getitem(self, index: str) -> dict: + # a hack to pass in the (time, height, width) info from sampler + real_index, num_frames, height, width = [int(val) for val in index.split("-")] + ret = dict() + if self.load_original_video: + ret.update(super().getitem(index)) + try: + ret.update(self.get_conditioning_latents(real_index)) + except Exception as e: + path = self.data.iloc[real_index]["path"] + print(f"video {path}: {e}") + return None + return ret + + def __getitem__(self, index): + return self.getitem(index) diff --git a/opensora/datasets/parallel.py b/opensora/datasets/parallel.py new file mode 100644 index 0000000..7449343 --- /dev/null +++ b/opensora/datasets/parallel.py @@ -0,0 +1,176 @@ +import multiprocessing +from itertools import count +from multiprocessing.managers import SyncManager +from typing import Any, Callable, Dict, Tuple, Type, cast + +import dill +import pandarallel +import pandas as pd +from pandarallel.data_types import DataType +from pandarallel.progress_bars import ProgressBarsType, get_progress_bars, progress_wrapper +from pandarallel.utils import WorkerStatus + +CONTEXT = multiprocessing.get_context("fork") +TMP = [] + + +class WrapWorkFunctionForPipe: + def __init__( + self, + work_function: Callable[ + [ + Any, + Callable, + tuple, + Dict[str, Any], + Dict[str, Any], + ], + Any, + ], + ) -> None: + self.work_function = work_function + + def __call__( + self, + progress_bars_type: ProgressBarsType, + worker_index: int, + master_workers_queue: multiprocessing.Queue, + dilled_user_defined_function: bytes, + user_defined_function_args: tuple, + user_defined_function_kwargs: Dict[str, Any], + extra: Dict[str, Any], + ) -> Any: + try: + data = TMP[worker_index] + data_size = len(data) + user_defined_function: Callable = dill.loads(dilled_user_defined_function) + + progress_wrapped_user_defined_function = progress_wrapper( + user_defined_function, master_workers_queue, worker_index, data_size + ) + + used_user_defined_function = ( + progress_wrapped_user_defined_function + if progress_bars_type + in ( + ProgressBarsType.InUserDefinedFunction, + ProgressBarsType.InUserDefinedFunctionMultiplyByNumberOfColumns, + ) + else user_defined_function + ) + + results = self.work_function( + data, + used_user_defined_function, + user_defined_function_args, + user_defined_function_kwargs, + extra, + ) + + master_workers_queue.put((worker_index, WorkerStatus.Success, None)) + + return results + + except: + master_workers_queue.put((worker_index, WorkerStatus.Error, None)) + raise + + +def parallelize_with_pipe( + nb_requested_workers: int, + data_type: Type[DataType], + progress_bars_type: ProgressBarsType, +): + def closure( + data: Any, + user_defined_function: Callable, + *user_defined_function_args: tuple, + **user_defined_function_kwargs: Dict[str, Any], + ): + wrapped_work_function = WrapWorkFunctionForPipe(data_type.work) + dilled_user_defined_function = dill.dumps(user_defined_function) + manager: SyncManager = CONTEXT.Manager() + master_workers_queue = manager.Queue() + + chunks = list( + data_type.get_chunks( + nb_requested_workers, + data, + user_defined_function_kwargs=user_defined_function_kwargs, + ) + ) + TMP.extend(chunks) + + nb_workers = len(chunks) + + multiplicator_factor = ( + len(cast(pd.DataFrame, data).columns) + if progress_bars_type == ProgressBarsType.InUserDefinedFunctionMultiplyByNumberOfColumns + else 1 + ) + + progresses_length = [len(chunk_) * multiplicator_factor for chunk_ in chunks] + + work_extra = data_type.get_work_extra(data) + reduce_extra = data_type.get_reduce_extra(data, user_defined_function_kwargs) + + show_progress_bars = progress_bars_type != ProgressBarsType.No + + progress_bars = get_progress_bars(progresses_length, show_progress_bars) + progresses = [0] * nb_workers + workers_status = [WorkerStatus.Running] * nb_workers + + work_args_list = [ + ( + progress_bars_type, + worker_index, + master_workers_queue, + dilled_user_defined_function, + user_defined_function_args, + user_defined_function_kwargs, + { + **work_extra, + **{ + "master_workers_queue": master_workers_queue, + "show_progress_bars": show_progress_bars, + "worker_index": worker_index, + }, + }, + ) + for worker_index in range(nb_workers) + ] + + pool = CONTEXT.Pool(nb_workers) + results_promise = pool.starmap_async(wrapped_work_function, work_args_list) + pool.close() + + generation = count() + + while any((worker_status == WorkerStatus.Running for worker_status in workers_status)): + message: Tuple[int, WorkerStatus, Any] = master_workers_queue.get() + worker_index, worker_status, payload = message + workers_status[worker_index] = worker_status + + if worker_status == WorkerStatus.Success: + progresses[worker_index] = progresses_length[worker_index] + progress_bars.update(progresses) + elif worker_status == WorkerStatus.Running: + progress = cast(int, payload) + progresses[worker_index] = progress + + if next(generation) % nb_workers == 0: + progress_bars.update(progresses) + elif worker_status == WorkerStatus.Error: + progress_bars.set_error(worker_index) + + results = results_promise.get() + TMP.clear() + + return data_type.reduce(results, reduce_extra) + + return closure + + +pandarallel.core.WrapWorkFunctionForPipe = WrapWorkFunctionForPipe +pandarallel.core.parallelize_with_pipe = parallelize_with_pipe +pandarallel = pandarallel.pandarallel diff --git a/opensora/datasets/pin_memory_cache.py b/opensora/datasets/pin_memory_cache.py new file mode 100644 index 0000000..3f6e755 --- /dev/null +++ b/opensora/datasets/pin_memory_cache.py @@ -0,0 +1,76 @@ +import threading +from typing import Dict, List, Optional + +import torch + + +class PinMemoryCache: + force_dtype: Optional[torch.dtype] = None + min_cache_numel: int = 0 + pre_alloc_numels: List[int] = [] + + def __init__(self): + self.cache: Dict[int, torch.Tensor] = {} + self.output_to_cache: Dict[int, int] = {} + self.cache_to_output: Dict[int, int] = {} + self.lock = threading.Lock() + self.total_cnt = 0 + self.hit_cnt = 0 + + if len(self.pre_alloc_numels) > 0 and self.force_dtype is not None: + for n in self.pre_alloc_numels: + cache_tensor = torch.empty(n, dtype=self.force_dtype, device="cpu", pin_memory=True) + with self.lock: + self.cache[id(cache_tensor)] = cache_tensor + + def get(self, tensor: torch.Tensor) -> torch.Tensor: + """Receive a cpu tensor and return the corresponding pinned tensor. Note that this only manage memory allocation, doesn't copy content. + + Args: + tensor (torch.Tensor): The tensor to be pinned. + + Returns: + torch.Tensor: The pinned tensor. + """ + self.total_cnt += 1 + with self.lock: + # find free cache + for cache_id, cache_tensor in self.cache.items(): + if cache_id not in self.cache_to_output and cache_tensor.numel() >= tensor.numel(): + target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape) + out_id = id(target_cache_tensor) + self.output_to_cache[out_id] = cache_id + self.cache_to_output[cache_id] = out_id + self.hit_cnt += 1 + return target_cache_tensor + # no free cache, create a new one + dtype = self.force_dtype if self.force_dtype is not None else tensor.dtype + cache_numel = max(tensor.numel(), self.min_cache_numel) + cache_tensor = torch.empty(cache_numel, dtype=dtype, device="cpu", pin_memory=True) + target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape) + out_id = id(target_cache_tensor) + with self.lock: + self.cache[id(cache_tensor)] = cache_tensor + self.output_to_cache[out_id] = id(cache_tensor) + self.cache_to_output[id(cache_tensor)] = out_id + return target_cache_tensor + + def remove(self, output_tensor: torch.Tensor) -> None: + """Release corresponding cache tensor. + + Args: + output_tensor (torch.Tensor): The tensor to be released. + """ + out_id = id(output_tensor) + with self.lock: + if out_id not in self.output_to_cache: + raise ValueError("Tensor not found in cache.") + cache_id = self.output_to_cache.pop(out_id) + del self.cache_to_output[cache_id] + + def __str__(self): + with self.lock: + num_cached = len(self.cache) + num_used = len(self.output_to_cache) + total_cache_size = sum([v.numel() * v.element_size() for v in self.cache.values()]) + return f"PinMemoryCache(num_cached={num_cached}, num_used={num_used}, total_cache_size={total_cache_size / 1024**3:.2f} GB, hit rate={self.hit_cnt / self.total_cnt:.2f})" diff --git a/opensora/datasets/read_video.py b/opensora/datasets/read_video.py new file mode 100644 index 0000000..08454c6 --- /dev/null +++ b/opensora/datasets/read_video.py @@ -0,0 +1,257 @@ +import gc +import math +import os +import re +import warnings +from fractions import Fraction + +import av +import cv2 +import numpy as np +import torch +from torchvision import get_video_backend +from torchvision.io.video import _check_av_available + +MAX_NUM_FRAMES = 2500 + + +def read_video_av( + filename: str, + start_pts: float | Fraction = 0, + end_pts: float | Fraction | None = None, + pts_unit: str = "pts", + output_format: str = "THWC", +) -> tuple[torch.Tensor, torch.Tensor, dict]: + """ + Reads a video from a file, returning both the video frames and the audio frames + + This method is modified from torchvision.io.video.read_video, with the following changes: + + 1. will not extract audio frames and return empty for aframes + 2. remove checks and only support pyav + 3. add container.close() and gc.collect() to avoid thread leakage + 4. try our best to avoid memory leak + + Args: + filename (str): path to the video file + start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): + The start presentation time of the video + end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): + The end presentation time + pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, + either 'pts' or 'sec'. Defaults to 'pts'. + output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". + + Returns: + vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames + aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points + info (dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) + """ + # format + output_format = output_format.upper() + if output_format not in ("THWC", "TCHW"): + raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") + # file existence + if not os.path.exists(filename): + raise RuntimeError(f"File not found: {filename}") + # backend check + assert get_video_backend() == "pyav", "pyav backend is required for read_video_av" + _check_av_available() + # end_pts check + if end_pts is None: + end_pts = float("inf") + if end_pts < start_pts: + raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") + + # == get video info == + info = {} + # TODO: creating an container leads to memory leak (1G for 8 workers 1 GPU) + container = av.open(filename, metadata_errors="ignore") + # fps + video_fps = container.streams.video[0].average_rate + # guard against potentially corrupted files + if video_fps is not None: + info["video_fps"] = float(video_fps) + iter_video = container.decode(**{"video": 0}) + frame = next(iter_video).to_rgb().to_ndarray() + height, width = frame.shape[:2] + total_frames = container.streams.video[0].frames + if total_frames == 0: + total_frames = MAX_NUM_FRAMES + warnings.warn(f"total_frames is 0, using {MAX_NUM_FRAMES} as a fallback") + container.close() + del container + + # HACK: must create before iterating stream + # use np.zeros will not actually allocate memory + # use np.ones will lead to a little memory leak + video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8) + + # == read == + try: + # TODO: The reading has memory leak (4G for 8 workers 1 GPU) + container = av.open(filename, metadata_errors="ignore") + assert container.streams.video is not None + video_frames = _read_from_stream( + video_frames, + container, + start_pts, + end_pts, + pts_unit, + container.streams.video[0], + {"video": 0}, + filename=filename, + ) + except av.AVError as e: + print(f"[Warning] Error while reading video {filename}: {e}") + + vframes = torch.from_numpy(video_frames).clone() + del video_frames + if output_format == "TCHW": + # [T,H,W,C] --> [T,C,H,W] + vframes = vframes.permute(0, 3, 1, 2) + + aframes = torch.empty((1, 0), dtype=torch.float32) + return vframes, aframes, info + + +def _read_from_stream( + video_frames, + container: "av.container.Container", + start_offset: float, + end_offset: float, + pts_unit: str, + stream: "av.stream.Stream", + stream_name: dict[str, int | tuple[int, ...] | list[int] | None], + filename: str | None = None, +) -> list["av.frame.Frame"]: + if pts_unit == "sec": + # TODO: we should change all of this from ground up to simply take + # sec and convert to MS in C++ + start_offset = int(math.floor(start_offset * (1 / stream.time_base))) + if end_offset != float("inf"): + end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) + else: + warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.") + + should_buffer = True + max_buffer_size = 5 + if stream.type == "video": + # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt) + # so need to buffer some extra frames to sort everything + # properly + extradata = stream.codec_context.extradata + # overly complicated way of finding if `divx_packed` is set, following + # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263 + if extradata and b"DivX" in extradata: + # can't use regex directly because of some weird characters sometimes... + pos = extradata.find(b"DivX") + d = extradata[pos:] + o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d) + if o is None: + o = re.search(rb"DivX(\d+)b(\d+)(\w)", d) + if o is not None: + should_buffer = o.group(3) == b"p" + seek_offset = start_offset + # some files don't seek to the right location, so better be safe here + seek_offset = max(seek_offset - 1, 0) + if should_buffer: + # FIXME this is kind of a hack, but we will jump to the previous keyframe + # so this will be safe + seek_offset = max(seek_offset - max_buffer_size, 0) + try: + # TODO check if stream needs to always be the video stream here or not + container.seek(seek_offset, any_frame=False, backward=True, stream=stream) + except av.AVError as e: + print(f"[Warning] Error while seeking video {filename}: {e}") + return [] + + # == main == + buffer_count = 0 + frames_pts = [] + cnt = 0 + try: + for _idx, frame in enumerate(container.decode(**stream_name)): + frames_pts.append(frame.pts) + video_frames[cnt] = frame.to_rgb().to_ndarray() + cnt += 1 + if cnt >= len(video_frames): + break + if frame.pts >= end_offset: + if should_buffer and buffer_count < max_buffer_size: + buffer_count += 1 + continue + break + except av.AVError as e: + print(f"[Warning] Error while reading video {filename}: {e}") + + # garbage collection for thread leakage + container.close() + del container + # NOTE: manually garbage collect to close pyav threads + gc.collect() + + # ensure that the results are sorted wrt the pts + # NOTE: here we assert frames_pts is sorted + start_ptr = 0 + end_ptr = cnt + while start_ptr < end_ptr and frames_pts[start_ptr] < start_offset: + start_ptr += 1 + while start_ptr < end_ptr and frames_pts[end_ptr - 1] > end_offset: + end_ptr -= 1 + if start_offset > 0 and start_offset not in frames_pts[start_ptr:end_ptr]: + # if there is no frame that exactly matches the pts of start_offset + # add the last frame smaller than start_offset, to guarantee that + # we will have all the necessary data. This is most useful for audio + if start_ptr > 0: + start_ptr -= 1 + result = video_frames[start_ptr:end_ptr].copy() + return result + + +def read_video_cv2(video_path): + cap = cv2.VideoCapture(video_path) + + if not cap.isOpened(): + # print("Error: Unable to open video") + raise ValueError + else: + fps = cap.get(cv2.CAP_PROP_FPS) + vinfo = { + "video_fps": fps, + } + + frames = [] + while True: + # Read a frame from the video + ret, frame = cap.read() + + # If frame is not read correctly, break the loop + if not ret: + break + + frames.append(frame[:, :, ::-1]) # BGR to RGB + + # Exit if 'q' is pressed + if cv2.waitKey(25) & 0xFF == ord("q"): + break + + # Release the video capture object and close all windows + cap.release() + cv2.destroyAllWindows() + + frames = np.stack(frames) + frames = torch.from_numpy(frames) # [T, H, W, C=3] + frames = frames.permute(0, 3, 1, 2) + return frames, vinfo + + +def read_video(video_path, backend="av"): + if backend == "cv2": + vframes, vinfo = read_video_cv2(video_path) + elif backend == "av": + vframes, _, vinfo = read_video_av(filename=video_path, pts_unit="sec", output_format="TCHW") + else: + raise ValueError + + return vframes, vinfo diff --git a/opensora/datasets/sampler.py b/opensora/datasets/sampler.py new file mode 100644 index 0000000..64fa367 --- /dev/null +++ b/opensora/datasets/sampler.py @@ -0,0 +1,393 @@ +from collections import OrderedDict, defaultdict +from typing import Iterator + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, DistributedSampler + +from opensora.utils.logger import log_message +from opensora.utils.misc import format_numel_str + +from .aspect import get_num_pexels_from_name +from .bucket import Bucket +from .datasets import VideoTextDataset +from .parallel import pandarallel +from .utils import sync_object_across_devices + + +# use pandarallel to accelerate bucket processing +# NOTE: pandarallel should only access local variables +def apply(data, method=None, seed=None, num_bucket=None, fps_max=16): + return method( + data["num_frames"], + data["height"], + data["width"], + data["fps"], + data["path"], + seed + data["id"] * num_bucket, + fps_max, + ) + + +class StatefulDistributedSampler(DistributedSampler): + def __init__( + self, + dataset: Dataset, + num_replicas: int | None = None, + rank: int | None = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + self.start_index: int = 0 + + def __iter__(self) -> Iterator: + iterator = super().__iter__() + indices = list(iterator) + indices = indices[self.start_index :] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.start_index + + def reset(self) -> None: + self.start_index = 0 + + def state_dict(self, step) -> dict: + return {"start_index": step} + + def load_state_dict(self, state_dict: dict) -> None: + self.__dict__.update(state_dict) + + +class VariableVideoBatchSampler(DistributedSampler): + def __init__( + self, + dataset: VideoTextDataset, + bucket_config: dict, + num_replicas: int | None = None, + rank: int | None = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + verbose: bool = False, + num_bucket_build_workers: int = 1, + num_groups: int = 1, + ) -> None: + super().__init__( + dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last + ) + self.dataset = dataset + assert dataset.bucket_class == "Bucket", "Only support Bucket class for now" + self.bucket = Bucket(bucket_config) + self.verbose = verbose + self.last_micro_batch_access_index = 0 + self.num_bucket_build_workers = num_bucket_build_workers + self._cached_bucket_sample_dict = None + self._cached_num_total_batch = None + self.num_groups = num_groups + + if dist.get_rank() == 0: + pandarallel.initialize( + nb_workers=self.num_bucket_build_workers, + progress_bar=False, + verbose=0, + use_memory_fs=False, + ) + + def __iter__(self) -> Iterator[list[int]]: + bucket_sample_dict, _ = self.group_by_bucket() + self.clear_cache() + + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + bucket_micro_batch_count = OrderedDict() + bucket_last_consumed = OrderedDict() + + # process the samples + for bucket_id, data_list in bucket_sample_dict.items(): + # handle droplast + bs_per_gpu = self.bucket.get_batch_size(bucket_id) + remainder = len(data_list) % bs_per_gpu + + if remainder > 0: + if not self.drop_last: + # if there is remainder, we pad to make it divisible + data_list += data_list[: bs_per_gpu - remainder] + else: + # we just drop the remainder to make it divisible + data_list = data_list[:-remainder] + bucket_sample_dict[bucket_id] = data_list + + # handle shuffle + if self.shuffle: + data_indices = torch.randperm(len(data_list), generator=g).tolist() + data_list = [data_list[i] for i in data_indices] + bucket_sample_dict[bucket_id] = data_list + + # compute how many micro-batches each bucket has + num_micro_batches = len(data_list) // bs_per_gpu + bucket_micro_batch_count[bucket_id] = num_micro_batches + + # compute the bucket access order + # each bucket may have more than one batch of data + # thus bucket_id may appear more than 1 time + bucket_id_access_order = [] + for bucket_id, num_micro_batch in bucket_micro_batch_count.items(): + bucket_id_access_order.extend([bucket_id] * num_micro_batch) + + # randomize the access order + if self.shuffle: + bucket_id_access_order_indices = torch.randperm(len(bucket_id_access_order), generator=g).tolist() + bucket_id_access_order = [bucket_id_access_order[i] for i in bucket_id_access_order_indices] + + # make the number of bucket accesses divisible by dp size + remainder = len(bucket_id_access_order) % self.num_replicas + if remainder > 0: + if self.drop_last: + bucket_id_access_order = bucket_id_access_order[: len(bucket_id_access_order) - remainder] + else: + bucket_id_access_order += bucket_id_access_order[: self.num_replicas - remainder] + + # prepare each batch from its bucket + # according to the predefined bucket access order + num_iters = len(bucket_id_access_order) // self.num_replicas + start_iter_idx = self.last_micro_batch_access_index // self.num_replicas + + # re-compute the micro-batch consumption + # this is useful when resuming from a state dict with a different number of GPUs + self.last_micro_batch_access_index = start_iter_idx * self.num_replicas + for i in range(self.last_micro_batch_access_index): + bucket_id = bucket_id_access_order[i] + bucket_bs = self.bucket.get_batch_size(bucket_id) + if bucket_id in bucket_last_consumed: + bucket_last_consumed[bucket_id] += bucket_bs + else: + bucket_last_consumed[bucket_id] = bucket_bs + + for i in range(start_iter_idx, num_iters): + bucket_access_list = bucket_id_access_order[i * self.num_replicas : (i + 1) * self.num_replicas] + self.last_micro_batch_access_index += self.num_replicas + + # compute the data samples consumed by each access + bucket_access_boundaries = [] + for bucket_id in bucket_access_list: + bucket_bs = self.bucket.get_batch_size(bucket_id) + last_consumed_index = bucket_last_consumed.get(bucket_id, 0) + bucket_access_boundaries.append([last_consumed_index, last_consumed_index + bucket_bs]) + + # update consumption + if bucket_id in bucket_last_consumed: + bucket_last_consumed[bucket_id] += bucket_bs + else: + bucket_last_consumed[bucket_id] = bucket_bs + + # compute the range of data accessed by each GPU + bucket_id = bucket_access_list[self.rank] + boundary = bucket_access_boundaries[self.rank] + cur_micro_batch = bucket_sample_dict[bucket_id][boundary[0] : boundary[1]] + + # encode t, h, w into the sample index + real_t, real_h, real_w = self.bucket.get_thw(bucket_id) + cur_micro_batch = [f"{idx}-{real_t}-{real_h}-{real_w}" for idx in cur_micro_batch] + yield cur_micro_batch + + self.reset() + + def __len__(self) -> int: + return self.get_num_batch() // self.num_groups + + def get_num_batch(self) -> int: + _, num_total_batch = self.group_by_bucket() + return num_total_batch + + def clear_cache(self): + self._cached_bucket_sample_dict = None + self._cached_num_total_batch = 0 + + def group_by_bucket(self) -> dict: + """ + Group the dataset samples into buckets. + This method will set `self._cached_bucket_sample_dict` to the bucket sample dict. + + Returns: + dict: a dictionary with bucket id as key and a list of sample indices as value + """ + if self._cached_bucket_sample_dict is not None: + return self._cached_bucket_sample_dict, self._cached_num_total_batch + + # use pandarallel to accelerate bucket processing + log_message("Building buckets using %d workers...", self.num_bucket_build_workers) + bucket_ids = None + if dist.get_rank() == 0: + data = self.dataset.data.copy(deep=True) + data["id"] = data.index + bucket_ids = data.parallel_apply( + apply, + axis=1, + method=self.bucket.get_bucket_id, + seed=self.seed + self.epoch, + num_bucket=self.bucket.num_bucket, + fps_max=self.dataset.fps_max, + ) + dist.barrier() + bucket_ids = sync_object_across_devices(bucket_ids) + dist.barrier() + + # group by bucket + # each data sample is put into a bucket with a similar image/video size + bucket_sample_dict = defaultdict(list) + bucket_ids_np = np.array(bucket_ids) + valid_indices = np.where(bucket_ids_np != None)[0] + for i in valid_indices: + bucket_sample_dict[bucket_ids_np[i]].append(i) + + # cache the bucket sample dict + self._cached_bucket_sample_dict = bucket_sample_dict + + # num total batch + num_total_batch = self.print_bucket_info(bucket_sample_dict) + self._cached_num_total_batch = num_total_batch + + return bucket_sample_dict, num_total_batch + + def print_bucket_info(self, bucket_sample_dict: dict) -> int: + # collect statistics + num_total_samples = num_total_batch = 0 + num_total_img_samples = num_total_vid_samples = 0 + num_total_img_batch = num_total_vid_batch = 0 + num_total_vid_batch_256 = num_total_vid_batch_768 = 0 + num_aspect_dict = defaultdict(lambda: [0, 0]) + num_hwt_dict = defaultdict(lambda: [0, 0]) + for k, v in bucket_sample_dict.items(): + size = len(v) + num_batch = size // self.bucket.get_batch_size(k[:-1]) + + num_total_samples += size + num_total_batch += num_batch + + if k[1] == 1: + num_total_img_samples += size + num_total_img_batch += num_batch + else: + if k[0] == "256px": + num_total_vid_batch_256 += num_batch + elif k[0] == "768px": + num_total_vid_batch_768 += num_batch + num_total_vid_samples += size + num_total_vid_batch += num_batch + + num_aspect_dict[k[-1]][0] += size + num_aspect_dict[k[-1]][1] += num_batch + num_hwt_dict[k[:-1]][0] += size + num_hwt_dict[k[:-1]][1] += num_batch + + # sort + num_aspect_dict = dict(sorted(num_aspect_dict.items(), key=lambda x: x[0])) + num_hwt_dict = dict( + sorted(num_hwt_dict.items(), key=lambda x: (get_num_pexels_from_name(x[0][0]), x[0][1]), reverse=True) + ) + num_hwt_img_dict = {k: v for k, v in num_hwt_dict.items() if k[1] == 1} + num_hwt_vid_dict = {k: v for k, v in num_hwt_dict.items() if k[1] > 1} + + # log + if dist.get_rank() == 0 and self.verbose: + log_message("Bucket Info:") + log_message("Bucket [#sample, #batch] by aspect ratio:") + for k, v in num_aspect_dict.items(): + log_message("(%s): #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1])) + log_message("===== Image Info =====") + log_message("Image Bucket by HxWxT:") + for k, v in num_hwt_img_dict.items(): + log_message("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1])) + log_message("--------------------------------") + log_message( + "#image sample: %s, #image batch: %s", + format_numel_str(num_total_img_samples), + format_numel_str(num_total_img_batch), + ) + log_message("===== Video Info =====") + log_message("Video Bucket by HxWxT:") + for k, v in num_hwt_vid_dict.items(): + log_message("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1])) + log_message("--------------------------------") + log_message( + "#video sample: %s, #video batch: %s", + format_numel_str(num_total_vid_samples), + format_numel_str(num_total_vid_batch), + ) + log_message("===== Summary =====") + log_message("#non-empty buckets: %s", len(bucket_sample_dict)) + log_message( + "Img/Vid sample ratio: %.2f", + num_total_img_samples / num_total_vid_samples if num_total_vid_samples > 0 else 0, + ) + log_message( + "Img/Vid batch ratio: %.2f", num_total_img_batch / num_total_vid_batch if num_total_vid_batch > 0 else 0 + ) + log_message( + "vid batch 256: %s, vid batch 768: %s", format_numel_str(num_total_vid_batch_256), format_numel_str(num_total_vid_batch_768) + ) + log_message( + "Vid batch ratio (256px/768px): %.2f", num_total_vid_batch_256 / num_total_vid_batch_768 if num_total_vid_batch_768 > 0 else 0 + ) + log_message( + "#training sample: %s, #training batch: %s", + format_numel_str(num_total_samples), + format_numel_str(num_total_batch), + ) + return num_total_batch + + def reset(self): + self.last_micro_batch_access_index = 0 + + def set_step(self, start_step: int): + self.last_micro_batch_access_index = start_step * self.num_replicas + + def state_dict(self, num_steps: int) -> dict: + # the last_micro_batch_access_index in the __iter__ is often + # not accurate during multi-workers and data prefetching + # thus, we need the user to pass the actual steps which have been executed + # to calculate the correct last_micro_batch_access_index + return {"seed": self.seed, "epoch": self.epoch, "last_micro_batch_access_index": num_steps * self.num_replicas} + + def load_state_dict(self, state_dict: dict) -> None: + self.__dict__.update(state_dict) + + +class BatchDistributedSampler(DistributedSampler): + """ + Used with BatchDataset; + Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then + | buffer {i} | buffer {i+1} + ------ | ------------------- | ------------------- + rank 0 | 0, 1, 2, 3, 4, | 5, 6, 7, 8, 9 + rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19 + rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29 + """ + + def __init__(self, dataset: Dataset, **kwargs): + super().__init__(dataset, **kwargs) + self.start_index = 0 + + def __iter__(self): + num_buffers = self.dataset.num_buffers + len_buffer = self.dataset.len_buffer + num_buffers_i = num_buffers // self.num_replicas + num_samples_i = len_buffer * num_buffers_i + + indices_i = np.arange(self.start_index, num_samples_i) + self.rank * num_samples_i + indices_i = indices_i.tolist() + + return iter(indices_i) + + def reset(self): + self.start_index = 0 + + def state_dict(self, step) -> dict: + return {"start_index": step} + + def load_state_dict(self, state_dict: dict): + self.start_index = state_dict["start_index"] + 1 diff --git a/opensora/datasets/utils.py b/opensora/datasets/utils.py new file mode 100644 index 0000000..e580735 --- /dev/null +++ b/opensora/datasets/utils.py @@ -0,0 +1,419 @@ +import math +import os +import random +import re +from typing import Any + +import numpy as np +import pandas as pd +import requests +import torch +import torch.distributed as dist +import torchvision +import torchvision.transforms as transforms +from PIL import Image +from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader +from torchvision.io import write_video +from torchvision.utils import save_image + +from . import video_transforms +from .read_video import read_video + +try: + import dask.dataframe as dd + + SUPPORT_DASK = True +except: + SUPPORT_DASK = False + +VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") + +regex = re.compile( + r"^(?:http|ftp)s?://" # http:// or https:// + r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain... + r"localhost|" # localhost... + r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip + r"(?::\d+)?" # optional port + r"(?:/?|[/?]\S+)$", + re.IGNORECASE, +) + + +def is_img(path): + ext = os.path.splitext(path)[-1].lower() + return ext in IMG_EXTENSIONS + + +def is_vid(path): + ext = os.path.splitext(path)[-1].lower() + return ext in VID_EXTENSIONS + + +def is_url(url): + return re.match(regex, url) is not None + + +def read_file(input_path, memory_efficient=False): + if input_path.endswith(".csv"): + assert not memory_efficient, "Memory efficient mode is not supported for CSV files" + return pd.read_csv(input_path) + elif input_path.endswith(".parquet"): + columns = None + if memory_efficient: + columns = ["path", "num_frames", "height", "width", "aspect_ratio", "fps", "resolution"] + if SUPPORT_DASK: + ret = dd.read_parquet(input_path, columns=columns).compute() + else: + ret = pd.read_parquet(input_path, columns=columns) + return ret + else: + raise NotImplementedError(f"Unsupported file format: {input_path}") + + +def download_url(input_path): + output_dir = "cache" + os.makedirs(output_dir, exist_ok=True) + base_name = os.path.basename(input_path) + output_path = os.path.join(output_dir, base_name) + img_data = requests.get(input_path).content + with open(output_path, "wb", encoding="utf-8") as handler: + handler.write(img_data) + print(f"URL {input_path} downloaded to {output_path}") + return output_path + + +def temporal_random_crop( + vframes: torch.Tensor, num_frames: int, frame_interval: int, return_frame_indices: bool = False +) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]: + temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval) + total_frames = len(vframes) + start_frame_ind, end_frame_ind = temporal_sample(total_frames) + + assert ( + end_frame_ind - start_frame_ind >= num_frames + ), f"Not enough frames to sample, {end_frame_ind} - {start_frame_ind} < {num_frames}" + + frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int) + video = vframes[frame_indices] + if return_frame_indices: + return video, frame_indices + else: + return video + + +def get_transforms_video(name="center", image_size=(256, 256)): + if name is None: + return None + elif name == "center": + assert image_size[0] == image_size[1], "image_size must be square for center crop" + transform_video = transforms.Compose( + [ + video_transforms.ToTensorVideo(), # TCHW + # video_transforms.RandomHorizontalFlipVideo(), + video_transforms.UCFCenterCropVideo(image_size[0]), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + elif name == "resize_crop": + transform_video = transforms.Compose( + [ + video_transforms.ToTensorVideo(), # TCHW + video_transforms.ResizeCrop(image_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + elif name == "rand_size_crop": + transform_video = transforms.Compose( + [ + video_transforms.ToTensorVideo(), # TCHW + video_transforms.RandomSizedCrop(image_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + else: + raise NotImplementedError(f"Transform {name} not implemented") + return transform_video + + +def get_transforms_image(name="center", image_size=(256, 256)): + if name is None: + return None + elif name == "center": + assert image_size[0] == image_size[1], "Image size must be square for center crop" + transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])), + # transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + elif name == "resize_crop": + transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + elif name == "rand_size_crop": + transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: rand_size_crop_arr(pil_image, image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + else: + raise NotImplementedError(f"Transform {name} not implemented") + return transform + + +def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)): + image = pil_loader(path) + if transform is None: + transform = get_transforms_image(image_size=image_size, name=transform_name) + image = transform(image) + video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1) + video = video.permute(1, 0, 2, 3) + return video + + +def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)): + vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") + if transform is None: + transform = get_transforms_video(image_size=image_size, name=transform_name) + video = transform(vframes) # T C H W + video = video.permute(1, 0, 2, 3) + return video + + +def read_from_path(path, image_size, transform_name="center"): + if is_url(path): + path = download_url(path) + ext = os.path.splitext(path)[-1].lower() + if ext.lower() in VID_EXTENSIONS: + return read_video_from_path(path, image_size=image_size, transform_name=transform_name) + else: + assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}" + return read_image_from_path(path, image_size=image_size, transform_name=transform_name) + + +def save_sample( + x, + save_path=None, + fps=8, + normalize=True, + value_range=(-1, 1), + force_video=False, + verbose=True, + crf=23, +): + """ + Args: + x (Tensor): shape [C, T, H, W] + """ + assert x.ndim == 4 + + if not force_video and x.shape[1] == 1: # T = 1: save as image + save_path += ".png" + x = x.squeeze(1) + save_image([x], save_path, normalize=normalize, value_range=value_range) + else: + save_path += ".mp4" + if normalize: + low, high = value_range + x.clamp_(min=low, max=high) + x.sub_(low).div_(max(high - low, 1e-5)) + + x = x.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8) + + write_video(save_path, x, fps=fps, video_codec="h264", options={"crf": str(crf)}) + if verbose: + print(f"Saved to {save_path}") + return save_path + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) + + +def rand_size_crop_arr(pil_image, image_size): + """ + Randomly crop image for height and width, ranging from image_size[0] to image_size[1] + """ + arr = np.array(pil_image) + + # get random target h w + height = random.randint(image_size[0], image_size[1]) + width = random.randint(image_size[0], image_size[1]) + # ensure that h w are factors of 8 + height = height - height % 8 + width = width - width % 8 + + # get random start pos + h_start = random.randint(0, max(len(arr) - height, 0)) + w_start = random.randint(0, max(len(arr[0]) - height, 0)) + + # crop + return Image.fromarray(arr[h_start : h_start + height, w_start : w_start + width]) + + +def resize_crop_to_fill(pil_image, image_size): + w, h = pil_image.size # PIL is (W, H) + th, tw = image_size + rh, rw = th / h, tw / w + if rh > rw: + sh, sw = th, round(w * rh) + image = pil_image.resize((sw, sh), Image.BICUBIC) + i = 0 + j = int(round((sw - tw) / 2.0)) + else: + sh, sw = round(h * rw), tw + image = pil_image.resize((sw, sh), Image.BICUBIC) + i = int(round((sh - th) / 2.0)) + j = 0 + arr = np.array(image) + assert i + th <= arr.shape[0] and j + tw <= arr.shape[1] + return Image.fromarray(arr[i : i + th, j : j + tw]) + + +def map_target_fps( + fps: float, + max_fps: float, +) -> tuple[float, int]: + """ + Map fps to a new fps that is less than max_fps. + + Args: + fps (float): Original fps. + max_fps (float): Maximum fps. + + Returns: + tuple[float, int]: New fps and sampling interval. + """ + if math.isnan(fps): + return 0, 1 + if fps < max_fps: + return fps, 1 + sampling_interval = math.ceil(fps / max_fps) + new_fps = math.floor(fps / sampling_interval) + return new_fps, sampling_interval + + +def sync_object_across_devices(obj: Any, rank: int = 0): + """ + Synchronizes any picklable object across devices in a PyTorch distributed setting + using `broadcast_object_list` with CUDA support. + + Parameters: + obj (Any): The object to synchronize. Can be any picklable object (e.g., list, dict, custom class). + rank (int): The rank of the device from which to broadcast the object state. Default is 0. + + Note: Ensure torch.distributed is initialized before using this function and CUDA is available. + """ + + # Move the object to a list for broadcasting + object_list = [obj] + + # Broadcast the object list from the source rank to all other ranks + dist.broadcast_object_list(object_list, src=rank, device="cuda") + + # Retrieve the synchronized object + obj = object_list[0] + + return obj + + +def rescale_image_by_path(path: str, height: int, width: int): + """ + Rescales an image to the specified height and width and saves it back to the original path. + + Args: + path (str): The file path of the image. + height (int): The target height of the image. + width (int): The target width of the image. + """ + try: + # read image + image = Image.open(path) + + # check if image is valid + if image is None: + raise ValueError("The image is invalid or empty.") + + # resize image + resize_transform = transforms.Resize((width, height)) + resized_image = resize_transform(image) + + # save resized image back to the original path + resized_image.save(path) + + except Exception as e: + print(f"Error rescaling image: {e}") + + +def rescale_video_by_path(path: str, height: int, width: int): + """ + Rescales an MP4 video (without audio) to the specified height and width. + + Args: + path (str): The file path of the video. + height (int): The target height of the video. + width (int): The target width of the video. + """ + try: + # Read video and metadata + video, info = read_video(path, backend="av") + + # Check if video is valid + if video is None or video.size(0) == 0: + raise ValueError("The video is invalid or empty.") + + # Resize video frames using a performant method + resize_transform = transforms.Compose([transforms.Resize((width, height))]) + resized_video = torch.stack([resize_transform(frame) for frame in video]) + + # Save resized video back to the original path + resized_video = resized_video.permute(0, 2, 3, 1) + write_video(path, resized_video, fps=int(info["video_fps"]), video_codec="h264") + except Exception as e: + print(f"Error rescaling video: {e}") + + +def save_tensor_to_disk(tensor, path, exist_handling="overwrite"): + if os.path.exists(path): + if exist_handling == "ignore": + return + elif exist_handling == "raise": + raise UserWarning(f"File {path} already exists, rewriting!") + torch.save(tensor, path) + + +def save_tensor_to_internet(tensor, path): + raise NotImplementedError("save_tensor_to_internet is not implemented yet!") + + +def save_latent(latent, path, exist_handling="overwrite"): + if path.startswith(("http://", "https://", "ftp://", "sftp://")): + save_tensor_to_internet(latent, path) + else: + save_tensor_to_disk(latent, path, exist_handling=exist_handling) + + +def cache_latents(latents, path, exist_handling="overwrite"): + for i in range(latents.shape[0]): + save_latent(latents[i], path[i], exist_handling=exist_handling) diff --git a/opensora/datasets/video_transforms.py b/opensora/datasets/video_transforms.py new file mode 100644 index 0000000..5da870d --- /dev/null +++ b/opensora/datasets/video_transforms.py @@ -0,0 +1,595 @@ +# Copyright 2024 Vchitect/Latte + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Modified from Latte + +import numbers + +# - This file is adapted from https://github.com/Vchitect/Latte/blob/main/datasets/video_transforms.py +import random + +import numpy as np +import torch + + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def crop(clip, i, j, h, w): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + """ + if len(clip.size()) != 4: + raise ValueError("clip should be a 4D tensor") + return clip[..., i : i + h, j : j + w] + + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) + + +def resize_scale(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + H, W = clip.size(-2), clip.size(-1) + scale_ = target_size[0] / min(H, W) + th, tw = int(round(H * scale_)), int(round(W * scale_)) + return torch.nn.functional.interpolate(clip, size=(th, tw), mode=interpolation_mode, align_corners=False) + + +def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): + """ + Do spatial cropping and resizing to the video clip + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the cropped region. + w (int): Width of the cropped region. + size (tuple(int, int)): height and width of resized clip + Returns: + clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + clip = crop(clip, i, j, h, w) + clip = resize(clip, size, interpolation_mode) + return clip + + +def center_crop(clip, crop_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + th, tw = crop_size + if h < th or w < tw: + raise ValueError("height and width must be no smaller than crop_size") + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return crop(clip, i, j, th, tw) + + +def center_crop_using_short_edge(clip): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + if h < w: + th, tw = h, h + i = 0 + j = int(round((w - tw) / 2.0)) + else: + th, tw = w, w + i = int(round((h - th) / 2.0)) + j = 0 + return crop(clip, i, j, th, tw) + + +def resize_crop_to_fill(clip, target_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + th, tw = target_size[0], target_size[1] + rh, rw = th / h, tw / w + if rh > rw: + sh, sw = th, round(w * rh) + clip = resize(clip, (sh, sw), "bilinear") + i = 0 + j = int(round(sw - tw) / 2.0) + else: + sh, sw = round(h * rw), tw + clip = resize(clip, (sh, sw), "bilinear") + i = int(round(sh - th) / 2.0) + j = 0 + assert i + th <= clip.size(-2) and j + tw <= clip.size(-1) + return crop(clip, i, j, th, tw) + + +# def rand_crop_h_w(clip, target_size_range, multiples_of: int = 8): +# # NOTE: for some reason, if don't re-import, gives same randint results +# import sys + +# del sys.modules["random"] +# import random + +# if not _is_tensor_video_clip(clip): +# raise ValueError("clip should be a 4D torch.tensor") +# h, w = clip.size(-2), clip.size(-1) + +# # get random target h w +# th = random.randint(target_size_range[0], target_size_range[1]) +# tw = random.randint(target_size_range[0], target_size_range[1]) + +# # ensure that h w are factors of 8 +# th = th - th % multiples_of +# tw = tw - tw % multiples_of + +# # get random start pos +# i = random.randint(0, h-th) if h > th else 0 +# j = random.randint(0, w-tw) if w > tw else 0 + +# th = th if th < h else h +# tw = tw if tw < w else w + +# # print("target size range:",target_size_range) +# # print("original size:", h, w) +# # print("crop size:", th, tw) +# # print(f"crop:{i}-{i+th}, {j}-{j+tw}") + +# return (crop(clip, i, j, th, tw), th, tw) + + +def random_shift_crop(clip): + """ + Slide along the long edge, with the short edge as crop size + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + + if h <= w: + short_edge = h + else: + short_edge = w + + th, tw = short_edge, short_edge + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + return crop(clip, i, j, th, tw) + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + # return clip.float().permute(3, 0, 1, 2) / 255.0 + return clip.float() / 255.0 + + +def normalize(clip, mean, std, inplace=False): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + mean (tuple): pixel RGB mean. Size is (3) + std (tuple): pixel standard deviation. Size is (3) + Returns: + normalized clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + if not inplace: + clip = clip.clone() + mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) + # print(mean) + std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + return clip + + +def hflip(clip): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + Returns: + flipped clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + return clip.flip(-1) + + +class ResizeCrop: + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + clip = resize_crop_to_fill(clip, self.size) + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class RandomSizedCrop: + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + i, j, h, w = self.get_params(clip) + # self.size = (h, w) + return crop(clip, i, j, h, w) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + def get_params(self, clip, multiples_of=8): + h, w = clip.shape[-2:] + + # get random target h w + th = random.randint(self.size[0], self.size[1]) + tw = random.randint(self.size[0], self.size[1]) + # ensure that h w are factors of 8 + th = th - th % multiples_of + tw = tw - tw % multiples_of + + if h < th: + th = h - h % multiples_of + if w < tw: + tw = w - w % multiples_of + + if w == tw and h == th: + return 0, 0, h, w + + else: + # get random start pos + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + + return i, j, th, tw + + +class RandomCropVideo: + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: randomly cropped video clip. + size is (T, C, OH, OW) + """ + i, j, h, w = self.get_params(clip) + return crop(clip, i, j, h, w) + + def get_params(self, clip): + h, w = clip.shape[-2:] + th, tw = self.size + + if h < th or w < tw: + raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") + + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + + return i, j, th, tw + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class CenterCropResizeVideo: + """ + First use the short side for cropping length, + center crop video, then resize to the specified size + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_center_crop = center_crop_using_short_edge(clip) + clip_center_crop_resize = resize( + clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode + ) + return clip_center_crop_resize + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class UCFCenterCropVideo: + """ + First scale to the specified size in equal proportion to the short edge, + then center cropping + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) + clip_center_crop = center_crop(clip_resize, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class KineticsRandomCropResizeVideo: + """ + Slide along the long edge, with the short edge as crop size. And resie to the desired size. + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + clip_random_crop = random_shift_crop(clip) + clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) + return clip_resize + + +class CenterCropVideo: + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_center_crop = center_crop(clip, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class NormalizeVideo: + """ + Normalize the video clip by mean subtraction and division by standard deviation + Args: + mean (3-tuple): pixel RGB mean + std (3-tuple): pixel RGB standard deviation + inplace (boolean): whether do in-place normalization + """ + + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) + """ + return normalize(clip, self.mean, self.std, self.inplace) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + return to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +class RandomHorizontalFlipVideo: + """ + Flip the video clip along the horizontal direction with a given probability + Args: + p (float): probability of the clip being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Size is (T, C, H, W) + Return: + clip (torch.tensor): Size is (T, C, H, W) + """ + if random.random() < self.p: + clip = hflip(clip) + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" + + +# ------------------------------------------------------------ +# --------------------- Sampling --------------------------- +# ------------------------------------------------------------ +class TemporalRandomCrop(object): + """Temporally crop the given frame indices at a random location. + + Args: + size (int): Desired length of frames will be seen in the model. + """ + + def __init__(self, size): + self.size = size + + def __call__(self, total_frames): + rand_end = max(0, total_frames - self.size - 1) + begin_index = random.randint(0, rand_end) + end_index = min(begin_index + self.size, total_frames) + return begin_index, end_index + + +if __name__ == "__main__": + import os + + import numpy as np + import torchvision.io as io + from torchvision import transforms + from torchvision.utils import save_image + + vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW") + + trans = transforms.Compose( + [ + ToTensorVideo(), + RandomHorizontalFlipVideo(), + UCFCenterCropVideo(512), + # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + target_video_len = 32 + frame_interval = 1 + total_frames = len(vframes) + print(total_frames) + + temporal_sample = TemporalRandomCrop(target_video_len * frame_interval) + + # Sampling video frames + start_frame_ind, end_frame_ind = temporal_sample(total_frames) + # print(start_frame_ind) + # print(end_frame_ind) + assert end_frame_ind - start_frame_ind >= target_video_len + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int) + print(frame_indice) + + select_vframes = vframes[frame_indice] + print(select_vframes.shape) + print(select_vframes.dtype) + + select_vframes_trans = trans(select_vframes) + print(select_vframes_trans.shape) + print(select_vframes_trans.dtype) + + select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8) + print(select_vframes_trans_int.dtype) + print(select_vframes_trans_int.permute(0, 2, 3, 1).shape) + + io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8) + + for i in range(target_video_len): + save_image( + select_vframes_trans[i], os.path.join("./test000", "%04d.png" % i), normalize=True, value_range=(-1, 1) + ) diff --git a/tools/datasets/README.md b/tools/datasets/README.md new file mode 100644 index 0000000..73f25b7 --- /dev/null +++ b/tools/datasets/README.md @@ -0,0 +1,282 @@ +# Dataset Management + +- [Dataset Management](#dataset-management) + - [Dataset Format](#dataset-format) + - [Dataset to CSV](#dataset-to-csv) + - [Manage datasets](#manage-datasets) + - [Requirement](#requirement) + - [Basic Usage](#basic-usage) + - [Score filtering](#score-filtering) + - [Documentation](#documentation) + - [Transform datasets](#transform-datasets) + - [Resize](#resize) + - [Frame extraction](#frame-extraction) + - [Crop Midjourney 4 grid](#crop-midjourney-4-grid) + - [Analyze datasets](#analyze-datasets) + - [Data Process Pipeline](#data-process-pipeline) + +After preparing the raw dataset according to the [instructions](/docs/datasets.md), you can use the following commands to manage the dataset. + +## Dataset Format + +All dataset should be provided in a `.csv` file (or `parquet.gzip` to save space), which is used for both training and data preprocessing. The columns should follow the words below: + +- `path`: the relative/absolute path or url to the image or video file. Required. +- `text`: the caption or description of the image or video. Required for training. +- `num_frames`: the number of frames in the video. Required for training. +- `width`: the width of the video frame. Required for dynamic bucket. +- `height`: the height of the video frame. Required for dynamic bucket. +- `aspect_ratio`: the aspect ratio of the video frame (height / width). Required for dynamic bucket. +- `resolution`: height x width. For analysis. +- `text_len`: the number of tokens in the text. For analysis. +- `aes`: aesthetic score calculated by [asethetic scorer](/tools/aesthetic/README.md). For filtering. +- `flow`: optical flow score calculated by [UniMatch](/tools/scoring/README.md). For filtering. +- `match`: matching score of a image-text/video-text pair calculated by [CLIP](/tools/scoring/README.md). For filtering. +- `fps`: the frame rate of the video. Optional. +- `cmotion`: the camera motion. + +An example ready for training: + +```csv +path, text, num_frames, width, height, aspect_ratio +/absolute/path/to/image1.jpg, caption, 1, 720, 1280, 0.5625 +/absolute/path/to/video1.mp4, caption, 120, 720, 1280, 0.5625 +/absolute/path/to/video2.mp4, caption, 20, 256, 256, 1 +``` + +We use pandas to manage the `.csv` or `.parquet` files. The following code is for reading and writing files: + +```python +df = pd.read_csv(input_path) +df = df.to_csv(output_path, index=False) +# or use parquet, which is smaller +df = pd.read_parquet(input_path) +df = df.to_parquet(output_path, index=False) +``` + +## Dataset to CSV + +As a start point, `convert.py` is used to convert the dataset to a CSV file. You can use the following commands to convert the dataset to a CSV file: + +```bash +python -m tools.datasets.convert DATASET-TYPE DATA_FOLDER + +# general video folder +python -m tools.datasets.convert video VIDEO_FOLDER --output video.csv +# general image folder +python -m tools.datasets.convert image IMAGE_FOLDER --output image.csv +# imagenet +python -m tools.datasets.convert imagenet IMAGENET_FOLDER --split train +# ucf101 +python -m tools.datasets.convert ucf101 UCF101_FOLDER --split videos +# vidprom +python -m tools.datasets.convert vidprom VIDPROM_FOLDER --info VidProM_semantic_unique.csv +``` + +## Manage datasets + +Use `datautil` to manage the dataset. + +### Requirement + +Follow our [installation guide](../../docs/installation.md)'s "Data Dependencies" and "Datasets" section to install the required packages. + + + + + + + + + + +### Basic Usage + +You can use the following commands to process the `csv` or `parquet` files. The output file will be saved in the same directory as the input, with different suffixes indicating the processed method. + +```bash +# datautil takes multiple CSV files as input and merge them into one CSV file +# output: DATA1+DATA2.csv +python -m tools.datasets.datautil DATA1.csv DATA2.csv + +# shard CSV files into multiple CSV files +# output: DATA1_0.csv, DATA1_1.csv, ... +python -m tools.datasets.datautil DATA1.csv --shard 10 + +# filter frames between 128 and 256, with captions +# output: DATA1_fmin_128_fmax_256.csv +python -m tools.datasets.datautil DATA.csv --fmin 128 --fmax 256 + +# Disable parallel processing +python -m tools.datasets.datautil DATA.csv --fmin 128 --fmax 256 --disable-parallel + +# Compute num_frames, height, width, fps, aspect_ratio for videos or images +# output: IMG_DATA+VID_DATA_vinfo.csv +python -m tools.datasets.datautil IMG_DATA.csv VID_DATA.csv --video-info + +# You can run multiple operations at the same time. +python -m tools.datasets.datautil DATA.csv --video-info --remove-empty-caption --remove-url --lang en +``` + +### Score filtering + +To examine and filter the quality of the dataset by aesthetic score and clip score, you can use the following commands: + +```bash +# sort the dataset by aesthetic score +# output: DATA_sort.csv +python -m tools.datasets.datautil DATA.csv --sort aesthetic_score +# View examples of high aesthetic score +head -n 10 DATA_sort.csv +# View examples of low aesthetic score +tail -n 10 DATA_sort.csv + +# sort the dataset by clip score +# output: DATA_sort.csv +python -m tools.datasets.datautil DATA.csv --sort clip_score + +# filter the dataset by aesthetic score +# output: DATA_aesmin_0.5.csv +python -m tools.datasets.datautil DATA.csv --aesmin 0.5 +# filter the dataset by clip score +# output: DATA_matchmin_0.5.csv +python -m tools.datasets.datautil DATA.csv --matchmin 0.5 +``` + +### Documentation + +You can also use `python -m tools.datasets.datautil --help` to see usage. + +| Args | File suffix | Description | +| --------------------------- | -------------- | ------------------------------------------------------------- | +| `--output OUTPUT` | | Output path | +| `--format FORMAT` | | Output format (csv, parquet, parquet.gzip) | +| `--disable-parallel` | | Disable `pandarallel` | +| `--seed SEED` | | Random seed | +| `--shard SHARD` | `_0`,`_1`, ... | Shard the dataset | +| `--sort KEY` | `_sort` | Sort the dataset by KEY | +| `--sort-descending KEY` | `_sort` | Sort the dataset by KEY in descending order | +| `--difference DATA.csv` | | Remove the paths in DATA.csv from the dataset | +| `--intersection DATA.csv` | | Keep the paths in DATA.csv from the dataset and merge columns | +| `--info` | `_info` | Get the basic information of each video and image (cv2) | +| `--ext` | `_ext` | Remove rows if the file does not exist | +| `--relpath` | `_relpath` | Modify the path to relative path by root given | +| `--abspath` | `_abspath` | Modify the path to absolute path by root given | +| `--remove-empty-caption` | `_noempty` | Remove rows with empty caption | +| `--remove-url` | `_nourl` | Remove rows with url in caption | +| `--lang LANG` | `_lang` | Remove rows with other language | +| `--remove-path-duplication` | `_noduppath` | Remove rows with duplicated path | +| `--remove-text-duplication` | `_noduptext` | Remove rows with duplicated caption | +| `--refine-llm-caption` | `_llm` | Modify the caption generated by LLM | +| `--clean-caption MODEL` | `_clean` | Modify the caption according to T5 pipeline to suit training | +| `--unescape` | `_unescape` | Unescape the caption | +| `--merge-cmotion` | `_cmotion` | Merge the camera motion to the caption | +| `--count-num-token` | `_ntoken` | Count the number of tokens in the caption | +| `--load-caption EXT` | `_load` | Load the caption from the file | +| `--fmin FMIN` | `_fmin` | Filter the dataset by minimum number of frames | +| `--fmax FMAX` | `_fmax` | Filter the dataset by maximum number of frames | +| `--hwmax HWMAX` | `_hwmax` | Filter the dataset by maximum height x width | +| `--aesmin AESMIN` | `_aesmin` | Filter the dataset by minimum aesthetic score | +| `--matchmin MATCHMIN` | `_matchmin` | Filter the dataset by minimum clip score | +| `--flowmin FLOWMIN` | `_flowmin` | Filter the dataset by minimum optical flow score | + +## Transform datasets + +The `tools.datasets.transform` module provides a set of tools to transform the dataset. The general usage is as follows: + +```bash +python -m tools.datasets.transform TRANSFORM_TYPE META.csv ORIGINAL_DATA_FOLDER DATA_FOLDER_TO_SAVE_RESULTS --additional-args +``` + +### Resize + +Sometimes you may need to resize the images or videos to a specific resolution. You can use the following commands to resize the dataset: + +```bash +python -m tools.datasets.transform meta.csv /path/to/raw/data /path/to/new/data --length 2160 +``` + +### Frame extraction + +To extract frames from videos, you can use the following commands: + +```bash +python -m tools.datasets.transform vid_frame_extract meta.csv /path/to/raw/data /path/to/new/data --points 0.1 0.5 0.9 +``` + +### Crop Midjourney 4 grid + +Randomly select one of the 4 images in the 4 grid generated by Midjourney. + +```bash +python -m tools.datasets.transform img_rand_crop meta.csv /path/to/raw/data /path/to/new/data +``` + +## Analyze datasets + +You can easily get basic information about a `.csv` dataset by using the following commands: + +```bash +# examine the first 10 rows of the CSV file +head -n 10 DATA1.csv +# count the number of data in the CSV file (approximately) +wc -l DATA1.csv +``` + +For the dataset provided in a `.csv` or `.parquet` file, you can easily analyze the dataset using the following commands. Plots will be automatically saved. + +```python +pyhton -m tools.datasets.analyze DATA_info.csv +``` + +## Data Process Pipeline + +```bash +# Suppose videos and images under ~/dataset/ +# 1. Convert dataset to CSV +python -m tools.datasets.convert video ~/dataset --output meta.csv + +# 2. Get video information +python -m tools.datasets.datautil meta.csv --info --fmin 1 + +# 3. Get caption +# 3.1. generate caption +torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava meta_info_fmin1.csv --dp-size 8 --tp-size 1 --model-path liuhaotian/llava-v1.6-mistral-7b --prompt video +# merge generated results +python -m tools.datasets.datautil meta_info_fmin1_caption_part*.csv --output meta_caption.csv +# merge caption and info +python -m tools.datasets.datautil meta_info_fmin1.csv --intersection meta_caption.csv --output meta_caption_info.csv +# clean caption +python -m tools.datasets.datautil meta_caption_info.csv --clean-caption --refine-llm-caption --remove-empty-caption --output meta_caption_processed.csv +# 3.2. extract caption +python -m tools.datasets.datautil meta_info_fmin1.csv --load-caption json --remove-empty-caption --clean-caption + +# 4. Scoring +# aesthetic scoring +torchrun --standalone --nproc_per_node 8 -m tools.scoring.aesthetic.inference meta_caption_processed.csv +python -m tools.datasets.datautil meta_caption_processed_part*.csv --output meta_caption_processed_aes.csv +# optical flow scoring +torchrun --standalone --nproc_per_node 8 -m tools.scoring.optical_flow.inference meta_caption_processed.csv +# matching scoring +torchrun --standalone --nproc_per_node 8 -m tools.scoring.matching.inference meta_caption_processed.csv +# camera motion +python -m tools.caption.camera_motion_detect meta_caption_processed.csv +``` diff --git a/tools/datasets/__init__.py b/tools/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/datasets/analyze.py b/tools/datasets/analyze.py new file mode 100644 index 0000000..7151689 --- /dev/null +++ b/tools/datasets/analyze.py @@ -0,0 +1,96 @@ +import argparse +import os + +import matplotlib.pyplot as plt +import pandas as pd + + +def read_file(input_path): + if input_path.endswith(".csv"): + return pd.read_csv(input_path) + elif input_path.endswith(".parquet"): + return pd.read_parquet(input_path) + else: + raise NotImplementedError(f"Unsupported file format: {input_path}") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str, help="Path to the input dataset") + parser.add_argument("--save-img", type=str, default="samples/infos/", help="Path to save the image") + return parser.parse_args() + + +def plot_data(data, column, bins, name): + plt.clf() + data.hist(column=column, bins=bins) + os.makedirs(os.path.dirname(name), exist_ok=True) + plt.savefig(name) + print(f"Saved {name}") + + +def plot_categorical_data(data, column, name): + plt.clf() + data[column].value_counts().plot(kind="bar") + os.makedirs(os.path.dirname(name), exist_ok=True) + plt.savefig(name) + print(f"Saved {name}") + + +COLUMNS = { + "num_frames": 100, + "resolution": 100, + "text_len": 100, + "aes": 100, + "match": 100, + "flow": 100, + "cmotion": None, +} + + +def main(args): + data = read_file(args.input) + + # === Image Data Info === + image_index = data["num_frames"] == 1 + if image_index.sum() > 0: + print("=== Image Data Info ===") + img_data = data[image_index] + print(f"Number of images: {len(img_data)}") + print(img_data.head()) + print(img_data.describe()) + if args.save_img: + for column in COLUMNS: + if column in img_data.columns and column not in ["num_frames", "cmotion"]: + if COLUMNS[column] is None: + plot_categorical_data(img_data, column, os.path.join(args.save_img, f"image_{column}.png")) + else: + plot_data(img_data, column, COLUMNS[column], os.path.join(args.save_img, f"image_{column}.png")) + + # === Video Data Info === + if not image_index.all(): + print("=== Video Data Info ===") + video_data = data[~image_index] + print(f"Number of videos: {len(video_data)}") + if "num_frames" in video_data.columns: + total_num_frames = video_data["num_frames"].sum() + print(f"Number of frames: {total_num_frames}") + DEFAULT_FPS = 30 + total_hours = total_num_frames / DEFAULT_FPS / 3600 + print(f"Total hours (30 FPS): {int(total_hours)}") + print(video_data.head()) + print(video_data.describe()) + if args.save_img: + for column in COLUMNS: + if column in video_data.columns: + if COLUMNS[column] is None: + plot_categorical_data(video_data, column, os.path.join(args.save_img, f"video_{column}.png")) + else: + plot_data( + video_data, column, COLUMNS[column], os.path.join(args.save_img, f"video_{column}.png") + ) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/tools/datasets/check_integrity.py b/tools/datasets/check_integrity.py new file mode 100644 index 0000000..4735fcc --- /dev/null +++ b/tools/datasets/check_integrity.py @@ -0,0 +1,79 @@ +import argparse +import subprocess + +import pandas as pd +from tqdm import tqdm + +tqdm.pandas() + +try: + from pandarallel import pandarallel + + PANDA_USE_PARALLEL = True +except ImportError: + PANDA_USE_PARALLEL = False + +import shutil + +if not shutil.which("ffmpeg"): + raise ImportError("FFmpeg is not installed") + + +def apply(df, func, **kwargs): + if PANDA_USE_PARALLEL: + return df.parallel_apply(func, **kwargs) + return df.progress_apply(func, **kwargs) + + +def check_video_integrity(video_path): + # try: + can_open_result = subprocess.run( + ["ffmpeg", "-v", "error", "-i", video_path, "-t", "0", "-f", "null", "-"], # open video and capture 0 seconds + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + fast_scan_result = subprocess.run( + ["ffmpeg", "-v", "error", "-analyzeduration", "10M", "-probesize", "10M", "-i", video_path, "-f", "null", "-"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + if can_open_result.stderr == "" and fast_scan_result.stderr == "": + return True + else: + return False + # except Exception as e: + # return False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str, help="path to the input dataset") + parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing") + parser.add_argument("--num-workers", type=int, default=None, help="number of workers") + args = parser.parse_args() + + if args.disable_parallel: + PANDA_USE_PARALLEL = False + if PANDA_USE_PARALLEL: + if args.num_workers is not None: + pandarallel.initialize(nb_workers=args.num_workers, progress_bar=True) + else: + pandarallel.initialize(progress_bar=True) + + data = pd.read_csv(args.input) + assert "path" in data.columns + data["integrity"] = apply(data["path"], check_video_integrity) + + integrity_file_path = args.input.replace(".csv", "_intact.csv") + broken_file_path = args.input.replace(".csv", "_broken.csv") + + intact_data = data[data["integrity"] == True].drop(columns=["integrity"]) + intact_data.to_csv(integrity_file_path, index=False) + broken_data = data[data["integrity"] == False].drop(columns=["integrity"]) + broken_data.to_csv(broken_file_path, index=False) + + print( + f"Integrity check completed. Intact videos saved to: {integrity_file_path}, broken videos saved to {broken_file_path}." + ) diff --git a/tools/datasets/convert.py b/tools/datasets/convert.py new file mode 100644 index 0000000..76f81d1 --- /dev/null +++ b/tools/datasets/convert.py @@ -0,0 +1,144 @@ +import argparse +import os +import time + +import pandas as pd +from torchvision.datasets import ImageNet + +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") +VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv", ".m2ts") + + +def scan_recursively(root): + num = 0 + for entry in os.scandir(root): + if entry.is_file(): + yield entry + elif entry.is_dir(): + num += 1 + if num % 100 == 0: + print(f"Scanned {num} directories.") + yield from scan_recursively(entry.path) + + +def get_filelist(file_path, exts=None): + filelist = [] + time_start = time.time() + + # == OS Walk == + # for home, dirs, files in os.walk(file_path): + # for filename in files: + # ext = os.path.splitext(filename)[-1].lower() + # if exts is None or ext in exts: + # filelist.append(os.path.join(home, filename)) + + # == Scandir == + obj = scan_recursively(file_path) + for entry in obj: + if entry.is_file(): + ext = os.path.splitext(entry.name)[-1].lower() + if exts is None or ext in exts: + filelist.append(entry.path) + + time_end = time.time() + print(f"Scanned {len(filelist)} files in {time_end - time_start:.2f} seconds.") + return filelist + + +def split_by_capital(name): + # BoxingPunchingBag -> Boxing Punching Bag + new_name = "" + for i in range(len(name)): + if name[i].isupper() and i != 0: + new_name += " " + new_name += name[i] + return new_name + + +def process_imagenet(root, split): + root = os.path.expanduser(root) + data = ImageNet(root, split=split) + samples = [(path, data.classes[label][0]) for path, label in data.samples] + output = f"imagenet_{split}.csv" + + df = pd.DataFrame(samples, columns=["path", "text"]) + df.to_csv(output, index=False) + print(f"Saved {len(samples)} samples to {output}.") + + +def process_ucf101(root, split): + root = os.path.expanduser(root) + video_lists = get_filelist(os.path.join(root, split)) + classes = [x.split("/")[-2] for x in video_lists] + classes = [split_by_capital(x) for x in classes] + samples = list(zip(video_lists, classes)) + output = f"ucf101_{split}.csv" + + df = pd.DataFrame(samples, columns=["path", "text"]) + df.to_csv(output, index=False) + print(f"Saved {len(samples)} samples to {output}.") + + +def process_vidprom(root, info): + root = os.path.expanduser(root) + video_lists = get_filelist(root) + video_set = set(video_lists) + # read info csv + infos = pd.read_csv(info) + abs_path = infos["uuid"].apply(lambda x: os.path.join(root, f"pika-{x}.mp4")) + is_exist = abs_path.apply(lambda x: x in video_set) + df = pd.DataFrame(dict(path=abs_path[is_exist], text=infos["prompt"][is_exist])) + df.to_csv("vidprom.csv", index=False) + print(f"Saved {len(df)} samples to vidprom.csv.") + + +def process_general_images(root, output): + root = os.path.expanduser(root) + if not os.path.exists(root): + return + path_list = get_filelist(root, IMG_EXTENSIONS) + fname_list = [os.path.splitext(os.path.basename(x))[0] for x in path_list] + relpath_list = [os.path.relpath(x, root) for x in path_list] + df = pd.DataFrame(dict(path=path_list, id=fname_list, relpath=relpath_list)) + + os.makedirs(os.path.dirname(output), exist_ok=True) + df.to_csv(output, index=False) + print(f"Saved {len(df)} samples to {output}.") + + +def process_general_videos(root, output): + root = os.path.expanduser(root) + if not os.path.exists(root): + return + path_list = get_filelist(root, VID_EXTENSIONS) + path_list = list(set(path_list)) # remove duplicates + fname_list = [os.path.splitext(os.path.basename(x))[0] for x in path_list] + relpath_list = [os.path.relpath(x, root) for x in path_list] + df = pd.DataFrame(dict(path=path_list, id=fname_list, relpath=relpath_list)) + + os.makedirs(os.path.dirname(output), exist_ok=True) + df.to_csv(output, index=False) + print(f"Saved {len(df)} samples to {output}.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101", "vidprom", "image", "video"]) + parser.add_argument("root", type=str) + parser.add_argument("--split", type=str, default="train") + parser.add_argument("--info", type=str, default=None) + parser.add_argument("--output", type=str, default=None, required=True, help="Output path") + args = parser.parse_args() + + if args.dataset == "imagenet": + process_imagenet(args.root, args.split) + elif args.dataset == "ucf101": + process_ucf101(args.root, args.split) + elif args.dataset == "vidprom": + process_vidprom(args.root, args.info) + elif args.dataset == "image": + process_general_images(args.root, args.output) + elif args.dataset == "video": + process_general_videos(args.root, args.output) + else: + raise ValueError("Invalid dataset") diff --git a/tools/datasets/csv2txt.py b/tools/datasets/csv2txt.py new file mode 100644 index 0000000..116e7e0 --- /dev/null +++ b/tools/datasets/csv2txt.py @@ -0,0 +1,14 @@ +import argparse + +import pandas as pd + +parser = argparse.ArgumentParser(description="Convert CSV file to txt file") +parser.add_argument("csv_file", type=str, help="CSV file to convert") +parser.add_argument("txt_file", type=str, help="TXT file to save") +args = parser.parse_args() + +data = pd.read_csv(args.csv_file) +text = data["text"].to_list() +text = "\n".join(text) +with open(args.txt_file, "w") as f: + f.write(text) diff --git a/tools/datasets/datautil.py b/tools/datasets/datautil.py new file mode 100644 index 0000000..3d78f30 --- /dev/null +++ b/tools/datasets/datautil.py @@ -0,0 +1,1089 @@ +import argparse +import html +import json +import math +import os +import random +import re +from functools import partial +from glob import glob + +import cv2 +import numpy as np +import pandas as pd +from PIL import Image +from tqdm import tqdm + +from opensora.datasets.read_video import read_video + +from .utils import IMG_EXTENSIONS + +tqdm.pandas() + +try: + from pandarallel import pandarallel + + PANDA_USE_PARALLEL = True +except ImportError: + PANDA_USE_PARALLEL = False + + +def apply(df, func, **kwargs): + if PANDA_USE_PARALLEL: + return df.parallel_apply(func, **kwargs) + return df.progress_apply(func, **kwargs) + + +TRAIN_COLUMNS = ["path", "text", "num_frames", "fps", "height", "width", "aspect_ratio", "resolution", "text_len"] +PRE_TRAIN_COLUMNS = [ + "path", + "text", + "num_frames", + "fps", + "height", + "width", + "aspect_ratio", + "resolution", + "text_len", + "aes", + "flow", + "pred_score", +] + +# ====================================================== +# --info +# ====================================================== + + +def get_video_length(cap, method="header"): + assert method in ["header", "set"] + if method == "header": + length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + else: + cap.set(cv2.CAP_PROP_POS_AVI_RATIO, 1) + length = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) + return length + + +def get_info(path): + try: + ext = os.path.splitext(path)[1].lower() + if ext in IMG_EXTENSIONS: + return get_image_info(path) + else: + return get_video_info(path) + except: + return 0, 0, 0, np.nan, np.nan, np.nan + + +def get_image_info(path, backend="pillow"): + if backend == "pillow": + try: + with open(path, "rb") as f: + img = Image.open(f) + img = img.convert("RGB") + width, height = img.size + num_frames, fps = 1, np.nan + hw = height * width + aspect_ratio = height / width if width > 0 else np.nan + return num_frames, height, width, aspect_ratio, fps, hw + except: + return 0, 0, 0, np.nan, np.nan, np.nan + elif backend == "cv2": + try: + im = cv2.imread(path) + if im is None: + return 0, 0, 0, np.nan, np.nan, np.nan + height, width = im.shape[:2] + num_frames, fps = 1, np.nan + hw = height * width + aspect_ratio = height / width if width > 0 else np.nan + return num_frames, height, width, aspect_ratio, fps, hw + except: + return 0, 0, 0, np.nan, np.nan, np.nan + else: + raise ValueError + + +def get_video_info(path, backend="torchvision"): + if backend == "torchvision": + try: + vframes, infos = read_video(path) + num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3] + if "video_fps" in infos: + fps = infos["video_fps"] + else: + fps = np.nan + hw = height * width + aspect_ratio = height / width if width > 0 else np.nan + return num_frames, height, width, aspect_ratio, fps, hw + except: + return 0, 0, 0, np.nan, np.nan, np.nan + elif backend == "cv2": + try: + cap = cv2.VideoCapture(path) + num_frames, height, width, fps = ( + get_video_length(cap, method="header"), + int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), + int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), + float(cap.get(cv2.CAP_PROP_FPS)), + ) + hw = height * width + aspect_ratio = height / width if width > 0 else np.nan + return num_frames, height, width, aspect_ratio, fps, hw + except: + return 0, 0, 0, np.nan, np.nan, np.nan + else: + raise ValueError + + +# ====================================================== +# --refine-llm-caption +# ====================================================== + +LLAVA_PREFIX = [ + "The video shows ", + "The video captures ", + "The video features ", + "The video depicts ", + "The video presents ", + "The video features ", + "The video is ", + "In the video, ", + "The image shows ", + "The image captures ", + "The image features ", + "The image depicts ", + "The image presents ", + "The image features ", + "The image is ", + "The image portrays ", + "In the image, ", +] + + +def remove_caption_prefix(caption): + for prefix in LLAVA_PREFIX: + if caption.startswith(prefix) or caption.startswith(prefix.lower()): + caption = caption[len(prefix) :].strip() + if caption[0].islower(): + caption = caption[0].upper() + caption[1:] + return caption + return caption + + +# ====================================================== +# --merge-cmotion +# ====================================================== + +CMOTION_TEXT = { + "static": "static", + "pan_right": "pan right", + "pan_left": "pan left", + "zoom_in": "zoom in", + "zoom_out": "zoom out", + "tilt_up": "tilt up", + "tilt_down": "tilt down", + # "pan/tilt": "The camera is panning.", + # "dynamic": "The camera is moving.", + # "unknown": None, +} +CMOTION_PROBS = { + # hard-coded probabilities + "static": 1.0, + "zoom_in": 1.0, + "zoom_out": 1.0, + "pan_left": 1.0, + "pan_right": 1.0, + "tilt_up": 1.0, + "tilt_down": 1.0, + # "dynamic": 1.0, + # "unknown": 0.0, + # "pan/tilt": 1.0, +} + + +def merge_cmotion(caption, cmotion): + text = CMOTION_TEXT[cmotion] + prob = CMOTION_PROBS[cmotion] + if text is not None and random.random() < prob: + caption = f"{caption} Camera motion: {text}." + return caption + + +# ====================================================== +# --lang +# ====================================================== + + +def build_lang_detector(lang_to_detect): + from lingua import Language, LanguageDetectorBuilder + + lang_dict = dict(en=Language.ENGLISH) + assert lang_to_detect in lang_dict + valid_lang = lang_dict[lang_to_detect] + detector = LanguageDetectorBuilder.from_all_spoken_languages().with_low_accuracy_mode().build() + + def detect_lang(caption): + confidence_values = detector.compute_language_confidence_values(caption) + confidence = [x.language for x in confidence_values[:5]] + if valid_lang not in confidence: + return False + return True + + return detect_lang + + +# ====================================================== +# --clean-caption +# ====================================================== + + +def basic_clean(text): + import ftfy + + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +BAD_PUNCT_REGEX = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" +) # noqa + + +def text_refine_t5(caption): + import urllib.parse as ul + + from bs4 import BeautifulSoup + + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = basic_clean(caption) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + +def text_preprocessing(text, use_text_preprocessing: bool = True): + if use_text_preprocessing: + # The exact text cleaning as was in the training stage: + text = text_refine_t5(text) + text = text_refine_t5(text) + return text + else: + return text.lower().strip() + + +def has_human(text): + first_sentence = text.split(".")[0] + human_words = ["man", "woman", "child", "girl", "boy"] + for word in human_words: + if word in first_sentence: + return True + return False + + +# ====================================================== +# load caption +# ====================================================== + + +def load_caption(path, ext): + try: + assert ext in ["json"] + json_path = path.split(".")[0] + ".json" + with open(json_path, "r") as f: + data = json.load(f) + caption = data["caption"] + return caption + except: + return "" + + +# ====================================================== +# --clean-caption +# ====================================================== + +DROP_SCORE_PROB = 0.2 + + +def transform_aes(aes): + # < 4 filter out + if aes < 4: + return "terrible" + elif aes < 4.5: + return "very poor" + elif aes < 5: + return "poor" + elif aes < 5.5: + return "fair" + elif aes < 6: + return "good" + elif aes < 6.5: + return "very good" + else: + return "excellent" + + +def transform_motion(motion): + # < 0.3 filter out + if motion < 0.5: + return "very low" + elif motion < 2: + return "low" + elif motion < 5: + return "fair" + elif motion < 10: + return "high" + elif motion < 20: + return "very high" + else: + return "extremely high" + + +def score2text(data): + text = data["text"] + if not text.endswith("."): + text += "." + # aesthetic + if "aes" in data: + aes = transform_aes(data["aes"]) + if random.random() > DROP_SCORE_PROB: + score_text = f" the aesthetic score is {aes}." + text += score_text + # flow + if "flow" in data: + flow = transform_motion(data["flow"]) + if random.random() > DROP_SCORE_PROB: + score_text = f" the motion strength is {flow}." + text += score_text + return text + + +def undo_score2text(data): + text = data["text"] + sentences = text.strip().split(".")[:-1] + + keywords = ["aesthetic score", "motion strength"] + num_scores = len(keywords) + num_texts_from_score = 0 + for idx in range(1, num_scores + 1): + s = sentences[-idx] + + for key in keywords: + if key in s: + num_texts_from_score += 1 + break + + new_sentences = sentences[:-num_texts_from_score] if num_texts_from_score > 0 else sentences + new_text = ".".join(new_sentences) + if not new_text.endswith("."): + new_text += "." + return new_text + + +# ====================================================== +# read & write +# ====================================================== + + +def read_file(input_path): + if input_path.endswith(".csv"): + return pd.read_csv(input_path) + elif input_path.endswith(".parquet"): + return pd.read_parquet(input_path) + else: + raise NotImplementedError(f"Unsupported file format: {input_path}") + + +def save_file(data, output_path): + output_dir = os.path.dirname(output_path) + if not os.path.exists(output_dir) and output_dir != "": + os.makedirs(output_dir) + if output_path.endswith(".csv"): + return data.to_csv(output_path, index=False) + elif output_path.endswith(".parquet"): + return data.to_parquet(output_path, index=False) + else: + raise NotImplementedError(f"Unsupported file format: {output_path}") + + +def read_data(input_paths): + if len(input_paths) == 0: + print(f"No meta file to process. Exit.") + exit() + + data = [] + input_name = "" + input_list = [] + for input_path in input_paths: + input_list.extend(glob(input_path)) + cnt = len(input_list) + print(f"==> Total {cnt} input files:") + for x in input_list: + print(x) + + for i, input_path in enumerate(input_list): + if not os.path.exists(input_path): + raise FileNotFoundError + data.append(read_file(input_path)) + basename = os.path.basename(input_path) + input_name += os.path.splitext(basename)[0] + if i != len(input_list) - 1: + input_name += "+" + print(f"==> Loaded meta (shape={data[-1].shape}) from '{input_path}'") + + data = pd.concat(data, ignore_index=True, sort=False) + print(f"==> Merged {cnt} files. shape={data.shape}") + return data, input_name + + +def is_verbose_sentence(s): + if ("is no " in s) or ("are no " in s): + return True + if "does not " in s: + return True + if "solely " in s: + return True + if "only " in s: + return True + if "not visible" in s: + return True + if "no " in s and "visible" in s: + return True + return False + + +def is_verbose_caption(caption): + caption = caption.strip() + sentences = caption.split(".") + if not caption.endswith("."): + sentences = sentences[:-1] + + cnt = 0 + for sentence in sentences: + if is_verbose_sentence(sentence): + cnt += 1 + if cnt >= 2: + return True + return False + + +def refine_sentences(caption): + caption = caption.strip() + sentences = caption.split(".") + if not caption.endswith("."): + sentences = sentences[:-1] + + new_caption = "" + for i, sentence in enumerate(sentences): + if sentence.strip() == "": + continue + if is_verbose_sentence(sentence): + continue + new_caption += f"{sentence}." + return new_caption + + +# ====================================================== +# main +# ====================================================== +# To add a new method, register it in the main, parse_args, and get_output_path functions, and update the doc at /tools/datasets/README.md#documentation + + +def main(args): + # reading data + data, input_name = read_data(args.input) + + # get output path + output_path = get_output_path(args, input_name) + + # path subtract (difference set) + if args.path_subtract is not None: + data_diff = pd.read_csv(args.path_subtract) + print(f"Meta to subtract: shape={data_diff.shape}.") + data = data[~data["path"].isin(data_diff["path"])] + + # path intersect + if args.path_intersect is not None: + data_new = pd.read_csv(args.path_intersect) + print(f"Meta to intersect: shape={data_new.shape}.") + + new_cols = data_new.columns.difference(data.columns) + col_on = "path" + new_cols = new_cols.insert(0, col_on) + data = pd.merge(data, data_new[new_cols], on=col_on, how="inner") + + # preparation + if args.lang is not None: + detect_lang = build_lang_detector(args.lang) + if args.count_num_token == "t5": + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("DeepFloyd/t5-v1_1-xxl") + + # IO-related + if args.load_caption is not None: + assert "path" in data.columns + data["text"] = apply(data["path"], load_caption, ext=args.load_caption) + if args.info: + info = apply(data["path"], get_info) + ( + data["num_frames"], + data["height"], + data["width"], + data["aspect_ratio"], + data["fps"], + data["resolution"], + ) = zip(*info) + if args.video_info: + info = apply(data["path"], get_video_info) + ( + data["num_frames"], + data["height"], + data["width"], + data["aspect_ratio"], + data["fps"], + data["resolution"], + ) = zip(*info) + + # filtering path + if args.path_filter_empty: + assert "path" in data.columns + data = data[data["path"].str.len() > 0] + data = data[~data["path"].isna()] + if args.path_filter_substr: + data = data[~data["path"].str.contains(args.path_filter_substr)] + if args.path_keep_substr: + data = data[data["path"].str.contains(args.path_keep_substr)] + if args.path_dedup: + assert "path" in data.columns + data = data.drop_duplicates(subset=["path"]) + + # filtering text + if args.text_filter_url: + assert "text" in data.columns + data = data[~data["text"].str.contains(r"(?Phttps?://[^\s]+)", regex=True)] + if args.lang is not None: + assert "text" in data.columns + data = data[data["text"].progress_apply(detect_lang)] # cannot parallelize + if args.text_filter_empty: + assert "text" in data.columns + data = data[data["text"].str.len() > 0] + data = data[~data["text"].isna()] + if args.text_filter_substr: + assert "text" in data.columns + data = data[~data["text"].str.contains(args.text_filter_substr)] + + # processing + if args.relpath is not None: + data["path"] = apply(data["path"], lambda x: os.path.relpath(x, args.relpath)) + if args.abspath is not None: + data["path"] = apply(data["path"], lambda x: os.path.join(args.abspath, x)) + if args.path_to_id: + data["id"] = apply(data["path"], lambda x: os.path.splitext(os.path.basename(x))[0]) + if args.merge_cmotion: + data["text"] = apply(data, lambda x: merge_cmotion(x["text"], x["cmotion"]), axis=1) + if args.text_remove_prefix: + assert "text" in data.columns + data["text"] = apply(data["text"], remove_caption_prefix) + if args.text_append is not None: + assert "text" in data.columns + data["text"] = data["text"] + args.text_append + if args.text_refine_t5: + assert "text" in data.columns + data["text"] = apply( + data["text"], + partial(text_preprocessing, use_text_preprocessing=True), + ) + if args.text_image2video: + assert "text" in data.columns + data["text"] = apply(data["text"], lambda x: x.replace("still image", "video").replace("image", "video")) + if args.count_num_token is not None: + assert "text" in data.columns + data["text_len"] = apply(data["text"], lambda x: len(tokenizer(x)["input_ids"])) + if args.update_text is not None: + data_new = pd.read_csv(args.update_text) + num_updated = data.path.isin(data_new.path).sum() + print(f"Number of updated samples: {num_updated}.") + data = data.set_index("path") + data_new = data_new[["path", "text"]].set_index("path") + data.update(data_new) + data = data.reset_index() + if args.text_refine_sentences: + data["text"] = apply(data["text"], refine_sentences) + if args.text_score2text: + data["text"] = apply(data, score2text, axis=1) + if args.text_undo_score2text: + data["text"] = apply(data, undo_score2text, axis=1) + + # sort + if args.sort is not None: + data = data.sort_values(by=args.sort, ascending=False) + if args.sort_ascending is not None: + data = data.sort_values(by=args.sort_ascending, ascending=True) + + # filtering + if args.filesize: + assert "path" in data.columns + data["filesize"] = apply(data["path"], lambda x: os.stat(x).st_size / 1024 / 1024) + if args.fsmax is not None: + assert "filesize" in data.columns + data = data[data["filesize"] <= args.fsmax] + if args.fsmin is not None: + assert "filesize" in data.columns + data = data[data["filesize"] >= args.fsmin] + if args.text_filter_empty: + assert "text" in data.columns + data = data[data["text"].str.len() > 0] + data = data[~data["text"].isna()] + if args.fmin is not None: + assert "num_frames" in data.columns + data = data[data["num_frames"] >= args.fmin] + if args.fmax is not None: + assert "num_frames" in data.columns + data = data[data["num_frames"] <= args.fmax] + if args.filter_dyn_fps is not False: + assert "fps" in data.columns and "num_frames" in data.columns + + def dyn_fps_filter(row, max_fps=args.dyn_fps_max_fps, keep_frames=args.dyn_fps_keep_frames): + # get scale factor + if math.isnan(row["fps"]): # image + return True + scale_factor = math.ceil(row["fps"] / max_fps) + min_frames = keep_frames * scale_factor + return row["num_frames"] >= min_frames + + dyn_fps_mask = data.apply(dyn_fps_filter, axis=1) + data = data[dyn_fps_mask] + if args.fpsmax is not None: + assert "fps" in data.columns + data = data[(data["fps"] <= args.fpsmax) | np.isnan(data["fps"])] + if args.hwmax is not None: + if "resolution" not in data.columns: + height = data["height"] + width = data["width"] + data["resolution"] = height * width + data = data[data["resolution"] <= args.hwmax] + if args.aesmin is not None: + assert "aes" in data.columns + data = data[data["aes"] >= args.aesmin] + if args.prefmin is not None: + assert "pred_score" in data.columns + data = data[data["pred_score"] >= args.prefmin] + if args.matchmin is not None: + assert "match" in data.columns + data = data[data["match"] >= args.matchmin] + if args.flowmin is not None: + assert "flow" in data.columns + data = data[data["flow"] >= args.flowmin] + if args.facemin is not None: + assert "face_area_ratio" in data.columns + data = data[data["face_area_ratio"] >= args.facemin] + if args.text_dedup: + data = data.drop_duplicates(subset=["text"], keep="first") + if args.img_only: + data = data[data["path"].str.lower().str.endswith(IMG_EXTENSIONS)] + if args.vid_only: + data = data[~data["path"].str.lower().str.endswith(IMG_EXTENSIONS)] + if args.filter_too_verbose: + data = data[data["text"].apply(is_verbose_caption)] + if args.h_le_w: + data = data[data["height"] <= data["width"]] + if args.filter_human: + data = data[data["text"].apply(has_human)] + + if args.ext: + assert "path" in data.columns + data = data[apply(data["path"], os.path.exists)] + + # process data + if args.shuffle: + data = data.sample(frac=1).reset_index(drop=True) # shuffle + if args.head is not None: + data = data.head(args.head) + if args.sample is not None: + data = data.sample(args.sample).reset_index(drop=True) + + # train columns + if args.train_column: + assert args.pre_train_column is False + all_columns = data.columns + columns_to_drop = all_columns.difference(TRAIN_COLUMNS) + data = data.drop(columns=columns_to_drop) + elif args.pre_train_column: + assert args.train_column is False + all_columns = data.columns + columns_to_drop = all_columns.difference(PRE_TRAIN_COLUMNS) + data = data.drop(columns=columns_to_drop) + + if args.chunk is not None: + assert len(args.input) == 1 + input_path = args.input[0] + res = np.array_split(data, args.chunk) + for idx in range(args.chunk): + out_path = f"_chunk-{idx}-{args.chunk}".join(os.path.splitext(input_path)) + shape = res[idx].shape + print(f"==> Saving meta file (shape={shape}) to '{out_path}'") + if args.format == "csv": + res[idx].to_csv(out_path, index=False) + elif args.format == "parquet": + res[idx].to_parquet(out_path, index=False) + else: + raise NotImplementedError + print(f"New meta (shape={shape}) saved to '{out_path}'") + else: + shape = data.shape + print(f"==> Saving meta file (shape={shape}) to '{output_path}'") + save_file(data, output_path) + print(f"==> New meta (shape={shape}) saved to '{output_path}'") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str, nargs="+", help="path to the input dataset") + parser.add_argument("--output", type=str, default=None, help="output path") + parser.add_argument("--format", type=str, default="csv", help="output format", choices=["csv", "parquet"]) + parser.add_argument("--disable_parallel", action="store_true", help="disable parallel processing") + parser.add_argument("--num-workers", type=int, default=None, help="number of workers") + parser.add_argument("--seed", type=int, default=42, help="random seed") + + # special case + parser.add_argument("--shard", type=int, default=None, help="shard the dataset") + parser.add_argument("--sort", type=str, default=None, help="sort by column") + parser.add_argument("--sort-ascending", type=str, default=None, help="sort by column (ascending order)") + parser.add_argument("--path_subtract", type=str, default=None, help="substract path (difference set)") + parser.add_argument("--path_intersect", type=str, default=None, help="intersect path and merge columns") + parser.add_argument("--train_column", action="store_true", help="only keep the train column") + parser.add_argument("--pre_train_column", action="store_true", help="only keep the pre-train column") + + # IO-related + parser.add_argument("--info", action="store_true", help="get the basic information of each video and image") + parser.add_argument("--video-info", action="store_true", help="get the basic information of each video") + parser.add_argument("--ext", action="store_true", help="check if the file exists") + parser.add_argument( + "--load-caption", type=str, default=None, choices=["json", "txt"], help="load the caption from json or txt" + ) + + # path processing + parser.add_argument("--relpath", type=str, default=None, help="modify the path to relative path by root given") + parser.add_argument("--abspath", type=str, default=None, help="modify the path to absolute path by root given") + parser.add_argument("--path-to-id", action="store_true", help="add id based on path") + parser.add_argument( + "--path_filter_empty", action="store_true", help="remove rows with empty path" + ) # caused by transform, cannot read path + parser.add_argument( + "--path_filter_substr", type=str, default=None, help="remove rows whose path contains a substring" + ) + parser.add_argument("--path_keep_substr", type=str, default=None, help="keep rows whose path contains a substring") + parser.add_argument("--path_dedup", action="store_true", help="remove rows with duplicated path") + + # caption filtering + parser.add_argument("--text_filter_empty", action="store_true", help="remove rows with empty caption") + parser.add_argument("--text_filter_url", action="store_true", help="remove rows with url in caption") + parser.add_argument("--text_filter_substr", type=str, default=None, help="remove text with a substring") + parser.add_argument("--text_dedup", action="store_true", help="remove rows with duplicated caption") + parser.add_argument("--lang", type=str, default=None, help="remove rows with other language") + parser.add_argument("--filter-too-verbose", action="store_true", help="filter samples with too verbose caption") + parser.add_argument("--filter-human", action="store_true", help="filter samples with human preference score") + + # caption processing + parser.add_argument("--text_remove_prefix", action="store_true", help="remove prefix like 'The video shows '") + parser.add_argument( + "--text_refine_t5", action="store_true", help="refine the caption output by T5 with regular expression" + ) + parser.add_argument("--text_image2video", action="store_true", help="text.replace('image', 'video'") + parser.add_argument("--text_refine_sentences", action="store_true", help="refine every sentence in the caption") + parser.add_argument("--text_score2text", action="store_true", help="convert score to text and append to caption") + parser.add_argument("--text_undo_score2text", action="store_true", help="undo score2text") + parser.add_argument("--merge-cmotion", action="store_true", help="merge the camera motion to the caption") + parser.add_argument( + "--count-num-token", type=str, choices=["t5"], default=None, help="Count the number of tokens in the caption" + ) + parser.add_argument("--text_append", type=str, default=None, help="append text to the caption") + parser.add_argument("--update-text", type=str, default=None, help="update the text with the given text") + + # filter for dynamic fps + parser.add_argument( + "--filter_dyn_fps", action="store_true", help="filter data to contain enough frames for dynamic fps" + ) + parser.add_argument("--dyn_fps_max_fps", type=int, default=16, help="max fps for dynamic fps") + parser.add_argument("--dyn_fps_keep_frames", type=int, default=32, help="num frames to keep for dynamic fps") + + # score filtering + parser.add_argument("--filesize", action="store_true", help="get the filesize of each video and image in MB") + parser.add_argument("--fsmax", type=float, default=None, help="filter the dataset by maximum filesize") + parser.add_argument("--fsmin", type=float, default=None, help="filter the dataset by minimum filesize") + parser.add_argument("--fmin", type=int, default=None, help="filter the dataset by minimum number of frames") + parser.add_argument("--fmax", type=int, default=None, help="filter the dataset by maximum number of frames") + parser.add_argument("--hwmax", type=int, default=None, help="filter the dataset by maximum resolution") + parser.add_argument("--aesmin", type=float, default=None, help="filter the dataset by minimum aes score") + parser.add_argument( + "--prefmin", type=float, default=None, help="filter the dataset by minimum human preference score" + ) + parser.add_argument("--matchmin", type=float, default=None, help="filter the dataset by minimum match score") + parser.add_argument("--flowmin", type=float, default=None, help="filter the dataset by minimum flow score") + parser.add_argument("--facemin", type=float, default=None, help="filter the dataset by minimum face area ratio") + parser.add_argument("--fpsmax", type=float, default=None, help="filter the dataset by maximum fps") + parser.add_argument("--img-only", action="store_true", help="only keep the image data") + parser.add_argument("--vid-only", action="store_true", help="only keep the video data") + parser.add_argument("--h-le-w", action="store_true", help="only keep samples with h <= w") + + # data processing + parser.add_argument("--shuffle", default=False, action="store_true", help="shuffle the dataset") + parser.add_argument("--head", type=int, default=None, help="return the first n rows of data") + parser.add_argument("--sample", type=int, default=None, help="randomly sample n rows; using args.seed") + parser.add_argument("--chunk", type=int, default=None, help="evenly split rows into chunks") + + return parser.parse_args() + + +def get_output_path(args, input_name): + if args.output is not None: + return args.output + name = input_name + dir_path = os.path.dirname(args.input[0]) + + if args.path_subtract is not None: + name += "_subtract" + if args.path_intersect is not None: + name += "_intersect" + + # sort + if args.sort is not None: + assert args.sort_ascending is None + name += "_sort" + if args.sort_ascending is not None: + assert args.sort is None + name += "_sort" + + # IO-related + # for IO-related, the function must be wrapped in try-except + if args.info: + name += "_info" + if args.video_info: + name += "_vinfo" + if args.ext: + name += "_ext" + if args.load_caption: + name += f"_load{args.load_caption}" + + # path processing + if args.relpath is not None: + name += "_relpath" + if args.abspath is not None: + name += "_abspath" + if args.path_filter_empty: + name += "_path-filter-empty" + if args.path_filter_substr is not None: + name += "_path-filter-substr" + if args.path_keep_substr is not None: + name += "_path-keep-substr" + if args.path_dedup: + name += "_path-dedup" + + # caption filtering + if args.text_filter_empty: + name += "_text-filter-empty" + if args.text_filter_url: + name += "_text-filter-url" + if args.lang is not None: + name += f"_{args.lang}" + if args.text_dedup: + name += "_text-dedup" + if args.text_filter_substr is not None: + name += f"_text-filter-substr" + if args.filter_too_verbose: + name += "_noverbose" + if args.h_le_w: + name += "_h-le-w" + if args.filter_human: + name += "_human" + + # caption processing + if args.text_remove_prefix: + name += "_text-remove-prefix" + if args.text_refine_t5: + name += "_text-refine-t5" + if args.text_image2video: + name += "_text-image2video" + if args.text_refine_sentences: + name += "_text-refine-sentences" + if args.text_score2text: + name += "_text-score2text" + if args.text_undo_score2text: + name += "_text-undo-score2text" + if args.merge_cmotion: + name += "_cmcaption" + if args.count_num_token: + name += "_ntoken" + if args.text_append is not None: + name += "_text-append" + if args.update_text is not None: + name += "_update" + + # dynmaic fps + if args.filter_dyn_fps is not False: + name += "_dynfps" + + # score filtering + if args.filesize: + name += "_filesize" + if args.fsmax is not None: + name += f"_fsmax{args.fsmax}" + if args.fsmin is not None: + name += f"_fsmin{args.fsmin}" + if args.fmin is not None: + name += f"_fmin{args.fmin}" + if args.fmax is not None: + name += f"_fmax{args.fmax}" + if args.fpsmax is not None: + name += f"_fpsmax{args.fpsmax}" + if args.hwmax is not None: + name += f"_hwmax{args.hwmax}" + if args.aesmin is not None: + name += f"_aesmin{args.aesmin}" + if args.prefmin is not None: + name += f"_prefmin{args.prefmin}" + if args.matchmin is not None: + name += f"_matchmin{args.matchmin}" + if args.flowmin is not None: + name += f"_flowmin{args.flowmin}" + if args.facemin is not None: + name += f"_facemin{args.facewmin}" + if args.img_only: + name += "_img" + if args.vid_only: + name += "_vid" + + # processing + if args.shuffle: + name += f"_shuffled_seed{args.seed}" + if args.head is not None: + name += f"_first_{args.head}_data" + if args.sample is not None: + name += f"_sample-{args.sample}" + + output_path = os.path.join(dir_path, f"{name}.{args.format}") + return output_path + + +if __name__ == "__main__": + args = parse_args() + if args.disable_parallel: + PANDA_USE_PARALLEL = False + if PANDA_USE_PARALLEL: + if args.num_workers is not None: + pandarallel.initialize(nb_workers=args.num_workers, progress_bar=True) + else: + pandarallel.initialize(progress_bar=True) + if args.seed is not None: + random.seed(args.seed) + np.random.seed(args.seed) + main(args) diff --git a/tools/datasets/filter_panda10m.py b/tools/datasets/filter_panda10m.py new file mode 100644 index 0000000..86a9f19 --- /dev/null +++ b/tools/datasets/filter_panda10m.py @@ -0,0 +1,262 @@ +# TODO: remove this file before releasing + +import argparse +import html +import os +import re + +import pandas as pd +from tqdm import tqdm + +tqdm.pandas() + +try: + from pandarallel import pandarallel + + pandarallel.initialize(progress_bar=True) + pandas_has_parallel = True +except ImportError: + pandas_has_parallel = False + + +def apply(df, func, **kwargs): + if pandas_has_parallel: + return df.parallel_apply(func, **kwargs) + return df.progress_apply(func, **kwargs) + + +def basic_clean(text): + import ftfy + + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +BAD_PUNCT_REGEX = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" +) # noqa + + +def clean_caption(caption): + import urllib.parse as ul + + from bs4 import BeautifulSoup + + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = basic_clean(caption) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + +def get_10m_set(): + meta_path_10m = "/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m.csv" + meta_10m = pd.read_csv(meta_path_10m) + + def process_single_caption(row): + text_list = eval(row["caption"]) + clean_list = [clean_caption(x) for x in text_list] + return str(clean_list) + + ret = apply(meta_10m, process_single_caption, axis=1) + # ret = meta_10m.progress_apply(process_single_caption, axis=1) + print("==> text processed.") + + text_list = [] + for x in ret: + text_list += eval(x) + # text_set = text_set.union(set(eval(x))) + text_set = set(text_list) + # meta_10m['caption_new'] = ret + # meta_10m.to_csv('/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m_new-cap.csv') + + # video_id_set = set(meta_10m['videoID']) + # id2t = {} + # for idx, row in tqdm(meta_10m.iterrows(), total=len(meta_10m)): + # video_id = row['videoID'] + # text_list = eval(row['caption']) + # id2t[video_id] = set(text_list) + + print(f"==> Loaded meta_10m from '{meta_path_10m}'") + return text_set + + +def filter_panda10m_text(meta_path, text_set): + def process_single_row(row): + # path = row['path'] + t = row["text"] + # fname = os.path.basename(path) + # video_id = fname[:fname.rindex('_')] + if t not in text_set: + return False + return True + + meta = pd.read_csv(meta_path) + ret = apply(meta, process_single_row, axis=1) + # ret = meta.progress_apply(process_single_row, axis=1) + + meta = meta[ret] + wo_ext, ext = os.path.splitext(meta_path) + out_path = f"{wo_ext}_filter-10m{ext}" + meta.to_csv(out_path, index=False) + print(f"New meta (shape={meta.shape}) saved to '{out_path}'.") + + +def filter_panda10m_timestamp(meta_path): + meta_path_10m = "/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m.csv" + meta_10m = pd.read_csv(meta_path_10m) + + id2t = {} + for idx, row in tqdm(meta_10m.iterrows(), total=len(meta_10m)): + video_id = row["videoID"] + timestamp = eval(row["timestamp"]) + timestamp = [str(tuple(x)) for x in timestamp] + id2t[video_id] = timestamp + + # video_id_set_10m = set(meta_10m['videoID']) + print(f"==> Loaded meta_10m from '{meta_path_10m}'") + + def process_single_row(row): + path = row["path"] + t = row["timestamp"] + fname = os.path.basename(path) + video_id = fname[: fname.rindex("_")] + if video_id not in id2t: + return False + if t not in id2t[video_id]: + return False + return True + # return video_id in video_id_set_10m + + meta = pd.read_csv(meta_path) + ret = apply(meta, process_single_row, axis=1) + + meta = meta[ret] + wo_ext, ext = os.path.splitext(meta_path) + out_path = f"{wo_ext}_filter-10m{ext}" + meta.to_csv(out_path, index=False) + print(f"New meta (shape={meta.shape}) saved to '{out_path}'.") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--meta_path", type=str, nargs="+") + parser.add_argument("--num_workers", default=5, type=int) + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + + text_set = get_10m_set() + for x in args.meta_path: + filter_panda10m_text(x, text_set) diff --git a/tools/datasets/save_first_frame.py b/tools/datasets/save_first_frame.py new file mode 100644 index 0000000..bcb63dc --- /dev/null +++ b/tools/datasets/save_first_frame.py @@ -0,0 +1,66 @@ +import argparse +import os + +import cv2 +import pandas as pd +from tqdm import tqdm + +tqdm.pandas() + +try: + from pandarallel import pandarallel + + PANDA_USE_PARALLEL = True +except ImportError: + PANDA_USE_PARALLEL = False + + +def save_first_frame(video_path, img_dir): + if not os.path.exists(video_path): + print(f"Video not found: {video_path}") + return "" + + try: + cap = cv2.VideoCapture(video_path) + success, frame = cap.read() + if success: + video_name = os.path.basename(video_path) + image_name = os.path.splitext(video_name)[0] + "_first_frame.jpg" + image_path = os.path.join(img_dir, image_name) + + cv2.imwrite(image_path, frame) + else: + raise ValueError("Video broken.") + cap.release() + return image_path + except Exception as e: + print(f"Save first frame of `{video_path}` failed. {e}") + return "" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str, help="path to the input csv dataset") + parser.add_argument("--img-dir", type=str, help="path to save first frame image") + parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing") + parser.add_argument("--num-workers", type=int, default=None, help="number of workers") + args = parser.parse_args() + + if args.disable_parallel: + PANDA_USE_PARALLEL = False + if PANDA_USE_PARALLEL: + if args.num_workers is not None: + pandarallel.initialize(nb_workers=args.num_workers, progress_bar=True) + else: + pandarallel.initialize(progress_bar=True) + + if not os.path.exists(args.img_dir): + os.makedirs(args.img_dir) + + data = pd.read_csv(args.input) + + data["first_frame_path"] = data["path"].parallel_apply(save_first_frame, img_dir=args.img_dir) + data_filtered = data.loc[data["first_frame_path"] != ""] + output_csv_path = args.input.replace(".csv", "_first-frame.csv") + data_filtered.to_csv(output_csv_path, index=False) + print(f"First frame csv saved to: {output_csv_path}, first frame images saved to {args.img_dir}.") diff --git a/tools/datasets/split.py b/tools/datasets/split.py new file mode 100644 index 0000000..4e312b2 --- /dev/null +++ b/tools/datasets/split.py @@ -0,0 +1,72 @@ +import argparse +from typing import List + +import pandas as pd +from mmengine.config import Config + +from opensora.datasets.bucket import Bucket + + +def split_by_bucket( + bucket: Bucket, + input_files: List[str], + output_path: str, + limit: int, + frame_interval: int, +): + print(f"Split {len(input_files)} files into {len(bucket)} buckets") + total_limit = len(bucket) * limit + bucket_cnt = {} + # get all bucket id + for hw_id, d in bucket.ar_criteria.items(): + for t_id, v in d.items(): + for ar_id in v.keys(): + bucket_id = (hw_id, t_id, ar_id) + bucket_cnt[bucket_id] = 0 + output_df = None + # split files + for path in input_files: + df = pd.read_csv(path) + if output_df is None: + output_df = pd.DataFrame(columns=df.columns) + for i in range(len(df)): + row = df.iloc[i] + t, h, w = row["num_frames"], row["height"], row["width"] + bucket_id = bucket.get_bucket_id(t, h, w, frame_interval) + if bucket_id is None: + continue + if bucket_cnt[bucket_id] < limit: + bucket_cnt[bucket_id] += 1 + output_df = pd.concat([output_df, pd.DataFrame([row])], ignore_index=True) + if len(output_df) >= total_limit: + break + if len(output_df) >= total_limit: + break + assert len(output_df) <= total_limit + if len(output_df) == total_limit: + print(f"All buckets are full ({total_limit} samples)") + else: + print(f"Only {len(output_df)} files are used") + output_df.to_csv(output_path, index=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str, nargs="+") + parser.add_argument("-o", "--output", required=True) + parser.add_argument("-c", "--config", required=True) + parser.add_argument("-l", "--limit", default=200, type=int) + args = parser.parse_args() + assert args.limit > 0 + + cfg = Config.fromfile(args.config) + bucket_config = cfg.bucket_config + # rewrite bucket_config + for ar, d in bucket_config.items(): + for frames, t in d.items(): + p, bs = t + if p > 0.0: + p = 1.0 + d[frames] = (p, bs) + bucket = Bucket(bucket_config) + split_by_bucket(bucket, args.input, args.output, args.limit, cfg.dataset.frame_interval) diff --git a/tools/datasets/transform.py b/tools/datasets/transform.py new file mode 100644 index 0000000..53dbaf3 --- /dev/null +++ b/tools/datasets/transform.py @@ -0,0 +1,306 @@ +import argparse +import os +import random +import shutil +import subprocess + +import cv2 +import ffmpeg +import numpy as np +import pandas as pd +from pandarallel import pandarallel +from tqdm import tqdm + +from .utils import IMG_EXTENSIONS, extract_frames + +tqdm.pandas() +USE_PANDARALLEL = True + + +def apply(df, func, **kwargs): + if USE_PANDARALLEL: + return df.parallel_apply(func, **kwargs) + return df.progress_apply(func, **kwargs) + + +def get_new_path(path, input_dir, output): + path_new = os.path.join(output, os.path.relpath(path, input_dir)) + os.makedirs(os.path.dirname(path_new), exist_ok=True) + return path_new + + +def resize_longer(path, length, input_dir, output_dir): + path_new = get_new_path(path, input_dir, output_dir) + ext = os.path.splitext(path)[1].lower() + assert ext in IMG_EXTENSIONS + img = cv2.imread(path) + if img is not None: + h, w = img.shape[:2] + if min(h, w) > length: + if h > w: + new_h = length + new_w = int(w / h * length) + else: + new_w = length + new_h = int(h / w * length) + img = cv2.resize(img, (new_w, new_h)) + cv2.imwrite(path_new, img) + else: + path_new = "" + return path_new + + +def resize_shorter(path, length, input_dir, output_dir): + path_new = get_new_path(path, input_dir, output_dir) + if os.path.exists(path_new): + return path_new + + ext = os.path.splitext(path)[1].lower() + assert ext in IMG_EXTENSIONS + img = cv2.imread(path) + if img is not None: + h, w = img.shape[:2] + if min(h, w) > length: + if h > w: + new_w = length + new_h = int(h / w * length) + else: + new_h = length + new_w = int(w / h * length) + img = cv2.resize(img, (new_w, new_h)) + cv2.imwrite(path_new, img) + else: + path_new = "" + return path_new + + +def rand_crop(path, input_dir, output): + ext = os.path.splitext(path)[1].lower() + path_new = get_new_path(path, input_dir, output) + assert ext in IMG_EXTENSIONS + img = cv2.imread(path) + if img is not None: + h, w = img.shape[:2] + width, height, _ = img.shape + pos = random.randint(0, 3) + if pos == 0: + img_cropped = img[: width // 2, : height // 2] + elif pos == 1: + img_cropped = img[width // 2 :, : height // 2] + elif pos == 2: + img_cropped = img[: width // 2, height // 2 :] + else: + img_cropped = img[width // 2 :, height // 2 :] + cv2.imwrite(path_new, img_cropped) + else: + path_new = "" + return path_new + + +def m2ts_to_mp4(row, output_dir): + input_path = row["path"] + output_name = os.path.basename(input_path).replace(".m2ts", ".mp4") + output_path = os.path.join(output_dir, output_name) + # create directory if it doesn't exist + os.makedirs(os.path.dirname(output_path), exist_ok=True) + try: + ffmpeg.input(input_path).output(output_path).overwrite_output().global_args("-loglevel", "quiet").run( + capture_stdout=True + ) + row["path"] = output_path + row["relpath"] = os.path.splitext(row["relpath"])[0] + ".mp4" + except Exception as e: + print(f"Error converting {input_path} to mp4: {e}") + row["path"] = "" + row["relpath"] = "" + return row + return row + + +def mkv_to_mp4(row, output_dir): + # str_to_replace and str_to_replace_with account for the different directory structure + input_path = row["path"] + output_name = os.path.basename(input_path).replace(".mkv", ".mp4") + output_path = os.path.join(output_dir, output_name) + + # create directory if it doesn't exist + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + try: + ffmpeg.input(input_path).output(output_path).overwrite_output().global_args("-loglevel", "quiet").run( + capture_stdout=True + ) + row["path"] = output_path + row["relpath"] = os.path.splitext(row["relpath"])[0] + ".mp4" + except Exception as e: + print(f"Error converting {input_path} to mp4: {e}") + row["path"] = "" + row["relpath"] = "" + return row + return row + + +def mp4_to_mp4(row, output_dir): + # str_to_replace and str_to_replace_with account for the different directory structure + input_path = row["path"] + + # 检查输入文件是否为.mp4文件 + if not input_path.lower().endswith(".mp4"): + print(f"Error: {input_path} is not an .mp4 file.") + row["path"] = "" + row["relpath"] = "" + return row + output_name = os.path.basename(input_path) + output_path = os.path.join(output_dir, output_name) + + # create directory if it doesn't exist + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + try: + shutil.copy2(input_path, output_path) # 使用shutil复制文件 + row["path"] = output_path + row["relpath"] = os.path.splitext(row["relpath"])[0] + ".mp4" + except Exception as e: + print(f"Error coy {input_path} to mp4: {e}") + row["path"] = "" + row["relpath"] = "" + return row + return row + + +def crop_to_square(input_path, output_path): + cmd = ( + f"ffmpeg -i {input_path} " + f"-vf \"crop='min(in_w,in_h)':'min(in_w,in_h)':'(in_w-min(in_w,in_h))/2':'(in_h-min(in_w,in_h))/2'\" " + f"-c:v libx264 -an " + f"-map 0:v {output_path}" + ) + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True) + stdout, stderr = proc.communicate() + + +def vid_crop_center(row, input_dir, output_dir): + input_path = row["path"] + relpath = os.path.relpath(input_path, input_dir) + assert not relpath.startswith("..") + output_path = os.path.join(output_dir, relpath) + + # create directory if it doesn't exist + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + try: + crop_to_square(input_path, output_path) + size = min(row["height"], row["width"]) + row["path"] = output_path + row["height"] = size + row["width"] = size + row["aspect_ratio"] = 1.0 + row["resolution"] = size**2 + except Exception as e: + print(f"Error cropping {input_path} to center: {e}") + row["path"] = "" + return row + + +def main(): + args = parse_args() + global USE_PANDARALLEL + + assert args.num_workers is None or not args.disable_parallel + if args.disable_parallel: + USE_PANDARALLEL = False + if args.num_workers is not None: + pandarallel.initialize(progress_bar=True, nb_workers=args.num_workers) + else: + pandarallel.initialize(progress_bar=True) + + random.seed(args.seed) + data = pd.read_csv(args.meta_path) + if args.task == "img_rand_crop": + data["path"] = apply(data["path"], lambda x: rand_crop(x, args.input_dir, args.output_dir)) + output_csv = args.meta_path.replace(".csv", "_rand_crop.csv") + elif args.task == "img_resize_longer": + data["path"] = apply(data["path"], lambda x: resize_longer(x, args.length, args.input_dir, args.output_dir)) + output_csv = args.meta_path.replace(".csv", f"_resize-longer-{args.length}.csv") + elif args.task == "img_resize_shorter": + data["path"] = apply(data["path"], lambda x: resize_shorter(x, args.length, args.input_dir, args.output_dir)) + output_csv = args.meta_path.replace(".csv", f"_resize-shorter-{args.length}.csv") + elif args.task == "vid_frame_extract": + points = args.points if args.points is not None else args.points_index + data = pd.DataFrame(np.repeat(data.values, 3, axis=0), columns=data.columns) + num_points = len(points) + data["point"] = np.nan + for i, point in enumerate(points): + if isinstance(point, int): + data.loc[i::num_points, "point"] = point + else: + data.loc[i::num_points, "point"] = data.loc[i::num_points, "num_frames"] * point + data["path"] = apply( + data, lambda x: extract_frames(x["path"], args.input_dir, args.output_dir, x["point"]), axis=1 + ) + output_csv = args.meta_path.replace(".csv", "_vid_frame_extract.csv") + elif args.task == "m2ts_to_mp4": + print(f"m2ts_to_mp4作业开始:{args.output_dir}") + assert args.meta_path.endswith("_m2ts.csv"), "Input file must end with '_m2ts.csv'" + m2ts_to_mp4_partial = lambda x: m2ts_to_mp4(x, args.output_dir) + data = apply(data, m2ts_to_mp4_partial, axis=1) + data = data[data["path"] != ""] + output_csv = args.meta_path.replace("_m2ts.csv", ".csv") + elif args.task == "mkv_to_mp4": + print(f"mkv_to_mp4作业开始:{args.output_dir}") + assert args.meta_path.endswith("_mkv.csv"), "Input file must end with '_mkv.csv'" + mkv_to_mp4_partial = lambda x: mkv_to_mp4(x, args.output_dir) + data = apply(data, mkv_to_mp4_partial, axis=1) + data = data[data["path"] != ""] + output_csv = args.meta_path.replace("_mkv.csv", ".csv") + elif args.task == "mp4_to_mp4": + # assert args.meta_path.endswith("meta.csv"), "Input file must end with '_mkv.csv'" + print(f"MP4复制作业开始:{args.output_dir}") + mkv_to_mp4_partial = lambda x: mp4_to_mp4(x, args.output_dir) + data = apply(data, mkv_to_mp4_partial, axis=1) + data = data[data["path"] != ""] + output_csv = args.meta_path + elif args.task == "vid_crop_center": + vid_crop_center_partial = lambda x: vid_crop_center(x, args.input_dir, args.output_dir) + data = apply(data, vid_crop_center_partial, axis=1) + data = data[data["path"] != ""] + output_csv = args.meta_path.replace(".csv", "_center-crop.csv") + else: + raise ValueError + data.to_csv(output_csv, index=False) + print(f"Saved to {output_csv}") + raise SystemExit(0) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--task", + type=str, + required=True, + choices=[ + "img_resize_longer", + "img_resize_shorter", + "img_rand_crop", + "vid_frame_extract", + "m2ts_to_mp4", + "mkv_to_mp4", + "mp4_to_mp4", + "vid_crop_center", + ], + ) + parser.add_argument("--meta_path", type=str, required=True) + parser.add_argument("--input_dir", type=str) + parser.add_argument("--output_dir", type=str) + parser.add_argument("--length", type=int, default=1080) + parser.add_argument("--disable-parallel", action="store_true") + parser.add_argument("--num_workers", type=int, default=None) + parser.add_argument("--seed", type=int, default=42, help="seed for random") + parser.add_argument("--points", nargs="+", type=float, default=None) + parser.add_argument("--points_index", nargs="+", type=int, default=None) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + main() diff --git a/tools/datasets/utils.py b/tools/datasets/utils.py new file mode 100644 index 0000000..4fbe3f6 --- /dev/null +++ b/tools/datasets/utils.py @@ -0,0 +1,130 @@ +import os + +import cv2 +import numpy as np +from PIL import Image + +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") +VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") + + +def is_video(filename): + ext = os.path.splitext(filename)[-1].lower() + return ext in VID_EXTENSIONS + + +def extract_frames( + video_path, + frame_inds=None, + points=None, + backend="opencv", + return_length=False, + num_frames=None, +): + """ + Args: + video_path (str): path to video + frame_inds (List[int]): indices of frames to extract + points (List[float]): values within [0, 1); multiply #frames to get frame indices + Return: + List[PIL.Image] + """ + assert backend in ["av", "opencv", "decord"] + assert (frame_inds is None) or (points is None) + + if backend == "av": + import av + + container = av.open(video_path) + if num_frames is not None: + total_frames = num_frames + else: + total_frames = container.streams.video[0].frames + + if points is not None: + frame_inds = [int(p * total_frames) for p in points] + + frames = [] + for idx in frame_inds: + if idx >= total_frames: + idx = total_frames - 1 + target_timestamp = int(idx * av.time_base / container.streams.video[0].average_rate) + container.seek(target_timestamp) # return the nearest key frame, not the precise timestamp!!! + frame = next(container.decode(video=0)).to_image() + frames.append(frame) + + if return_length: + return frames, total_frames + return frames + + elif backend == "decord": + import decord + + container = decord.VideoReader(video_path, num_threads=1) + if num_frames is not None: + total_frames = num_frames + else: + total_frames = len(container) + + if points is not None: + frame_inds = [int(p * total_frames) for p in points] + + frame_inds = np.array(frame_inds).astype(np.int32) + frame_inds[frame_inds >= total_frames] = total_frames - 1 + frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C] + frames = [Image.fromarray(x) for x in frames] + + if return_length: + return frames, total_frames + return frames + + elif backend == "opencv": + cap = cv2.VideoCapture(video_path) + if num_frames is not None: + total_frames = num_frames + else: + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + if points is not None: + frame_inds = [int(p * total_frames) for p in points] + + frames = [] + for idx in frame_inds: + if idx >= total_frames: + idx = total_frames - 1 + + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + + # HACK: sometimes OpenCV fails to read frames, return a black frame instead + try: + ret, frame = cap.read() + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + except Exception as e: + print(f"[Warning] Error reading frame {idx} from {video_path}: {e}") + # First, try to read the first frame + try: + print(f"[Warning] Try reading first frame.") + cap.set(cv2.CAP_PROP_POS_FRAMES, 0) + ret, frame = cap.read() + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + # If that fails, return a black frame + except Exception as e: + print(f"[Warning] Error in reading first frame from {video_path}: {e}") + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame = Image.new("RGB", (width, height), (0, 0, 0)) + + # HACK: if height or width is 0, return a black frame instead + if frame.height == 0 or frame.width == 0: + height = width = 256 + frame = Image.new("RGB", (width, height), (0, 0, 0)) + + frames.append(frame) + + if return_length: + return frames, total_frames + return frames + else: + raise ValueError