feat: add opensora/datasets module and tools/datasets

- Add opensora/datasets (aspect, bucket, dataloader, datasets, parallel,
  pin_memory_cache, read_video, sampler, utils, video_transforms)
- Add tools/datasets pipeline scripts
- Fix .gitignore: scope /datasets to root-level only, whitelist opensora/datasets/

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-03-06 02:00:19 -08:00
parent 916ee2126d
commit bdeb2870d4
24 changed files with 5467 additions and 1 deletions

3
.gitignore vendored
View File

@ -195,4 +195,5 @@ package.json
exps
ckpts
flash-attention
datasets
/datasets
!opensora/datasets/

View File

@ -0,0 +1,2 @@
from .datasets import TextDataset, VideoTextDataset
from .utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample

151
opensora/datasets/aspect.py Normal file
View File

@ -0,0 +1,151 @@
import math
import os
ASPECT_RATIO_LD_LIST = [ # width:height
"2.39:1", # cinemascope, 2.39
"2:1", # rare, 2
"16:9", # rare, 1.89
"1.85:1", # american widescreen, 1.85
"9:16", # popular, 1.78
"5:8", # rare, 1.6
"3:2", # rare, 1.5
"4:3", # classic, 1.33
"1:1", # square
]
def get_ratio(name: str) -> float:
width, height = map(float, name.split(":"))
return height / width
def get_aspect_ratios_dict(
total_pixels: int = 256 * 256, training: bool = True
) -> dict[str, tuple[int, int]]:
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
aspect_ratios_dict = {}
aspect_ratios_vertical_dict = {}
for ratio in ASPECT_RATIO_LD_LIST:
width_ratio, height_ratio = map(float, ratio.split(":"))
width = int(math.sqrt(total_pixels * (width_ratio / height_ratio)) // D) * D
height = int((total_pixels / width) // D) * D
if training:
# adjust aspect ratio to match total pixels
diff = abs(height * width - total_pixels)
candidate = [
(height - D, width),
(height + D, width),
(height, width - D),
(height, width + D),
]
for h, w in candidate:
if abs(h * w - total_pixels) < diff:
height, width = h, w
diff = abs(h * w - total_pixels)
# remove duplicated aspect ratio
if (height, width) not in aspect_ratios_dict.values() or not training:
aspect_ratios_dict[ratio] = (height, width)
vertial_ratios = ":".join(ratio.split(":")[::-1])
aspect_ratios_vertical_dict[vertial_ratios] = (width, height)
aspect_ratios_dict.update(aspect_ratios_vertical_dict)
return aspect_ratios_dict
def get_num_pexels(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]:
return {ratio: h * w for ratio, (h, w) in aspect_ratios_dict.items()}
def get_num_tokens(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]:
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
return {ratio: h * w // D // D for ratio, (h, w) in aspect_ratios_dict.items()}
def get_num_pexels_from_name(resolution: str) -> int:
resolution = resolution.split("_")[0]
if resolution.endswith("px"):
size = int(resolution[:-2])
num_pexels = size * size
elif resolution.endswith("p"):
size = int(resolution[:-1])
num_pexels = int(size * size / 9 * 16)
else:
raise ValueError(f"Invalid resolution {resolution}")
return num_pexels
def get_resolution_with_aspect_ratio(
resolution: str,
) -> tuple[int, dict[str, tuple[int, int]]]:
"""Get resolution with aspect ratio
Args:
resolution (str): resolution name. The format is name only or "{name}_{setting}".
name supports "256px" or "360p". setting supports "ar1:1" or "max".
Returns:
tuple[int, dict[str, tuple[int, int]]]: resolution with aspect ratio
"""
keys = resolution.split("_")
if len(keys) == 1:
resolution = keys[0]
setting = ""
else:
resolution, setting = keys
assert setting == "max" or setting.startswith(
"ar"
), f"Invalid setting {setting}"
# get resolution
num_pexels = get_num_pexels_from_name(resolution)
# get aspect ratio
aspect_ratio_dict = get_aspect_ratios_dict(num_pexels)
# handle setting
if setting == "max":
aspect_ratio = max(
aspect_ratio_dict,
key=lambda x: aspect_ratio_dict[x][0] * aspect_ratio_dict[x][1],
)
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}
elif setting.startswith("ar"):
aspect_ratio = setting[2:]
assert (
aspect_ratio in aspect_ratio_dict
), f"Aspect ratio {aspect_ratio} not found"
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}
return num_pexels, aspect_ratio_dict
def get_closest_ratio(height: float, width: float, ratios: dict) -> str:
aspect_ratio = height / width
closest_ratio = min(
ratios.keys(), key=lambda ratio: abs(aspect_ratio - get_ratio(ratio))
)
return closest_ratio
def get_image_size(
resolution: str, ar_ratio: str, training: bool = True
) -> tuple[int, int]:
num_pexels = get_num_pexels_from_name(resolution)
ar_dict = get_aspect_ratios_dict(num_pexels, training)
assert ar_ratio in ar_dict, f"Aspect ratio {ar_ratio} not found"
return ar_dict[ar_ratio]
def bucket_to_shapes(bucket_config, batch_size=None):
shapes = []
for resolution, infos in bucket_config.items():
for num_frames, (_, bs) in infos.items():
aspect_ratios = get_aspect_ratios_dict(get_num_pexels_from_name(resolution))
for ar, (height, width) in aspect_ratios.items():
if batch_size is not None:
bs = batch_size
shapes.append((bs, 3, num_frames, height, width))
return shapes

139
opensora/datasets/bucket.py Normal file
View File

@ -0,0 +1,139 @@
from collections import OrderedDict
import numpy as np
from opensora.utils.logger import log_message
from .aspect import get_closest_ratio, get_resolution_with_aspect_ratio
from .utils import map_target_fps
class Bucket:
def __init__(self, bucket_config: dict[str, dict[int, tuple[float, int] | tuple[tuple[float, float], int]]]):
"""
Args:
bucket_config (dict): A dictionary containing the bucket configuration.
The dictionary should be in the following format:
{
"bucket_name": {
"time": (probability, batch_size),
"time": (probability, batch_size),
...
},
...
}
Or in the following format:
{
"bucket_name": {
"time": ((probability, next_probability), batch_size),
"time": ((probability, next_probability), batch_size),
...
},
...
}
The bucket_name should be the name of the bucket, and the time should be the number of frames in the video.
The probability should be a float between 0 and 1, and the batch_size should be an integer.
If the probability is a tuple, the second value should be the probability to skip to the next time.
"""
aspect_ratios = {key: get_resolution_with_aspect_ratio(key) for key in bucket_config.keys()}
bucket_probs = OrderedDict()
bucket_bs = OrderedDict()
bucket_names = sorted(bucket_config.keys(), key=lambda x: aspect_ratios[x][0], reverse=True)
for key in bucket_names:
bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True)
bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names})
bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names})
self.hw_criteria = {k: aspect_ratios[k][0] for k in bucket_names}
self.t_criteria = {k1: {k2: k2 for k2 in bucket_config[k1].keys()} for k1 in bucket_names}
self.ar_criteria = {
k1: {k2: {k3: v3 for k3, v3 in aspect_ratios[k1][1].items()} for k2 in bucket_config[k1].keys()}
for k1 in bucket_names
}
bucket_id_cnt = num_bucket = 0
bucket_id = dict()
for k1, v1 in bucket_probs.items():
bucket_id[k1] = dict()
for k2, _ in v1.items():
bucket_id[k1][k2] = bucket_id_cnt
bucket_id_cnt += 1
num_bucket += len(aspect_ratios[k1][1])
self.bucket_probs = bucket_probs
self.bucket_bs = bucket_bs
self.bucket_id = bucket_id
self.num_bucket = num_bucket
log_message("Number of buckets: %s", num_bucket)
def get_bucket_id(
self,
T: int,
H: int,
W: int,
fps: float,
path: str | None = None,
seed: int | None = None,
fps_max: int = 16,
) -> tuple[str, int, int] | None:
approx = 0.8
_, sampling_interval = map_target_fps(fps, fps_max)
T = T // sampling_interval
resolution = H * W
rng = np.random.default_rng(seed)
# Reference to probabilities and criteria for faster access
bucket_probs = self.bucket_probs
hw_criteria = self.hw_criteria
ar_criteria = self.ar_criteria
# Start searching for the appropriate bucket
for hw_id, t_criteria in bucket_probs.items():
# if resolution is too low, skip
if resolution < hw_criteria[hw_id] * approx:
continue
# if sample is an image
if T == 1:
if 1 in t_criteria:
if rng.random() < t_criteria[1]:
return hw_id, 1, get_closest_ratio(H, W, ar_criteria[hw_id][1])
continue
# Look for suitable t_id for video
for t_id, prob in t_criteria.items():
if T >= t_id and t_id != 1:
# if prob is a tuple, use the second value as the threshold to skip
# to the next t_id
if isinstance(prob, tuple):
next_hw_prob, next_t_prob = prob
if next_t_prob >= 1 or rng.random() <= next_t_prob:
continue
else:
next_hw_prob = prob
if next_hw_prob >= 1 or rng.random() <= next_hw_prob:
ar_id = get_closest_ratio(H, W, ar_criteria[hw_id][t_id])
return hw_id, t_id, ar_id
else:
break
return None
def get_thw(self, bucket_idx: tuple[str, int, int]) -> tuple[int, int, int]:
assert len(bucket_idx) == 3
T = self.t_criteria[bucket_idx[0]][bucket_idx[1]]
H, W = self.ar_criteria[bucket_idx[0]][bucket_idx[1]][bucket_idx[2]]
return T, H, W
def get_prob(self, bucket_idx: tuple[str, int]) -> float:
return self.bucket_probs[bucket_idx[0]][bucket_idx[1]]
def get_batch_size(self, bucket_idx: tuple[str, int]) -> int:
return self.bucket_bs[bucket_idx[0]][bucket_idx[1]]
def __len__(self) -> int:
return self.num_bucket

View File

@ -0,0 +1,402 @@
import collections
import functools
import os
import queue
import random
import threading
import numpy as np
import torch
import torch.multiprocessing as multiprocessing
from torch._utils import ExceptionWrapper
from torch.distributed import ProcessGroup
from torch.utils.data import DataLoader, _utils
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data.dataloader import (
IterDataPipe,
MapDataPipe,
_BaseDataLoaderIter,
_MultiProcessingDataLoaderIter,
_sharding_worker_init_fn,
_SingleProcessDataLoaderIter,
)
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.registry import DATASETS, build_module
from opensora.utils.config import parse_configs
from opensora.utils.logger import create_logger
from opensora.utils.misc import format_duration
from opensora.utils.train import setup_device
from .datasets import TextDataset, VideoTextDataset
from .pin_memory_cache import PinMemoryCache
from .sampler import DistributedSampler, VariableVideoBatchSampler
def _pin_memory_loop(
in_queue, out_queue, device_id, done_event, device, pin_memory_cache: PinMemoryCache, pin_memory_key: str
):
# This setting is thread local, and prevents the copy in pin_memory from
# consuming all CPU cores.
torch.set_num_threads(1)
if device == "cuda":
torch.cuda.set_device(device_id)
elif device == "xpu":
torch.xpu.set_device(device_id) # type: ignore[attr-defined]
elif device == torch._C._get_privateuse1_backend_name():
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
custom_device_mod.set_device(device_id)
def do_one_step():
try:
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
return
idx, data = r
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
try:
assert isinstance(data, dict)
if pin_memory_key in data:
val = data[pin_memory_key]
pin_memory_value = pin_memory_cache.get(val)
pin_memory_value.copy_(val)
data[pin_memory_key] = pin_memory_value
except Exception:
data = ExceptionWrapper(where=f"in pin memory thread for device {device_id}")
r = (idx, data)
while not done_event.is_set():
try:
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
break
except queue.Full:
continue
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
while not done_event.is_set():
# Make sure that we don't preserve any object from one iteration
# to the next
do_one_step()
class _MultiProcessingDataLoaderIterForVideo(_MultiProcessingDataLoaderIter):
pin_memory_key: str = "video"
def __init__(self, loader):
_BaseDataLoaderIter.__init__(self, loader)
self.pin_memory_cache = PinMemoryCache()
self._prefetch_factor = loader.prefetch_factor
assert self._num_workers > 0
assert self._prefetch_factor > 0
if loader.multiprocessing_context is None:
multiprocessing_context = multiprocessing
else:
multiprocessing_context = loader.multiprocessing_context
self._worker_init_fn = loader.worker_init_fn
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
# Additional worker init function will take care of sharding in MP and Distributed
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
self._worker_init_fn = functools.partial(
_sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank
)
# No certainty which module multiprocessing_context is
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
self._worker_pids_set = False
self._shutdown = False
self._workers_done_event = multiprocessing_context.Event()
self._index_queues = []
self._workers = []
for i in range(self._num_workers):
# No certainty which module multiprocessing_context is
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
# Need to `cancel_join_thread` here!
# See sections (2) and (3b) above.
index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
args=(
self._dataset_kind,
self._dataset,
index_queue,
self._worker_result_queue,
self._workers_done_event,
self._auto_collation,
self._collate_fn,
self._drop_last,
self._base_seed,
self._worker_init_fn,
i,
self._num_workers,
self._persistent_workers,
self._shared_seed,
),
)
w.daemon = True
# NB: Process.start() actually take some time as it needs to
# start a process and pass the arguments over via a pipe.
# Therefore, we only add a worker to self._workers list after
# it started, so that we do not call .join() if program dies
# before it starts, and __del__ tries to join but will get:
# AssertionError: can only join a started process.
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event()
# Queue is not type-annotated
self._data_queue = queue.Queue() # type: ignore[var-annotated]
if self._pin_memory_device == "xpu":
current_device = torch.xpu.current_device() # type: ignore[attr-defined]
elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
current_device = custom_device_mod.current_device()
else:
current_device = torch.cuda.current_device() # choose cuda for default
pin_memory_thread = threading.Thread(
target=_pin_memory_loop,
args=(
self._worker_result_queue,
self._data_queue,
current_device,
self._pin_memory_thread_done_event,
self._pin_memory_device,
self.pin_memory_cache,
self.pin_memory_key,
),
)
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue # type: ignore[assignment]
# In some rare cases, persistent workers (daemonic processes)
# would be terminated before `__del__` of iterator is invoked
# when main process exits
# It would cause failure when pin_memory_thread tries to read
# corrupted data from worker_result_queue
# atexit is used to shutdown thread and child processes in the
# right sequence before main process exits
if self._persistent_workers and self._pin_memory:
import atexit
for w in self._workers:
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
# .pid can be None only before process is spawned (not the case, so ignore)
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True
self._reset(loader, first_iter=True)
def remove_cache(self, output_tensor: torch.Tensor):
self.pin_memory_cache.remove(output_tensor)
def get_cache_info(self) -> str:
return str(self.pin_memory_cache)
class DataloaderForVideo(DataLoader):
def _get_iterator(self) -> "_BaseDataLoaderIter":
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIterForVideo(self)
# Deterministic dataloader
def get_seed_worker(seed):
def seed_worker(worker_id):
worker_seed = seed
if seed is not None:
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return seed_worker
def prepare_dataloader(
dataset,
batch_size=None,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
process_group: ProcessGroup | None = None,
bucket_config=None,
num_bucket_build_workers=1,
prefetch_factor=None,
cache_pin_memory=False,
num_groups=1,
**kwargs,
):
_kwargs = kwargs.copy()
if isinstance(dataset, VideoTextDataset):
batch_sampler = VariableVideoBatchSampler(
dataset,
bucket_config,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
verbose=True,
num_bucket_build_workers=num_bucket_build_workers,
num_groups=num_groups,
)
dl_cls = DataloaderForVideo if cache_pin_memory else DataLoader
return (
dl_cls(
dataset,
batch_sampler=batch_sampler,
worker_init_fn=get_seed_worker(seed),
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_default,
prefetch_factor=prefetch_factor,
**_kwargs,
),
batch_sampler,
)
elif isinstance(dataset, TextDataset):
if process_group is None:
return (
DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
worker_init_fn=get_seed_worker(seed),
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
**_kwargs,
),
None,
)
else:
sampler = DistributedSampler(
dataset,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)
return (
DataLoader(
dataset,
sampler=sampler,
worker_init_fn=get_seed_worker(seed),
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_default,
prefetch_factor=prefetch_factor,
**_kwargs,
),
sampler,
)
else:
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
def collate_fn_default(batch):
# filter out None
batch = [x for x in batch if x is not None]
assert len(batch) > 0, "batch is empty"
# HACK: for loading text features
use_mask = False
if "mask" in batch[0] and isinstance(batch[0]["mask"], int):
masks = [x.pop("mask") for x in batch]
texts = [x.pop("text") for x in batch]
texts = torch.cat(texts, dim=1)
use_mask = True
ret = torch.utils.data.default_collate(batch)
if use_mask:
ret["mask"] = masks
ret["text"] = texts
return ret
def collate_fn_batch(batch):
"""
Used only with BatchDistributedSampler
"""
# filter out None
batch = [x for x in batch if x is not None]
res = torch.utils.data.default_collate(batch)
# squeeze the first dimension, which is due to torch.stack() in default_collate()
if isinstance(res, collections.abc.Mapping):
for k, v in res.items():
if isinstance(v, torch.Tensor):
res[k] = v.squeeze(0)
elif isinstance(res, collections.abc.Sequence):
res = [x.squeeze(0) if isinstance(x, torch.Tensor) else x for x in res]
elif isinstance(res, torch.Tensor):
res = res.squeeze(0)
else:
raise TypeError
return res
if __name__ == "__main__":
# NUM_GPU: number of GPUs for training
# TIME_PER_STEP: time per step in seconds
# Example usage:
# torchrun --nproc_per_node 1 -m opensora.datasets.dataloader configs/diffusion/train/video_cond.py
cfg = parse_configs()
setup_device()
logger = create_logger()
# == build dataset ==
dataset = build_module(cfg.dataset, DATASETS)
# == build dataloader ==
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", None),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
prefetch_factor=cfg.get("prefetch_factor", None),
)
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
num_steps_per_epoch = len(dataloader)
num_machines = int(os.environ.get("NUM_MACHINES", 28))
num_gpu = num_machines * 8
logger.info("Number of GPUs: %d", num_gpu)
logger.info("Number of steps per epoch: %d", num_steps_per_epoch // num_gpu)
time_per_step = int(os.environ.get("TIME_PER_STEP", 20))
time_training = num_steps_per_epoch // num_gpu * time_per_step
logger.info("Time per step: %s", format_duration(time_per_step))
logger.info("Time for training: %s", format_duration(time_training))

View File

@ -0,0 +1,315 @@
import os
import random
import numpy as np
import pandas as pd
import torch
from PIL import ImageFile
from torchvision.datasets.folder import pil_loader
from opensora.registry import DATASETS
from .read_video import read_video
from .utils import get_transforms_image, get_transforms_video, is_img, map_target_fps, read_file, temporal_random_crop
ImageFile.LOAD_TRUNCATED_IMAGES = True
VALID_KEYS = ("neg", "path")
K = 10000
class Iloc:
def __init__(self, data, sharded_folder, sharded_folders, rows_per_shard):
self.data = data
self.sharded_folder = sharded_folder
self.sharded_folders = sharded_folders
self.rows_per_shard = rows_per_shard
def __getitem__(self, index):
return Item(
index,
self.data,
self.sharded_folder,
self.sharded_folders,
self.rows_per_shard,
)
class Item:
def __init__(self, index, data, sharded_folder, sharded_folders, rows_per_shard):
self.index = index
self.data = data
self.sharded_folder = sharded_folder
self.sharded_folders = sharded_folders
self.rows_per_shard = rows_per_shard
def __getitem__(self, key):
index = self.index
if key in self.data.columns:
return self.data[key].iloc[index]
else:
shard_idx = index // self.rows_per_shard
idx = index % self.rows_per_shard
shard_parquet = os.path.join(self.sharded_folder, self.sharded_folders[shard_idx])
try:
text_parquet = pd.read_parquet(shard_parquet, engine="fastparquet")
path = text_parquet["path"].iloc[idx]
assert path == self.data["path"].iloc[index]
except Exception as e:
print(f"Error reading {shard_parquet}: {e}")
raise
return text_parquet[key].iloc[idx]
def to_dict(self):
index = self.index
ret = {}
ret.update(self.data.iloc[index].to_dict())
shard_idx = index // self.rows_per_shard
idx = index % self.rows_per_shard
shard_parquet = os.path.join(self.sharded_folder, self.sharded_folders[shard_idx])
try:
text_parquet = pd.read_parquet(shard_parquet, engine="fastparquet")
path = text_parquet["path"].iloc[idx]
assert path == self.data["path"].iloc[index]
ret.update(text_parquet.iloc[idx].to_dict())
except Exception as e:
print(f"Error reading {shard_parquet}: {e}")
ret.update({"text": ""})
return ret
class EfficientParquet:
def __init__(self, df, sharded_folder):
self.data = df
self.total_rows = len(df)
self.rows_per_shard = (self.total_rows + K - 1) // K
self.sharded_folder = sharded_folder
assert os.path.exists(sharded_folder), f"Sharded folder {sharded_folder} does not exist."
self.sharded_folders = os.listdir(sharded_folder)
self.sharded_folders = sorted(self.sharded_folders)
def __len__(self):
return self.total_rows
@property
def iloc(self):
return Iloc(self.data, self.sharded_folder, self.sharded_folders, self.rows_per_shard)
@DATASETS.register_module("text")
class TextDataset(torch.utils.data.Dataset):
"""
Dataset for text data
"""
def __init__(
self,
data_path: str = None,
tokenize_fn: callable = None,
fps_max: int = 16,
vmaf: bool = False,
memory_efficient: bool = False,
**kwargs,
):
self.data_path = data_path
self.data = read_file(data_path, memory_efficient=memory_efficient)
self.memory_efficient = memory_efficient
self.tokenize_fn = tokenize_fn
self.vmaf = vmaf
if fps_max is not None:
self.fps_max = fps_max
else:
self.fps_max = 999999999
def to_efficient(self):
if self.memory_efficient:
addition_data_path = self.data_path.split(".")[0]
self._data = self.data
self.data = EfficientParquet(self._data, addition_data_path)
def getitem(self, index: int) -> dict:
ret = dict()
sample = self.data.iloc[index].to_dict()
sample_fps = sample.get("fps", np.nan)
new_fps, sampling_interval = map_target_fps(sample_fps, self.fps_max)
ret.update({"sampling_interval": sampling_interval})
if "text" in sample:
ret["text"] = sample.pop("text")
postfixs = []
if new_fps != 0 and self.fps_max < 999:
postfixs.append(f"{new_fps} FPS")
if self.vmaf and "score_vmafmotion" in sample and not np.isnan(sample["score_vmafmotion"]):
postfixs.append(f"{int(sample['score_vmafmotion'] + 0.5)} motion score")
postfix = " " + ", ".join(postfixs) + "." if postfixs else ""
ret["text"] = ret["text"] + postfix
if self.tokenize_fn is not None:
ret.update({k: v.squeeze(0) for k, v in self.tokenize_fn(ret["text"]).items()})
if "ref" in sample: # i2v & v2v reference
ret["ref"] = sample.pop("ref")
# name of the generated sample
if "name" in sample: # sample name (`dataset_idx`)
ret["name"] = sample.pop("name")
else:
ret["index"] = index # use index for name
valid_sample = {k: v for k, v in sample.items() if k in VALID_KEYS}
ret.update(valid_sample)
return ret
def __getitem__(self, index):
return self.getitem(index)
def __len__(self):
return len(self.data)
@DATASETS.register_module("video_text")
class VideoTextDataset(TextDataset):
def __init__(
self,
transform_name: str = None,
bucket_class: str = "Bucket",
rand_sample_interval: int = None, # random sample_interval value from [1, min(rand_sample_interval, video_allowed_max)]
**kwargs,
):
super().__init__(**kwargs)
self.transform_name = transform_name
self.bucket_class = bucket_class
self.rand_sample_interval = rand_sample_interval
def get_image(self, index: int, height: int, width: int) -> dict:
sample = self.data.iloc[index]
path = sample["path"]
# loading
image = pil_loader(path)
# transform
transform = get_transforms_image(self.transform_name, (height, width))
image = transform(image)
# CHW -> CTHW
video = image.unsqueeze(1)
return {"video": video}
def get_video(self, index: int, num_frames: int, height: int, width: int, sampling_interval: int) -> dict:
sample = self.data.iloc[index]
path = sample["path"]
# loading
vframes, vinfo = read_video(path, backend="av")
if self.rand_sample_interval is not None:
# randomly sample from 1 - self.rand_sample_interval
video_allowed_max = min(len(vframes) // num_frames, self.rand_sample_interval)
sampling_interval = random.randint(1, video_allowed_max)
# Sampling video frames
video = temporal_random_crop(vframes, num_frames, sampling_interval)
video = video.clone()
del vframes
# transform
transform = get_transforms_video(self.transform_name, (height, width))
video = transform(video) # T C H W
video = video.permute(1, 0, 2, 3)
ret = {"video": video}
return ret
def get_image_or_video(self, index: int, num_frames: int, height: int, width: int, sampling_interval: int) -> dict:
sample = self.data.iloc[index]
path = sample["path"]
if is_img(path):
return self.get_image(index, height, width)
return self.get_video(index, num_frames, height, width, sampling_interval)
def getitem(self, index: str) -> dict:
# a hack to pass in the (time, height, width) info from sampler
index, num_frames, height, width = [int(val) for val in index.split("-")]
ret = dict()
ret.update(super().getitem(index))
try:
ret.update(self.get_image_or_video(index, num_frames, height, width, ret["sampling_interval"]))
except Exception as e:
path = self.data.iloc[index]["path"]
print(f"video {path}: {e}")
return None
return ret
def __getitem__(self, index):
return self.getitem(index)
@DATASETS.register_module("cached_video_text")
class CachedVideoTextDataset(VideoTextDataset):
def __init__(
self,
transform_name: str = None,
bucket_class: str = "Bucket",
rand_sample_interval: int = None, # random sample_interval value from [1, min(rand_sample_interval, video_allowed_max)]
cached_video: bool = False,
cached_text: bool = False,
return_latents_path: bool = False,
load_original_video: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.transform_name = transform_name
self.bucket_class = bucket_class
self.rand_sample_interval = rand_sample_interval
self.cached_video = cached_video
self.cached_text = cached_text
self.return_latents_path = return_latents_path
self.load_original_video = load_original_video
def get_latents(self, path):
try:
latents = torch.load(path, map_location=torch.device("cpu"))
except Exception as e:
print(f"Error loading latents from {path}: {e}")
return torch.zeros_like(torch.randn(1, 1, 1, 1))
return latents
def get_conditioning_latents(self, index: int) -> dict:
sample = self.data.iloc[index]
latents_path = sample["latents_path"]
text_t5_path = sample["text_t5_path"]
text_clip_path = sample["text_clip_path"]
res = dict()
if self.cached_video:
latents = self.get_latents(latents_path)
res["video_latents"] = latents
if self.cached_text:
text_t5 = self.get_latents(text_t5_path)
text_clip = self.get_latents(text_clip_path)
res["text_t5"] = text_t5
res["text_clip"] = text_clip
if self.return_latents_path:
res["latents_path"] = latents_path
res["text_t5_path"] = text_t5_path
res["text_clip_path"] = text_clip_path
return res
def getitem(self, index: str) -> dict:
# a hack to pass in the (time, height, width) info from sampler
real_index, num_frames, height, width = [int(val) for val in index.split("-")]
ret = dict()
if self.load_original_video:
ret.update(super().getitem(index))
try:
ret.update(self.get_conditioning_latents(real_index))
except Exception as e:
path = self.data.iloc[real_index]["path"]
print(f"video {path}: {e}")
return None
return ret
def __getitem__(self, index):
return self.getitem(index)

View File

@ -0,0 +1,176 @@
import multiprocessing
from itertools import count
from multiprocessing.managers import SyncManager
from typing import Any, Callable, Dict, Tuple, Type, cast
import dill
import pandarallel
import pandas as pd
from pandarallel.data_types import DataType
from pandarallel.progress_bars import ProgressBarsType, get_progress_bars, progress_wrapper
from pandarallel.utils import WorkerStatus
CONTEXT = multiprocessing.get_context("fork")
TMP = []
class WrapWorkFunctionForPipe:
def __init__(
self,
work_function: Callable[
[
Any,
Callable,
tuple,
Dict[str, Any],
Dict[str, Any],
],
Any,
],
) -> None:
self.work_function = work_function
def __call__(
self,
progress_bars_type: ProgressBarsType,
worker_index: int,
master_workers_queue: multiprocessing.Queue,
dilled_user_defined_function: bytes,
user_defined_function_args: tuple,
user_defined_function_kwargs: Dict[str, Any],
extra: Dict[str, Any],
) -> Any:
try:
data = TMP[worker_index]
data_size = len(data)
user_defined_function: Callable = dill.loads(dilled_user_defined_function)
progress_wrapped_user_defined_function = progress_wrapper(
user_defined_function, master_workers_queue, worker_index, data_size
)
used_user_defined_function = (
progress_wrapped_user_defined_function
if progress_bars_type
in (
ProgressBarsType.InUserDefinedFunction,
ProgressBarsType.InUserDefinedFunctionMultiplyByNumberOfColumns,
)
else user_defined_function
)
results = self.work_function(
data,
used_user_defined_function,
user_defined_function_args,
user_defined_function_kwargs,
extra,
)
master_workers_queue.put((worker_index, WorkerStatus.Success, None))
return results
except:
master_workers_queue.put((worker_index, WorkerStatus.Error, None))
raise
def parallelize_with_pipe(
nb_requested_workers: int,
data_type: Type[DataType],
progress_bars_type: ProgressBarsType,
):
def closure(
data: Any,
user_defined_function: Callable,
*user_defined_function_args: tuple,
**user_defined_function_kwargs: Dict[str, Any],
):
wrapped_work_function = WrapWorkFunctionForPipe(data_type.work)
dilled_user_defined_function = dill.dumps(user_defined_function)
manager: SyncManager = CONTEXT.Manager()
master_workers_queue = manager.Queue()
chunks = list(
data_type.get_chunks(
nb_requested_workers,
data,
user_defined_function_kwargs=user_defined_function_kwargs,
)
)
TMP.extend(chunks)
nb_workers = len(chunks)
multiplicator_factor = (
len(cast(pd.DataFrame, data).columns)
if progress_bars_type == ProgressBarsType.InUserDefinedFunctionMultiplyByNumberOfColumns
else 1
)
progresses_length = [len(chunk_) * multiplicator_factor for chunk_ in chunks]
work_extra = data_type.get_work_extra(data)
reduce_extra = data_type.get_reduce_extra(data, user_defined_function_kwargs)
show_progress_bars = progress_bars_type != ProgressBarsType.No
progress_bars = get_progress_bars(progresses_length, show_progress_bars)
progresses = [0] * nb_workers
workers_status = [WorkerStatus.Running] * nb_workers
work_args_list = [
(
progress_bars_type,
worker_index,
master_workers_queue,
dilled_user_defined_function,
user_defined_function_args,
user_defined_function_kwargs,
{
**work_extra,
**{
"master_workers_queue": master_workers_queue,
"show_progress_bars": show_progress_bars,
"worker_index": worker_index,
},
},
)
for worker_index in range(nb_workers)
]
pool = CONTEXT.Pool(nb_workers)
results_promise = pool.starmap_async(wrapped_work_function, work_args_list)
pool.close()
generation = count()
while any((worker_status == WorkerStatus.Running for worker_status in workers_status)):
message: Tuple[int, WorkerStatus, Any] = master_workers_queue.get()
worker_index, worker_status, payload = message
workers_status[worker_index] = worker_status
if worker_status == WorkerStatus.Success:
progresses[worker_index] = progresses_length[worker_index]
progress_bars.update(progresses)
elif worker_status == WorkerStatus.Running:
progress = cast(int, payload)
progresses[worker_index] = progress
if next(generation) % nb_workers == 0:
progress_bars.update(progresses)
elif worker_status == WorkerStatus.Error:
progress_bars.set_error(worker_index)
results = results_promise.get()
TMP.clear()
return data_type.reduce(results, reduce_extra)
return closure
pandarallel.core.WrapWorkFunctionForPipe = WrapWorkFunctionForPipe
pandarallel.core.parallelize_with_pipe = parallelize_with_pipe
pandarallel = pandarallel.pandarallel

View File

@ -0,0 +1,76 @@
import threading
from typing import Dict, List, Optional
import torch
class PinMemoryCache:
force_dtype: Optional[torch.dtype] = None
min_cache_numel: int = 0
pre_alloc_numels: List[int] = []
def __init__(self):
self.cache: Dict[int, torch.Tensor] = {}
self.output_to_cache: Dict[int, int] = {}
self.cache_to_output: Dict[int, int] = {}
self.lock = threading.Lock()
self.total_cnt = 0
self.hit_cnt = 0
if len(self.pre_alloc_numels) > 0 and self.force_dtype is not None:
for n in self.pre_alloc_numels:
cache_tensor = torch.empty(n, dtype=self.force_dtype, device="cpu", pin_memory=True)
with self.lock:
self.cache[id(cache_tensor)] = cache_tensor
def get(self, tensor: torch.Tensor) -> torch.Tensor:
"""Receive a cpu tensor and return the corresponding pinned tensor. Note that this only manage memory allocation, doesn't copy content.
Args:
tensor (torch.Tensor): The tensor to be pinned.
Returns:
torch.Tensor: The pinned tensor.
"""
self.total_cnt += 1
with self.lock:
# find free cache
for cache_id, cache_tensor in self.cache.items():
if cache_id not in self.cache_to_output and cache_tensor.numel() >= tensor.numel():
target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape)
out_id = id(target_cache_tensor)
self.output_to_cache[out_id] = cache_id
self.cache_to_output[cache_id] = out_id
self.hit_cnt += 1
return target_cache_tensor
# no free cache, create a new one
dtype = self.force_dtype if self.force_dtype is not None else tensor.dtype
cache_numel = max(tensor.numel(), self.min_cache_numel)
cache_tensor = torch.empty(cache_numel, dtype=dtype, device="cpu", pin_memory=True)
target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape)
out_id = id(target_cache_tensor)
with self.lock:
self.cache[id(cache_tensor)] = cache_tensor
self.output_to_cache[out_id] = id(cache_tensor)
self.cache_to_output[id(cache_tensor)] = out_id
return target_cache_tensor
def remove(self, output_tensor: torch.Tensor) -> None:
"""Release corresponding cache tensor.
Args:
output_tensor (torch.Tensor): The tensor to be released.
"""
out_id = id(output_tensor)
with self.lock:
if out_id not in self.output_to_cache:
raise ValueError("Tensor not found in cache.")
cache_id = self.output_to_cache.pop(out_id)
del self.cache_to_output[cache_id]
def __str__(self):
with self.lock:
num_cached = len(self.cache)
num_used = len(self.output_to_cache)
total_cache_size = sum([v.numel() * v.element_size() for v in self.cache.values()])
return f"PinMemoryCache(num_cached={num_cached}, num_used={num_used}, total_cache_size={total_cache_size / 1024**3:.2f} GB, hit rate={self.hit_cnt / self.total_cnt:.2f})"

View File

@ -0,0 +1,257 @@
import gc
import math
import os
import re
import warnings
from fractions import Fraction
import av
import cv2
import numpy as np
import torch
from torchvision import get_video_backend
from torchvision.io.video import _check_av_available
MAX_NUM_FRAMES = 2500
def read_video_av(
filename: str,
start_pts: float | Fraction = 0,
end_pts: float | Fraction | None = None,
pts_unit: str = "pts",
output_format: str = "THWC",
) -> tuple[torch.Tensor, torch.Tensor, dict]:
"""
Reads a video from a file, returning both the video frames and the audio frames
This method is modified from torchvision.io.video.read_video, with the following changes:
1. will not extract audio frames and return empty for aframes
2. remove checks and only support pyav
3. add container.close() and gc.collect() to avoid thread leakage
4. try our best to avoid memory leak
Args:
filename (str): path to the video file
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
The start presentation time of the video
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
The end presentation time
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
either 'pts' or 'sec'. Defaults to 'pts'.
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
Returns:
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
info (dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
"""
# format
output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
# file existence
if not os.path.exists(filename):
raise RuntimeError(f"File not found: {filename}")
# backend check
assert get_video_backend() == "pyav", "pyav backend is required for read_video_av"
_check_av_available()
# end_pts check
if end_pts is None:
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
# == get video info ==
info = {}
# TODO: creating an container leads to memory leak (1G for 8 workers 1 GPU)
container = av.open(filename, metadata_errors="ignore")
# fps
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)
iter_video = container.decode(**{"video": 0})
frame = next(iter_video).to_rgb().to_ndarray()
height, width = frame.shape[:2]
total_frames = container.streams.video[0].frames
if total_frames == 0:
total_frames = MAX_NUM_FRAMES
warnings.warn(f"total_frames is 0, using {MAX_NUM_FRAMES} as a fallback")
container.close()
del container
# HACK: must create before iterating stream
# use np.zeros will not actually allocate memory
# use np.ones will lead to a little memory leak
video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8)
# == read ==
try:
# TODO: The reading has memory leak (4G for 8 workers 1 GPU)
container = av.open(filename, metadata_errors="ignore")
assert container.streams.video is not None
video_frames = _read_from_stream(
video_frames,
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
filename=filename,
)
except av.AVError as e:
print(f"[Warning] Error while reading video {filename}: {e}")
vframes = torch.from_numpy(video_frames).clone()
del video_frames
if output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
vframes = vframes.permute(0, 3, 1, 2)
aframes = torch.empty((1, 0), dtype=torch.float32)
return vframes, aframes, info
def _read_from_stream(
video_frames,
container: "av.container.Container",
start_offset: float,
end_offset: float,
pts_unit: str,
stream: "av.stream.Stream",
stream_name: dict[str, int | tuple[int, ...] | list[int] | None],
filename: str | None = None,
) -> list["av.frame.Frame"]:
if pts_unit == "sec":
# TODO: we should change all of this from ground up to simply take
# sec and convert to MS in C++
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
if end_offset != float("inf"):
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
else:
warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
should_buffer = True
max_buffer_size = 5
if stream.type == "video":
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
# so need to buffer some extra frames to sort everything
# properly
extradata = stream.codec_context.extradata
# overly complicated way of finding if `divx_packed` is set, following
# https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
if extradata and b"DivX" in extradata:
# can't use regex directly because of some weird characters sometimes...
pos = extradata.find(b"DivX")
d = extradata[pos:]
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
if o is None:
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
if o is not None:
should_buffer = o.group(3) == b"p"
seek_offset = start_offset
# some files don't seek to the right location, so better be safe here
seek_offset = max(seek_offset - 1, 0)
if should_buffer:
# FIXME this is kind of a hack, but we will jump to the previous keyframe
# so this will be safe
seek_offset = max(seek_offset - max_buffer_size, 0)
try:
# TODO check if stream needs to always be the video stream here or not
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
except av.AVError as e:
print(f"[Warning] Error while seeking video {filename}: {e}")
return []
# == main ==
buffer_count = 0
frames_pts = []
cnt = 0
try:
for _idx, frame in enumerate(container.decode(**stream_name)):
frames_pts.append(frame.pts)
video_frames[cnt] = frame.to_rgb().to_ndarray()
cnt += 1
if cnt >= len(video_frames):
break
if frame.pts >= end_offset:
if should_buffer and buffer_count < max_buffer_size:
buffer_count += 1
continue
break
except av.AVError as e:
print(f"[Warning] Error while reading video {filename}: {e}")
# garbage collection for thread leakage
container.close()
del container
# NOTE: manually garbage collect to close pyav threads
gc.collect()
# ensure that the results are sorted wrt the pts
# NOTE: here we assert frames_pts is sorted
start_ptr = 0
end_ptr = cnt
while start_ptr < end_ptr and frames_pts[start_ptr] < start_offset:
start_ptr += 1
while start_ptr < end_ptr and frames_pts[end_ptr - 1] > end_offset:
end_ptr -= 1
if start_offset > 0 and start_offset not in frames_pts[start_ptr:end_ptr]:
# if there is no frame that exactly matches the pts of start_offset
# add the last frame smaller than start_offset, to guarantee that
# we will have all the necessary data. This is most useful for audio
if start_ptr > 0:
start_ptr -= 1
result = video_frames[start_ptr:end_ptr].copy()
return result
def read_video_cv2(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
# print("Error: Unable to open video")
raise ValueError
else:
fps = cap.get(cv2.CAP_PROP_FPS)
vinfo = {
"video_fps": fps,
}
frames = []
while True:
# Read a frame from the video
ret, frame = cap.read()
# If frame is not read correctly, break the loop
if not ret:
break
frames.append(frame[:, :, ::-1]) # BGR to RGB
# Exit if 'q' is pressed
if cv2.waitKey(25) & 0xFF == ord("q"):
break
# Release the video capture object and close all windows
cap.release()
cv2.destroyAllWindows()
frames = np.stack(frames)
frames = torch.from_numpy(frames) # [T, H, W, C=3]
frames = frames.permute(0, 3, 1, 2)
return frames, vinfo
def read_video(video_path, backend="av"):
if backend == "cv2":
vframes, vinfo = read_video_cv2(video_path)
elif backend == "av":
vframes, _, vinfo = read_video_av(filename=video_path, pts_unit="sec", output_format="TCHW")
else:
raise ValueError
return vframes, vinfo

View File

@ -0,0 +1,393 @@
from collections import OrderedDict, defaultdict
from typing import Iterator
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DistributedSampler
from opensora.utils.logger import log_message
from opensora.utils.misc import format_numel_str
from .aspect import get_num_pexels_from_name
from .bucket import Bucket
from .datasets import VideoTextDataset
from .parallel import pandarallel
from .utils import sync_object_across_devices
# use pandarallel to accelerate bucket processing
# NOTE: pandarallel should only access local variables
def apply(data, method=None, seed=None, num_bucket=None, fps_max=16):
return method(
data["num_frames"],
data["height"],
data["width"],
data["fps"],
data["path"],
seed + data["id"] * num_bucket,
fps_max,
)
class StatefulDistributedSampler(DistributedSampler):
def __init__(
self,
dataset: Dataset,
num_replicas: int | None = None,
rank: int | None = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.start_index: int = 0
def __iter__(self) -> Iterator:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def reset(self) -> None:
self.start_index = 0
def state_dict(self, step) -> dict:
return {"start_index": step}
def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict)
class VariableVideoBatchSampler(DistributedSampler):
def __init__(
self,
dataset: VideoTextDataset,
bucket_config: dict,
num_replicas: int | None = None,
rank: int | None = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
verbose: bool = False,
num_bucket_build_workers: int = 1,
num_groups: int = 1,
) -> None:
super().__init__(
dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last
)
self.dataset = dataset
assert dataset.bucket_class == "Bucket", "Only support Bucket class for now"
self.bucket = Bucket(bucket_config)
self.verbose = verbose
self.last_micro_batch_access_index = 0
self.num_bucket_build_workers = num_bucket_build_workers
self._cached_bucket_sample_dict = None
self._cached_num_total_batch = None
self.num_groups = num_groups
if dist.get_rank() == 0:
pandarallel.initialize(
nb_workers=self.num_bucket_build_workers,
progress_bar=False,
verbose=0,
use_memory_fs=False,
)
def __iter__(self) -> Iterator[list[int]]:
bucket_sample_dict, _ = self.group_by_bucket()
self.clear_cache()
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
bucket_micro_batch_count = OrderedDict()
bucket_last_consumed = OrderedDict()
# process the samples
for bucket_id, data_list in bucket_sample_dict.items():
# handle droplast
bs_per_gpu = self.bucket.get_batch_size(bucket_id)
remainder = len(data_list) % bs_per_gpu
if remainder > 0:
if not self.drop_last:
# if there is remainder, we pad to make it divisible
data_list += data_list[: bs_per_gpu - remainder]
else:
# we just drop the remainder to make it divisible
data_list = data_list[:-remainder]
bucket_sample_dict[bucket_id] = data_list
# handle shuffle
if self.shuffle:
data_indices = torch.randperm(len(data_list), generator=g).tolist()
data_list = [data_list[i] for i in data_indices]
bucket_sample_dict[bucket_id] = data_list
# compute how many micro-batches each bucket has
num_micro_batches = len(data_list) // bs_per_gpu
bucket_micro_batch_count[bucket_id] = num_micro_batches
# compute the bucket access order
# each bucket may have more than one batch of data
# thus bucket_id may appear more than 1 time
bucket_id_access_order = []
for bucket_id, num_micro_batch in bucket_micro_batch_count.items():
bucket_id_access_order.extend([bucket_id] * num_micro_batch)
# randomize the access order
if self.shuffle:
bucket_id_access_order_indices = torch.randperm(len(bucket_id_access_order), generator=g).tolist()
bucket_id_access_order = [bucket_id_access_order[i] for i in bucket_id_access_order_indices]
# make the number of bucket accesses divisible by dp size
remainder = len(bucket_id_access_order) % self.num_replicas
if remainder > 0:
if self.drop_last:
bucket_id_access_order = bucket_id_access_order[: len(bucket_id_access_order) - remainder]
else:
bucket_id_access_order += bucket_id_access_order[: self.num_replicas - remainder]
# prepare each batch from its bucket
# according to the predefined bucket access order
num_iters = len(bucket_id_access_order) // self.num_replicas
start_iter_idx = self.last_micro_batch_access_index // self.num_replicas
# re-compute the micro-batch consumption
# this is useful when resuming from a state dict with a different number of GPUs
self.last_micro_batch_access_index = start_iter_idx * self.num_replicas
for i in range(self.last_micro_batch_access_index):
bucket_id = bucket_id_access_order[i]
bucket_bs = self.bucket.get_batch_size(bucket_id)
if bucket_id in bucket_last_consumed:
bucket_last_consumed[bucket_id] += bucket_bs
else:
bucket_last_consumed[bucket_id] = bucket_bs
for i in range(start_iter_idx, num_iters):
bucket_access_list = bucket_id_access_order[i * self.num_replicas : (i + 1) * self.num_replicas]
self.last_micro_batch_access_index += self.num_replicas
# compute the data samples consumed by each access
bucket_access_boundaries = []
for bucket_id in bucket_access_list:
bucket_bs = self.bucket.get_batch_size(bucket_id)
last_consumed_index = bucket_last_consumed.get(bucket_id, 0)
bucket_access_boundaries.append([last_consumed_index, last_consumed_index + bucket_bs])
# update consumption
if bucket_id in bucket_last_consumed:
bucket_last_consumed[bucket_id] += bucket_bs
else:
bucket_last_consumed[bucket_id] = bucket_bs
# compute the range of data accessed by each GPU
bucket_id = bucket_access_list[self.rank]
boundary = bucket_access_boundaries[self.rank]
cur_micro_batch = bucket_sample_dict[bucket_id][boundary[0] : boundary[1]]
# encode t, h, w into the sample index
real_t, real_h, real_w = self.bucket.get_thw(bucket_id)
cur_micro_batch = [f"{idx}-{real_t}-{real_h}-{real_w}" for idx in cur_micro_batch]
yield cur_micro_batch
self.reset()
def __len__(self) -> int:
return self.get_num_batch() // self.num_groups
def get_num_batch(self) -> int:
_, num_total_batch = self.group_by_bucket()
return num_total_batch
def clear_cache(self):
self._cached_bucket_sample_dict = None
self._cached_num_total_batch = 0
def group_by_bucket(self) -> dict:
"""
Group the dataset samples into buckets.
This method will set `self._cached_bucket_sample_dict` to the bucket sample dict.
Returns:
dict: a dictionary with bucket id as key and a list of sample indices as value
"""
if self._cached_bucket_sample_dict is not None:
return self._cached_bucket_sample_dict, self._cached_num_total_batch
# use pandarallel to accelerate bucket processing
log_message("Building buckets using %d workers...", self.num_bucket_build_workers)
bucket_ids = None
if dist.get_rank() == 0:
data = self.dataset.data.copy(deep=True)
data["id"] = data.index
bucket_ids = data.parallel_apply(
apply,
axis=1,
method=self.bucket.get_bucket_id,
seed=self.seed + self.epoch,
num_bucket=self.bucket.num_bucket,
fps_max=self.dataset.fps_max,
)
dist.barrier()
bucket_ids = sync_object_across_devices(bucket_ids)
dist.barrier()
# group by bucket
# each data sample is put into a bucket with a similar image/video size
bucket_sample_dict = defaultdict(list)
bucket_ids_np = np.array(bucket_ids)
valid_indices = np.where(bucket_ids_np != None)[0]
for i in valid_indices:
bucket_sample_dict[bucket_ids_np[i]].append(i)
# cache the bucket sample dict
self._cached_bucket_sample_dict = bucket_sample_dict
# num total batch
num_total_batch = self.print_bucket_info(bucket_sample_dict)
self._cached_num_total_batch = num_total_batch
return bucket_sample_dict, num_total_batch
def print_bucket_info(self, bucket_sample_dict: dict) -> int:
# collect statistics
num_total_samples = num_total_batch = 0
num_total_img_samples = num_total_vid_samples = 0
num_total_img_batch = num_total_vid_batch = 0
num_total_vid_batch_256 = num_total_vid_batch_768 = 0
num_aspect_dict = defaultdict(lambda: [0, 0])
num_hwt_dict = defaultdict(lambda: [0, 0])
for k, v in bucket_sample_dict.items():
size = len(v)
num_batch = size // self.bucket.get_batch_size(k[:-1])
num_total_samples += size
num_total_batch += num_batch
if k[1] == 1:
num_total_img_samples += size
num_total_img_batch += num_batch
else:
if k[0] == "256px":
num_total_vid_batch_256 += num_batch
elif k[0] == "768px":
num_total_vid_batch_768 += num_batch
num_total_vid_samples += size
num_total_vid_batch += num_batch
num_aspect_dict[k[-1]][0] += size
num_aspect_dict[k[-1]][1] += num_batch
num_hwt_dict[k[:-1]][0] += size
num_hwt_dict[k[:-1]][1] += num_batch
# sort
num_aspect_dict = dict(sorted(num_aspect_dict.items(), key=lambda x: x[0]))
num_hwt_dict = dict(
sorted(num_hwt_dict.items(), key=lambda x: (get_num_pexels_from_name(x[0][0]), x[0][1]), reverse=True)
)
num_hwt_img_dict = {k: v for k, v in num_hwt_dict.items() if k[1] == 1}
num_hwt_vid_dict = {k: v for k, v in num_hwt_dict.items() if k[1] > 1}
# log
if dist.get_rank() == 0 and self.verbose:
log_message("Bucket Info:")
log_message("Bucket [#sample, #batch] by aspect ratio:")
for k, v in num_aspect_dict.items():
log_message("(%s): #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
log_message("===== Image Info =====")
log_message("Image Bucket by HxWxT:")
for k, v in num_hwt_img_dict.items():
log_message("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
log_message("--------------------------------")
log_message(
"#image sample: %s, #image batch: %s",
format_numel_str(num_total_img_samples),
format_numel_str(num_total_img_batch),
)
log_message("===== Video Info =====")
log_message("Video Bucket by HxWxT:")
for k, v in num_hwt_vid_dict.items():
log_message("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
log_message("--------------------------------")
log_message(
"#video sample: %s, #video batch: %s",
format_numel_str(num_total_vid_samples),
format_numel_str(num_total_vid_batch),
)
log_message("===== Summary =====")
log_message("#non-empty buckets: %s", len(bucket_sample_dict))
log_message(
"Img/Vid sample ratio: %.2f",
num_total_img_samples / num_total_vid_samples if num_total_vid_samples > 0 else 0,
)
log_message(
"Img/Vid batch ratio: %.2f", num_total_img_batch / num_total_vid_batch if num_total_vid_batch > 0 else 0
)
log_message(
"vid batch 256: %s, vid batch 768: %s", format_numel_str(num_total_vid_batch_256), format_numel_str(num_total_vid_batch_768)
)
log_message(
"Vid batch ratio (256px/768px): %.2f", num_total_vid_batch_256 / num_total_vid_batch_768 if num_total_vid_batch_768 > 0 else 0
)
log_message(
"#training sample: %s, #training batch: %s",
format_numel_str(num_total_samples),
format_numel_str(num_total_batch),
)
return num_total_batch
def reset(self):
self.last_micro_batch_access_index = 0
def set_step(self, start_step: int):
self.last_micro_batch_access_index = start_step * self.num_replicas
def state_dict(self, num_steps: int) -> dict:
# the last_micro_batch_access_index in the __iter__ is often
# not accurate during multi-workers and data prefetching
# thus, we need the user to pass the actual steps which have been executed
# to calculate the correct last_micro_batch_access_index
return {"seed": self.seed, "epoch": self.epoch, "last_micro_batch_access_index": num_steps * self.num_replicas}
def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict)
class BatchDistributedSampler(DistributedSampler):
"""
Used with BatchDataset;
Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then
| buffer {i} | buffer {i+1}
------ | ------------------- | -------------------
rank 0 | 0, 1, 2, 3, 4, | 5, 6, 7, 8, 9
rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19
rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29
"""
def __init__(self, dataset: Dataset, **kwargs):
super().__init__(dataset, **kwargs)
self.start_index = 0
def __iter__(self):
num_buffers = self.dataset.num_buffers
len_buffer = self.dataset.len_buffer
num_buffers_i = num_buffers // self.num_replicas
num_samples_i = len_buffer * num_buffers_i
indices_i = np.arange(self.start_index, num_samples_i) + self.rank * num_samples_i
indices_i = indices_i.tolist()
return iter(indices_i)
def reset(self):
self.start_index = 0
def state_dict(self, step) -> dict:
return {"start_index": step}
def load_state_dict(self, state_dict: dict):
self.start_index = state_dict["start_index"] + 1

419
opensora/datasets/utils.py Normal file
View File

@ -0,0 +1,419 @@
import math
import os
import random
import re
from typing import Any
import numpy as np
import pandas as pd
import requests
import torch
import torch.distributed as dist
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from torchvision.io import write_video
from torchvision.utils import save_image
from . import video_transforms
from .read_video import read_video
try:
import dask.dataframe as dd
SUPPORT_DASK = True
except:
SUPPORT_DASK = False
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
regex = re.compile(
r"^(?:http|ftp)s?://" # http:// or https://
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
r"localhost|" # localhost...
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
r"(?::\d+)?" # optional port
r"(?:/?|[/?]\S+)$",
re.IGNORECASE,
)
def is_img(path):
ext = os.path.splitext(path)[-1].lower()
return ext in IMG_EXTENSIONS
def is_vid(path):
ext = os.path.splitext(path)[-1].lower()
return ext in VID_EXTENSIONS
def is_url(url):
return re.match(regex, url) is not None
def read_file(input_path, memory_efficient=False):
if input_path.endswith(".csv"):
assert not memory_efficient, "Memory efficient mode is not supported for CSV files"
return pd.read_csv(input_path)
elif input_path.endswith(".parquet"):
columns = None
if memory_efficient:
columns = ["path", "num_frames", "height", "width", "aspect_ratio", "fps", "resolution"]
if SUPPORT_DASK:
ret = dd.read_parquet(input_path, columns=columns).compute()
else:
ret = pd.read_parquet(input_path, columns=columns)
return ret
else:
raise NotImplementedError(f"Unsupported file format: {input_path}")
def download_url(input_path):
output_dir = "cache"
os.makedirs(output_dir, exist_ok=True)
base_name = os.path.basename(input_path)
output_path = os.path.join(output_dir, base_name)
img_data = requests.get(input_path).content
with open(output_path, "wb", encoding="utf-8") as handler:
handler.write(img_data)
print(f"URL {input_path} downloaded to {output_path}")
return output_path
def temporal_random_crop(
vframes: torch.Tensor, num_frames: int, frame_interval: int, return_frame_indices: bool = False
) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
total_frames = len(vframes)
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= num_frames
), f"Not enough frames to sample, {end_frame_ind} - {start_frame_ind} < {num_frames}"
frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int)
video = vframes[frame_indices]
if return_frame_indices:
return video, frame_indices
else:
return video
def get_transforms_video(name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[1], "image_size must be square for center crop"
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(image_size[0]),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.ResizeCrop(image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "rand_size_crop":
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomSizedCrop(image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
else:
raise NotImplementedError(f"Transform {name} not implemented")
return transform_video
def get_transforms_image(name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[1], "Image size must be square for center crop"
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "rand_size_crop":
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: rand_size_crop_arr(pil_image, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
else:
raise NotImplementedError(f"Transform {name} not implemented")
return transform
def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)):
image = pil_loader(path)
if transform is None:
transform = get_transforms_image(image_size=image_size, name=transform_name)
image = transform(image)
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
video = video.permute(1, 0, 2, 3)
return video
def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)):
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
if transform is None:
transform = get_transforms_video(image_size=image_size, name=transform_name)
video = transform(vframes) # T C H W
video = video.permute(1, 0, 2, 3)
return video
def read_from_path(path, image_size, transform_name="center"):
if is_url(path):
path = download_url(path)
ext = os.path.splitext(path)[-1].lower()
if ext.lower() in VID_EXTENSIONS:
return read_video_from_path(path, image_size=image_size, transform_name=transform_name)
else:
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
def save_sample(
x,
save_path=None,
fps=8,
normalize=True,
value_range=(-1, 1),
force_video=False,
verbose=True,
crf=23,
):
"""
Args:
x (Tensor): shape [C, T, H, W]
"""
assert x.ndim == 4
if not force_video and x.shape[1] == 1: # T = 1: save as image
save_path += ".png"
x = x.squeeze(1)
save_image([x], save_path, normalize=normalize, value_range=value_range)
else:
save_path += ".mp4"
if normalize:
low, high = value_range
x.clamp_(min=low, max=high)
x.sub_(low).div_(max(high - low, 1e-5))
x = x.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
write_video(save_path, x, fps=fps, video_codec="h264", options={"crf": str(crf)})
if verbose:
print(f"Saved to {save_path}")
return save_path
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
def rand_size_crop_arr(pil_image, image_size):
"""
Randomly crop image for height and width, ranging from image_size[0] to image_size[1]
"""
arr = np.array(pil_image)
# get random target h w
height = random.randint(image_size[0], image_size[1])
width = random.randint(image_size[0], image_size[1])
# ensure that h w are factors of 8
height = height - height % 8
width = width - width % 8
# get random start pos
h_start = random.randint(0, max(len(arr) - height, 0))
w_start = random.randint(0, max(len(arr[0]) - height, 0))
# crop
return Image.fromarray(arr[h_start : h_start + height, w_start : w_start + width])
def resize_crop_to_fill(pil_image, image_size):
w, h = pil_image.size # PIL is (W, H)
th, tw = image_size
rh, rw = th / h, tw / w
if rh > rw:
sh, sw = th, round(w * rh)
image = pil_image.resize((sw, sh), Image.BICUBIC)
i = 0
j = int(round((sw - tw) / 2.0))
else:
sh, sw = round(h * rw), tw
image = pil_image.resize((sw, sh), Image.BICUBIC)
i = int(round((sh - th) / 2.0))
j = 0
arr = np.array(image)
assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
return Image.fromarray(arr[i : i + th, j : j + tw])
def map_target_fps(
fps: float,
max_fps: float,
) -> tuple[float, int]:
"""
Map fps to a new fps that is less than max_fps.
Args:
fps (float): Original fps.
max_fps (float): Maximum fps.
Returns:
tuple[float, int]: New fps and sampling interval.
"""
if math.isnan(fps):
return 0, 1
if fps < max_fps:
return fps, 1
sampling_interval = math.ceil(fps / max_fps)
new_fps = math.floor(fps / sampling_interval)
return new_fps, sampling_interval
def sync_object_across_devices(obj: Any, rank: int = 0):
"""
Synchronizes any picklable object across devices in a PyTorch distributed setting
using `broadcast_object_list` with CUDA support.
Parameters:
obj (Any): The object to synchronize. Can be any picklable object (e.g., list, dict, custom class).
rank (int): The rank of the device from which to broadcast the object state. Default is 0.
Note: Ensure torch.distributed is initialized before using this function and CUDA is available.
"""
# Move the object to a list for broadcasting
object_list = [obj]
# Broadcast the object list from the source rank to all other ranks
dist.broadcast_object_list(object_list, src=rank, device="cuda")
# Retrieve the synchronized object
obj = object_list[0]
return obj
def rescale_image_by_path(path: str, height: int, width: int):
"""
Rescales an image to the specified height and width and saves it back to the original path.
Args:
path (str): The file path of the image.
height (int): The target height of the image.
width (int): The target width of the image.
"""
try:
# read image
image = Image.open(path)
# check if image is valid
if image is None:
raise ValueError("The image is invalid or empty.")
# resize image
resize_transform = transforms.Resize((width, height))
resized_image = resize_transform(image)
# save resized image back to the original path
resized_image.save(path)
except Exception as e:
print(f"Error rescaling image: {e}")
def rescale_video_by_path(path: str, height: int, width: int):
"""
Rescales an MP4 video (without audio) to the specified height and width.
Args:
path (str): The file path of the video.
height (int): The target height of the video.
width (int): The target width of the video.
"""
try:
# Read video and metadata
video, info = read_video(path, backend="av")
# Check if video is valid
if video is None or video.size(0) == 0:
raise ValueError("The video is invalid or empty.")
# Resize video frames using a performant method
resize_transform = transforms.Compose([transforms.Resize((width, height))])
resized_video = torch.stack([resize_transform(frame) for frame in video])
# Save resized video back to the original path
resized_video = resized_video.permute(0, 2, 3, 1)
write_video(path, resized_video, fps=int(info["video_fps"]), video_codec="h264")
except Exception as e:
print(f"Error rescaling video: {e}")
def save_tensor_to_disk(tensor, path, exist_handling="overwrite"):
if os.path.exists(path):
if exist_handling == "ignore":
return
elif exist_handling == "raise":
raise UserWarning(f"File {path} already exists, rewriting!")
torch.save(tensor, path)
def save_tensor_to_internet(tensor, path):
raise NotImplementedError("save_tensor_to_internet is not implemented yet!")
def save_latent(latent, path, exist_handling="overwrite"):
if path.startswith(("http://", "https://", "ftp://", "sftp://")):
save_tensor_to_internet(latent, path)
else:
save_tensor_to_disk(latent, path, exist_handling=exist_handling)
def cache_latents(latents, path, exist_handling="overwrite"):
for i in range(latents.shape[0]):
save_latent(latents[i], path[i], exist_handling=exist_handling)

View File

@ -0,0 +1,595 @@
# Copyright 2024 Vchitect/Latte
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Modified from Latte
import numbers
# - This file is adapted from https://github.com/Vchitect/Latte/blob/main/datasets/video_transforms.py
import random
import numpy as np
import torch
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i : i + h, j : j + w]
def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
def resize_scale(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size[0] / min(H, W)
th, tw = int(round(H * scale_)), int(round(W * scale_))
return torch.nn.functional.interpolate(clip, size=(th, tw), mode=interpolation_mode, align_corners=False)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
if h < th or w < tw:
raise ValueError("height and width must be no smaller than crop_size")
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def center_crop_using_short_edge(clip):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h < w:
th, tw = h, h
i = 0
j = int(round((w - tw) / 2.0))
else:
th, tw = w, w
i = int(round((h - th) / 2.0))
j = 0
return crop(clip, i, j, th, tw)
def resize_crop_to_fill(clip, target_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = target_size[0], target_size[1]
rh, rw = th / h, tw / w
if rh > rw:
sh, sw = th, round(w * rh)
clip = resize(clip, (sh, sw), "bilinear")
i = 0
j = int(round(sw - tw) / 2.0)
else:
sh, sw = round(h * rw), tw
clip = resize(clip, (sh, sw), "bilinear")
i = int(round(sh - th) / 2.0)
j = 0
assert i + th <= clip.size(-2) and j + tw <= clip.size(-1)
return crop(clip, i, j, th, tw)
# def rand_crop_h_w(clip, target_size_range, multiples_of: int = 8):
# # NOTE: for some reason, if don't re-import, gives same randint results
# import sys
# del sys.modules["random"]
# import random
# if not _is_tensor_video_clip(clip):
# raise ValueError("clip should be a 4D torch.tensor")
# h, w = clip.size(-2), clip.size(-1)
# # get random target h w
# th = random.randint(target_size_range[0], target_size_range[1])
# tw = random.randint(target_size_range[0], target_size_range[1])
# # ensure that h w are factors of 8
# th = th - th % multiples_of
# tw = tw - tw % multiples_of
# # get random start pos
# i = random.randint(0, h-th) if h > th else 0
# j = random.randint(0, w-tw) if w > tw else 0
# th = th if th < h else h
# tw = tw if tw < w else w
# # print("target size range:",target_size_range)
# # print("original size:", h, w)
# # print("crop size:", th, tw)
# # print(f"crop:{i}-{i+th}, {j}-{j+tw}")
# return (crop(clip, i, j, th, tw), th, tw)
def random_shift_crop(clip):
"""
Slide along the long edge, with the short edge as crop size
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h <= w:
short_edge = h
else:
short_edge = w
th, tw = short_edge, short_edge
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return crop(clip, i, j, th, tw)
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
# return clip.float().permute(3, 0, 1, 2) / 255.0
return clip.float() / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
# print(mean)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
Returns:
flipped clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
return clip.flip(-1)
class ResizeCrop:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
clip = resize_crop_to_fill(clip, self.size)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class RandomSizedCrop:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
i, j, h, w = self.get_params(clip)
# self.size = (h, w)
return crop(clip, i, j, h, w)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
def get_params(self, clip, multiples_of=8):
h, w = clip.shape[-2:]
# get random target h w
th = random.randint(self.size[0], self.size[1])
tw = random.randint(self.size[0], self.size[1])
# ensure that h w are factors of 8
th = th - th % multiples_of
tw = tw - tw % multiples_of
if h < th:
th = h - h % multiples_of
if w < tw:
tw = w - w % multiples_of
if w == tw and h == th:
return 0, 0, h, w
else:
# get random start pos
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
class RandomCropVideo:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: randomly cropped video clip.
size is (T, C, OH, OW)
"""
i, j, h, w = self.get_params(clip)
return crop(clip, i, j, h, w)
def get_params(self, clip):
h, w = clip.shape[-2:]
th, tw = self.size
if h < th or w < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class CenterCropResizeVideo:
"""
First use the short side for cropping length,
center crop video, then resize to the specified size
"""
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop_using_short_edge(clip)
clip_center_crop_resize = resize(
clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode
)
return clip_center_crop_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class UCFCenterCropVideo:
"""
First scale to the specified size in equal proportion to the short edge,
then center cropping
"""
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
clip_center_crop = center_crop(clip_resize, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class KineticsRandomCropResizeVideo:
"""
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
"""
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
clip_random_crop = random_shift_crop(clip)
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
return clip_resize
class CenterCropVideo:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop(clip, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class NormalizeVideo:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
"""
return normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
return to_tensor(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class RandomHorizontalFlipVideo:
"""
Flip the video clip along the horizontal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (T, C, H, W)
Return:
clip (torch.tensor): Size is (T, C, H, W)
"""
if random.random() < self.p:
clip = hflip(clip)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
# ------------------------------------------------------------
# --------------------- Sampling ---------------------------
# ------------------------------------------------------------
class TemporalRandomCrop(object):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def __init__(self, size):
self.size = size
def __call__(self, total_frames):
rand_end = max(0, total_frames - self.size - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + self.size, total_frames)
return begin_index, end_index
if __name__ == "__main__":
import os
import numpy as np
import torchvision.io as io
from torchvision import transforms
from torchvision.utils import save_image
vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW")
trans = transforms.Compose(
[
ToTensorVideo(),
RandomHorizontalFlipVideo(),
UCFCenterCropVideo(512),
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
target_video_len = 32
frame_interval = 1
total_frames = len(vframes)
print(total_frames)
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
# Sampling video frames
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
# print(start_frame_ind)
# print(end_frame_ind)
assert end_frame_ind - start_frame_ind >= target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
print(frame_indice)
select_vframes = vframes[frame_indice]
print(select_vframes.shape)
print(select_vframes.dtype)
select_vframes_trans = trans(select_vframes)
print(select_vframes_trans.shape)
print(select_vframes_trans.dtype)
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
print(select_vframes_trans_int.dtype)
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
for i in range(target_video_len):
save_image(
select_vframes_trans[i], os.path.join("./test000", "%04d.png" % i), normalize=True, value_range=(-1, 1)
)

282
tools/datasets/README.md Normal file
View File

@ -0,0 +1,282 @@
# Dataset Management
- [Dataset Management](#dataset-management)
- [Dataset Format](#dataset-format)
- [Dataset to CSV](#dataset-to-csv)
- [Manage datasets](#manage-datasets)
- [Requirement](#requirement)
- [Basic Usage](#basic-usage)
- [Score filtering](#score-filtering)
- [Documentation](#documentation)
- [Transform datasets](#transform-datasets)
- [Resize](#resize)
- [Frame extraction](#frame-extraction)
- [Crop Midjourney 4 grid](#crop-midjourney-4-grid)
- [Analyze datasets](#analyze-datasets)
- [Data Process Pipeline](#data-process-pipeline)
After preparing the raw dataset according to the [instructions](/docs/datasets.md), you can use the following commands to manage the dataset.
## Dataset Format
All dataset should be provided in a `.csv` file (or `parquet.gzip` to save space), which is used for both training and data preprocessing. The columns should follow the words below:
- `path`: the relative/absolute path or url to the image or video file. Required.
- `text`: the caption or description of the image or video. Required for training.
- `num_frames`: the number of frames in the video. Required for training.
- `width`: the width of the video frame. Required for dynamic bucket.
- `height`: the height of the video frame. Required for dynamic bucket.
- `aspect_ratio`: the aspect ratio of the video frame (height / width). Required for dynamic bucket.
- `resolution`: height x width. For analysis.
- `text_len`: the number of tokens in the text. For analysis.
- `aes`: aesthetic score calculated by [asethetic scorer](/tools/aesthetic/README.md). For filtering.
- `flow`: optical flow score calculated by [UniMatch](/tools/scoring/README.md). For filtering.
- `match`: matching score of a image-text/video-text pair calculated by [CLIP](/tools/scoring/README.md). For filtering.
- `fps`: the frame rate of the video. Optional.
- `cmotion`: the camera motion.
An example ready for training:
```csv
path, text, num_frames, width, height, aspect_ratio
/absolute/path/to/image1.jpg, caption, 1, 720, 1280, 0.5625
/absolute/path/to/video1.mp4, caption, 120, 720, 1280, 0.5625
/absolute/path/to/video2.mp4, caption, 20, 256, 256, 1
```
We use pandas to manage the `.csv` or `.parquet` files. The following code is for reading and writing files:
```python
df = pd.read_csv(input_path)
df = df.to_csv(output_path, index=False)
# or use parquet, which is smaller
df = pd.read_parquet(input_path)
df = df.to_parquet(output_path, index=False)
```
## Dataset to CSV
As a start point, `convert.py` is used to convert the dataset to a CSV file. You can use the following commands to convert the dataset to a CSV file:
```bash
python -m tools.datasets.convert DATASET-TYPE DATA_FOLDER
# general video folder
python -m tools.datasets.convert video VIDEO_FOLDER --output video.csv
# general image folder
python -m tools.datasets.convert image IMAGE_FOLDER --output image.csv
# imagenet
python -m tools.datasets.convert imagenet IMAGENET_FOLDER --split train
# ucf101
python -m tools.datasets.convert ucf101 UCF101_FOLDER --split videos
# vidprom
python -m tools.datasets.convert vidprom VIDPROM_FOLDER --info VidProM_semantic_unique.csv
```
## Manage datasets
Use `datautil` to manage the dataset.
### Requirement
Follow our [installation guide](../../docs/installation.md)'s "Data Dependencies" and "Datasets" section to install the required packages.
<!-- To accelerate processing speed, you can install [pandarallel](https://github.com/nalepae/pandarallel):
```bash
pip install pandarallel
``` -->
<!-- To get image and video information, you need to install [opencv-python](https://github.com/opencv/opencv-python): -->
<!-- ```bash
pip install opencv-python
# If your videos are in av1 codec instead of h264, you need to
# - install ffmpeg first
# - install via conda to support av1 codec
conda install -c conda-forge opencv
``` -->
<!-- Or to get video information, you can install ffmpeg and ffmpeg-python:
```bash
pip install ffmpeg-python
``` -->
<!-- To filter a specific language, you need to install [lingua](https://github.com/pemistahl/lingua-py):
```bash
pip install lingua-language-detector
``` -->
### Basic Usage
You can use the following commands to process the `csv` or `parquet` files. The output file will be saved in the same directory as the input, with different suffixes indicating the processed method.
```bash
# datautil takes multiple CSV files as input and merge them into one CSV file
# output: DATA1+DATA2.csv
python -m tools.datasets.datautil DATA1.csv DATA2.csv
# shard CSV files into multiple CSV files
# output: DATA1_0.csv, DATA1_1.csv, ...
python -m tools.datasets.datautil DATA1.csv --shard 10
# filter frames between 128 and 256, with captions
# output: DATA1_fmin_128_fmax_256.csv
python -m tools.datasets.datautil DATA.csv --fmin 128 --fmax 256
# Disable parallel processing
python -m tools.datasets.datautil DATA.csv --fmin 128 --fmax 256 --disable-parallel
# Compute num_frames, height, width, fps, aspect_ratio for videos or images
# output: IMG_DATA+VID_DATA_vinfo.csv
python -m tools.datasets.datautil IMG_DATA.csv VID_DATA.csv --video-info
# You can run multiple operations at the same time.
python -m tools.datasets.datautil DATA.csv --video-info --remove-empty-caption --remove-url --lang en
```
### Score filtering
To examine and filter the quality of the dataset by aesthetic score and clip score, you can use the following commands:
```bash
# sort the dataset by aesthetic score
# output: DATA_sort.csv
python -m tools.datasets.datautil DATA.csv --sort aesthetic_score
# View examples of high aesthetic score
head -n 10 DATA_sort.csv
# View examples of low aesthetic score
tail -n 10 DATA_sort.csv
# sort the dataset by clip score
# output: DATA_sort.csv
python -m tools.datasets.datautil DATA.csv --sort clip_score
# filter the dataset by aesthetic score
# output: DATA_aesmin_0.5.csv
python -m tools.datasets.datautil DATA.csv --aesmin 0.5
# filter the dataset by clip score
# output: DATA_matchmin_0.5.csv
python -m tools.datasets.datautil DATA.csv --matchmin 0.5
```
### Documentation
You can also use `python -m tools.datasets.datautil --help` to see usage.
| Args | File suffix | Description |
| --------------------------- | -------------- | ------------------------------------------------------------- |
| `--output OUTPUT` | | Output path |
| `--format FORMAT` | | Output format (csv, parquet, parquet.gzip) |
| `--disable-parallel` | | Disable `pandarallel` |
| `--seed SEED` | | Random seed |
| `--shard SHARD` | `_0`,`_1`, ... | Shard the dataset |
| `--sort KEY` | `_sort` | Sort the dataset by KEY |
| `--sort-descending KEY` | `_sort` | Sort the dataset by KEY in descending order |
| `--difference DATA.csv` | | Remove the paths in DATA.csv from the dataset |
| `--intersection DATA.csv` | | Keep the paths in DATA.csv from the dataset and merge columns |
| `--info` | `_info` | Get the basic information of each video and image (cv2) |
| `--ext` | `_ext` | Remove rows if the file does not exist |
| `--relpath` | `_relpath` | Modify the path to relative path by root given |
| `--abspath` | `_abspath` | Modify the path to absolute path by root given |
| `--remove-empty-caption` | `_noempty` | Remove rows with empty caption |
| `--remove-url` | `_nourl` | Remove rows with url in caption |
| `--lang LANG` | `_lang` | Remove rows with other language |
| `--remove-path-duplication` | `_noduppath` | Remove rows with duplicated path |
| `--remove-text-duplication` | `_noduptext` | Remove rows with duplicated caption |
| `--refine-llm-caption` | `_llm` | Modify the caption generated by LLM |
| `--clean-caption MODEL` | `_clean` | Modify the caption according to T5 pipeline to suit training |
| `--unescape` | `_unescape` | Unescape the caption |
| `--merge-cmotion` | `_cmotion` | Merge the camera motion to the caption |
| `--count-num-token` | `_ntoken` | Count the number of tokens in the caption |
| `--load-caption EXT` | `_load` | Load the caption from the file |
| `--fmin FMIN` | `_fmin` | Filter the dataset by minimum number of frames |
| `--fmax FMAX` | `_fmax` | Filter the dataset by maximum number of frames |
| `--hwmax HWMAX` | `_hwmax` | Filter the dataset by maximum height x width |
| `--aesmin AESMIN` | `_aesmin` | Filter the dataset by minimum aesthetic score |
| `--matchmin MATCHMIN` | `_matchmin` | Filter the dataset by minimum clip score |
| `--flowmin FLOWMIN` | `_flowmin` | Filter the dataset by minimum optical flow score |
## Transform datasets
The `tools.datasets.transform` module provides a set of tools to transform the dataset. The general usage is as follows:
```bash
python -m tools.datasets.transform TRANSFORM_TYPE META.csv ORIGINAL_DATA_FOLDER DATA_FOLDER_TO_SAVE_RESULTS --additional-args
```
### Resize
Sometimes you may need to resize the images or videos to a specific resolution. You can use the following commands to resize the dataset:
```bash
python -m tools.datasets.transform meta.csv /path/to/raw/data /path/to/new/data --length 2160
```
### Frame extraction
To extract frames from videos, you can use the following commands:
```bash
python -m tools.datasets.transform vid_frame_extract meta.csv /path/to/raw/data /path/to/new/data --points 0.1 0.5 0.9
```
### Crop Midjourney 4 grid
Randomly select one of the 4 images in the 4 grid generated by Midjourney.
```bash
python -m tools.datasets.transform img_rand_crop meta.csv /path/to/raw/data /path/to/new/data
```
## Analyze datasets
You can easily get basic information about a `.csv` dataset by using the following commands:
```bash
# examine the first 10 rows of the CSV file
head -n 10 DATA1.csv
# count the number of data in the CSV file (approximately)
wc -l DATA1.csv
```
For the dataset provided in a `.csv` or `.parquet` file, you can easily analyze the dataset using the following commands. Plots will be automatically saved.
```python
pyhton -m tools.datasets.analyze DATA_info.csv
```
## Data Process Pipeline
```bash
# Suppose videos and images under ~/dataset/
# 1. Convert dataset to CSV
python -m tools.datasets.convert video ~/dataset --output meta.csv
# 2. Get video information
python -m tools.datasets.datautil meta.csv --info --fmin 1
# 3. Get caption
# 3.1. generate caption
torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava meta_info_fmin1.csv --dp-size 8 --tp-size 1 --model-path liuhaotian/llava-v1.6-mistral-7b --prompt video
# merge generated results
python -m tools.datasets.datautil meta_info_fmin1_caption_part*.csv --output meta_caption.csv
# merge caption and info
python -m tools.datasets.datautil meta_info_fmin1.csv --intersection meta_caption.csv --output meta_caption_info.csv
# clean caption
python -m tools.datasets.datautil meta_caption_info.csv --clean-caption --refine-llm-caption --remove-empty-caption --output meta_caption_processed.csv
# 3.2. extract caption
python -m tools.datasets.datautil meta_info_fmin1.csv --load-caption json --remove-empty-caption --clean-caption
# 4. Scoring
# aesthetic scoring
torchrun --standalone --nproc_per_node 8 -m tools.scoring.aesthetic.inference meta_caption_processed.csv
python -m tools.datasets.datautil meta_caption_processed_part*.csv --output meta_caption_processed_aes.csv
# optical flow scoring
torchrun --standalone --nproc_per_node 8 -m tools.scoring.optical_flow.inference meta_caption_processed.csv
# matching scoring
torchrun --standalone --nproc_per_node 8 -m tools.scoring.matching.inference meta_caption_processed.csv
# camera motion
python -m tools.caption.camera_motion_detect meta_caption_processed.csv
```

View File

96
tools/datasets/analyze.py Normal file
View File

@ -0,0 +1,96 @@
import argparse
import os
import matplotlib.pyplot as plt
import pandas as pd
def read_file(input_path):
if input_path.endswith(".csv"):
return pd.read_csv(input_path)
elif input_path.endswith(".parquet"):
return pd.read_parquet(input_path)
else:
raise NotImplementedError(f"Unsupported file format: {input_path}")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="Path to the input dataset")
parser.add_argument("--save-img", type=str, default="samples/infos/", help="Path to save the image")
return parser.parse_args()
def plot_data(data, column, bins, name):
plt.clf()
data.hist(column=column, bins=bins)
os.makedirs(os.path.dirname(name), exist_ok=True)
plt.savefig(name)
print(f"Saved {name}")
def plot_categorical_data(data, column, name):
plt.clf()
data[column].value_counts().plot(kind="bar")
os.makedirs(os.path.dirname(name), exist_ok=True)
plt.savefig(name)
print(f"Saved {name}")
COLUMNS = {
"num_frames": 100,
"resolution": 100,
"text_len": 100,
"aes": 100,
"match": 100,
"flow": 100,
"cmotion": None,
}
def main(args):
data = read_file(args.input)
# === Image Data Info ===
image_index = data["num_frames"] == 1
if image_index.sum() > 0:
print("=== Image Data Info ===")
img_data = data[image_index]
print(f"Number of images: {len(img_data)}")
print(img_data.head())
print(img_data.describe())
if args.save_img:
for column in COLUMNS:
if column in img_data.columns and column not in ["num_frames", "cmotion"]:
if COLUMNS[column] is None:
plot_categorical_data(img_data, column, os.path.join(args.save_img, f"image_{column}.png"))
else:
plot_data(img_data, column, COLUMNS[column], os.path.join(args.save_img, f"image_{column}.png"))
# === Video Data Info ===
if not image_index.all():
print("=== Video Data Info ===")
video_data = data[~image_index]
print(f"Number of videos: {len(video_data)}")
if "num_frames" in video_data.columns:
total_num_frames = video_data["num_frames"].sum()
print(f"Number of frames: {total_num_frames}")
DEFAULT_FPS = 30
total_hours = total_num_frames / DEFAULT_FPS / 3600
print(f"Total hours (30 FPS): {int(total_hours)}")
print(video_data.head())
print(video_data.describe())
if args.save_img:
for column in COLUMNS:
if column in video_data.columns:
if COLUMNS[column] is None:
plot_categorical_data(video_data, column, os.path.join(args.save_img, f"video_{column}.png"))
else:
plot_data(
video_data, column, COLUMNS[column], os.path.join(args.save_img, f"video_{column}.png")
)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -0,0 +1,79 @@
import argparse
import subprocess
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
try:
from pandarallel import pandarallel
PANDA_USE_PARALLEL = True
except ImportError:
PANDA_USE_PARALLEL = False
import shutil
if not shutil.which("ffmpeg"):
raise ImportError("FFmpeg is not installed")
def apply(df, func, **kwargs):
if PANDA_USE_PARALLEL:
return df.parallel_apply(func, **kwargs)
return df.progress_apply(func, **kwargs)
def check_video_integrity(video_path):
# try:
can_open_result = subprocess.run(
["ffmpeg", "-v", "error", "-i", video_path, "-t", "0", "-f", "null", "-"], # open video and capture 0 seconds
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
fast_scan_result = subprocess.run(
["ffmpeg", "-v", "error", "-analyzeduration", "10M", "-probesize", "10M", "-i", video_path, "-f", "null", "-"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if can_open_result.stderr == "" and fast_scan_result.stderr == "":
return True
else:
return False
# except Exception as e:
# return False
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="path to the input dataset")
parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing")
parser.add_argument("--num-workers", type=int, default=None, help="number of workers")
args = parser.parse_args()
if args.disable_parallel:
PANDA_USE_PARALLEL = False
if PANDA_USE_PARALLEL:
if args.num_workers is not None:
pandarallel.initialize(nb_workers=args.num_workers, progress_bar=True)
else:
pandarallel.initialize(progress_bar=True)
data = pd.read_csv(args.input)
assert "path" in data.columns
data["integrity"] = apply(data["path"], check_video_integrity)
integrity_file_path = args.input.replace(".csv", "_intact.csv")
broken_file_path = args.input.replace(".csv", "_broken.csv")
intact_data = data[data["integrity"] == True].drop(columns=["integrity"])
intact_data.to_csv(integrity_file_path, index=False)
broken_data = data[data["integrity"] == False].drop(columns=["integrity"])
broken_data.to_csv(broken_file_path, index=False)
print(
f"Integrity check completed. Intact videos saved to: {integrity_file_path}, broken videos saved to {broken_file_path}."
)

144
tools/datasets/convert.py Normal file
View File

@ -0,0 +1,144 @@
import argparse
import os
import time
import pandas as pd
from torchvision.datasets import ImageNet
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv", ".m2ts")
def scan_recursively(root):
num = 0
for entry in os.scandir(root):
if entry.is_file():
yield entry
elif entry.is_dir():
num += 1
if num % 100 == 0:
print(f"Scanned {num} directories.")
yield from scan_recursively(entry.path)
def get_filelist(file_path, exts=None):
filelist = []
time_start = time.time()
# == OS Walk ==
# for home, dirs, files in os.walk(file_path):
# for filename in files:
# ext = os.path.splitext(filename)[-1].lower()
# if exts is None or ext in exts:
# filelist.append(os.path.join(home, filename))
# == Scandir ==
obj = scan_recursively(file_path)
for entry in obj:
if entry.is_file():
ext = os.path.splitext(entry.name)[-1].lower()
if exts is None or ext in exts:
filelist.append(entry.path)
time_end = time.time()
print(f"Scanned {len(filelist)} files in {time_end - time_start:.2f} seconds.")
return filelist
def split_by_capital(name):
# BoxingPunchingBag -> Boxing Punching Bag
new_name = ""
for i in range(len(name)):
if name[i].isupper() and i != 0:
new_name += " "
new_name += name[i]
return new_name
def process_imagenet(root, split):
root = os.path.expanduser(root)
data = ImageNet(root, split=split)
samples = [(path, data.classes[label][0]) for path, label in data.samples]
output = f"imagenet_{split}.csv"
df = pd.DataFrame(samples, columns=["path", "text"])
df.to_csv(output, index=False)
print(f"Saved {len(samples)} samples to {output}.")
def process_ucf101(root, split):
root = os.path.expanduser(root)
video_lists = get_filelist(os.path.join(root, split))
classes = [x.split("/")[-2] for x in video_lists]
classes = [split_by_capital(x) for x in classes]
samples = list(zip(video_lists, classes))
output = f"ucf101_{split}.csv"
df = pd.DataFrame(samples, columns=["path", "text"])
df.to_csv(output, index=False)
print(f"Saved {len(samples)} samples to {output}.")
def process_vidprom(root, info):
root = os.path.expanduser(root)
video_lists = get_filelist(root)
video_set = set(video_lists)
# read info csv
infos = pd.read_csv(info)
abs_path = infos["uuid"].apply(lambda x: os.path.join(root, f"pika-{x}.mp4"))
is_exist = abs_path.apply(lambda x: x in video_set)
df = pd.DataFrame(dict(path=abs_path[is_exist], text=infos["prompt"][is_exist]))
df.to_csv("vidprom.csv", index=False)
print(f"Saved {len(df)} samples to vidprom.csv.")
def process_general_images(root, output):
root = os.path.expanduser(root)
if not os.path.exists(root):
return
path_list = get_filelist(root, IMG_EXTENSIONS)
fname_list = [os.path.splitext(os.path.basename(x))[0] for x in path_list]
relpath_list = [os.path.relpath(x, root) for x in path_list]
df = pd.DataFrame(dict(path=path_list, id=fname_list, relpath=relpath_list))
os.makedirs(os.path.dirname(output), exist_ok=True)
df.to_csv(output, index=False)
print(f"Saved {len(df)} samples to {output}.")
def process_general_videos(root, output):
root = os.path.expanduser(root)
if not os.path.exists(root):
return
path_list = get_filelist(root, VID_EXTENSIONS)
path_list = list(set(path_list)) # remove duplicates
fname_list = [os.path.splitext(os.path.basename(x))[0] for x in path_list]
relpath_list = [os.path.relpath(x, root) for x in path_list]
df = pd.DataFrame(dict(path=path_list, id=fname_list, relpath=relpath_list))
os.makedirs(os.path.dirname(output), exist_ok=True)
df.to_csv(output, index=False)
print(f"Saved {len(df)} samples to {output}.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101", "vidprom", "image", "video"])
parser.add_argument("root", type=str)
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--info", type=str, default=None)
parser.add_argument("--output", type=str, default=None, required=True, help="Output path")
args = parser.parse_args()
if args.dataset == "imagenet":
process_imagenet(args.root, args.split)
elif args.dataset == "ucf101":
process_ucf101(args.root, args.split)
elif args.dataset == "vidprom":
process_vidprom(args.root, args.info)
elif args.dataset == "image":
process_general_images(args.root, args.output)
elif args.dataset == "video":
process_general_videos(args.root, args.output)
else:
raise ValueError("Invalid dataset")

14
tools/datasets/csv2txt.py Normal file
View File

@ -0,0 +1,14 @@
import argparse
import pandas as pd
parser = argparse.ArgumentParser(description="Convert CSV file to txt file")
parser.add_argument("csv_file", type=str, help="CSV file to convert")
parser.add_argument("txt_file", type=str, help="TXT file to save")
args = parser.parse_args()
data = pd.read_csv(args.csv_file)
text = data["text"].to_list()
text = "\n".join(text)
with open(args.txt_file, "w") as f:
f.write(text)

1089
tools/datasets/datautil.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,262 @@
# TODO: remove this file before releasing
import argparse
import html
import os
import re
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
try:
from pandarallel import pandarallel
pandarallel.initialize(progress_bar=True)
pandas_has_parallel = True
except ImportError:
pandas_has_parallel = False
def apply(df, func, **kwargs):
if pandas_has_parallel:
return df.parallel_apply(func, **kwargs)
return df.progress_apply(func, **kwargs)
def basic_clean(text):
import ftfy
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
BAD_PUNCT_REGEX = re.compile(
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
def clean_caption(caption):
import urllib.parse as ul
from bs4 import BeautifulSoup
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = basic_clean(caption)
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
def get_10m_set():
meta_path_10m = "/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m.csv"
meta_10m = pd.read_csv(meta_path_10m)
def process_single_caption(row):
text_list = eval(row["caption"])
clean_list = [clean_caption(x) for x in text_list]
return str(clean_list)
ret = apply(meta_10m, process_single_caption, axis=1)
# ret = meta_10m.progress_apply(process_single_caption, axis=1)
print("==> text processed.")
text_list = []
for x in ret:
text_list += eval(x)
# text_set = text_set.union(set(eval(x)))
text_set = set(text_list)
# meta_10m['caption_new'] = ret
# meta_10m.to_csv('/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m_new-cap.csv')
# video_id_set = set(meta_10m['videoID'])
# id2t = {}
# for idx, row in tqdm(meta_10m.iterrows(), total=len(meta_10m)):
# video_id = row['videoID']
# text_list = eval(row['caption'])
# id2t[video_id] = set(text_list)
print(f"==> Loaded meta_10m from '{meta_path_10m}'")
return text_set
def filter_panda10m_text(meta_path, text_set):
def process_single_row(row):
# path = row['path']
t = row["text"]
# fname = os.path.basename(path)
# video_id = fname[:fname.rindex('_')]
if t not in text_set:
return False
return True
meta = pd.read_csv(meta_path)
ret = apply(meta, process_single_row, axis=1)
# ret = meta.progress_apply(process_single_row, axis=1)
meta = meta[ret]
wo_ext, ext = os.path.splitext(meta_path)
out_path = f"{wo_ext}_filter-10m{ext}"
meta.to_csv(out_path, index=False)
print(f"New meta (shape={meta.shape}) saved to '{out_path}'.")
def filter_panda10m_timestamp(meta_path):
meta_path_10m = "/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m.csv"
meta_10m = pd.read_csv(meta_path_10m)
id2t = {}
for idx, row in tqdm(meta_10m.iterrows(), total=len(meta_10m)):
video_id = row["videoID"]
timestamp = eval(row["timestamp"])
timestamp = [str(tuple(x)) for x in timestamp]
id2t[video_id] = timestamp
# video_id_set_10m = set(meta_10m['videoID'])
print(f"==> Loaded meta_10m from '{meta_path_10m}'")
def process_single_row(row):
path = row["path"]
t = row["timestamp"]
fname = os.path.basename(path)
video_id = fname[: fname.rindex("_")]
if video_id not in id2t:
return False
if t not in id2t[video_id]:
return False
return True
# return video_id in video_id_set_10m
meta = pd.read_csv(meta_path)
ret = apply(meta, process_single_row, axis=1)
meta = meta[ret]
wo_ext, ext = os.path.splitext(meta_path)
out_path = f"{wo_ext}_filter-10m{ext}"
meta.to_csv(out_path, index=False)
print(f"New meta (shape={meta.shape}) saved to '{out_path}'.")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--meta_path", type=str, nargs="+")
parser.add_argument("--num_workers", default=5, type=int)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
text_set = get_10m_set()
for x in args.meta_path:
filter_panda10m_text(x, text_set)

View File

@ -0,0 +1,66 @@
import argparse
import os
import cv2
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
try:
from pandarallel import pandarallel
PANDA_USE_PARALLEL = True
except ImportError:
PANDA_USE_PARALLEL = False
def save_first_frame(video_path, img_dir):
if not os.path.exists(video_path):
print(f"Video not found: {video_path}")
return ""
try:
cap = cv2.VideoCapture(video_path)
success, frame = cap.read()
if success:
video_name = os.path.basename(video_path)
image_name = os.path.splitext(video_name)[0] + "_first_frame.jpg"
image_path = os.path.join(img_dir, image_name)
cv2.imwrite(image_path, frame)
else:
raise ValueError("Video broken.")
cap.release()
return image_path
except Exception as e:
print(f"Save first frame of `{video_path}` failed. {e}")
return ""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="path to the input csv dataset")
parser.add_argument("--img-dir", type=str, help="path to save first frame image")
parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing")
parser.add_argument("--num-workers", type=int, default=None, help="number of workers")
args = parser.parse_args()
if args.disable_parallel:
PANDA_USE_PARALLEL = False
if PANDA_USE_PARALLEL:
if args.num_workers is not None:
pandarallel.initialize(nb_workers=args.num_workers, progress_bar=True)
else:
pandarallel.initialize(progress_bar=True)
if not os.path.exists(args.img_dir):
os.makedirs(args.img_dir)
data = pd.read_csv(args.input)
data["first_frame_path"] = data["path"].parallel_apply(save_first_frame, img_dir=args.img_dir)
data_filtered = data.loc[data["first_frame_path"] != ""]
output_csv_path = args.input.replace(".csv", "_first-frame.csv")
data_filtered.to_csv(output_csv_path, index=False)
print(f"First frame csv saved to: {output_csv_path}, first frame images saved to {args.img_dir}.")

72
tools/datasets/split.py Normal file
View File

@ -0,0 +1,72 @@
import argparse
from typing import List
import pandas as pd
from mmengine.config import Config
from opensora.datasets.bucket import Bucket
def split_by_bucket(
bucket: Bucket,
input_files: List[str],
output_path: str,
limit: int,
frame_interval: int,
):
print(f"Split {len(input_files)} files into {len(bucket)} buckets")
total_limit = len(bucket) * limit
bucket_cnt = {}
# get all bucket id
for hw_id, d in bucket.ar_criteria.items():
for t_id, v in d.items():
for ar_id in v.keys():
bucket_id = (hw_id, t_id, ar_id)
bucket_cnt[bucket_id] = 0
output_df = None
# split files
for path in input_files:
df = pd.read_csv(path)
if output_df is None:
output_df = pd.DataFrame(columns=df.columns)
for i in range(len(df)):
row = df.iloc[i]
t, h, w = row["num_frames"], row["height"], row["width"]
bucket_id = bucket.get_bucket_id(t, h, w, frame_interval)
if bucket_id is None:
continue
if bucket_cnt[bucket_id] < limit:
bucket_cnt[bucket_id] += 1
output_df = pd.concat([output_df, pd.DataFrame([row])], ignore_index=True)
if len(output_df) >= total_limit:
break
if len(output_df) >= total_limit:
break
assert len(output_df) <= total_limit
if len(output_df) == total_limit:
print(f"All buckets are full ({total_limit} samples)")
else:
print(f"Only {len(output_df)} files are used")
output_df.to_csv(output_path, index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, nargs="+")
parser.add_argument("-o", "--output", required=True)
parser.add_argument("-c", "--config", required=True)
parser.add_argument("-l", "--limit", default=200, type=int)
args = parser.parse_args()
assert args.limit > 0
cfg = Config.fromfile(args.config)
bucket_config = cfg.bucket_config
# rewrite bucket_config
for ar, d in bucket_config.items():
for frames, t in d.items():
p, bs = t
if p > 0.0:
p = 1.0
d[frames] = (p, bs)
bucket = Bucket(bucket_config)
split_by_bucket(bucket, args.input, args.output, args.limit, cfg.dataset.frame_interval)

306
tools/datasets/transform.py Normal file
View File

@ -0,0 +1,306 @@
import argparse
import os
import random
import shutil
import subprocess
import cv2
import ffmpeg
import numpy as np
import pandas as pd
from pandarallel import pandarallel
from tqdm import tqdm
from .utils import IMG_EXTENSIONS, extract_frames
tqdm.pandas()
USE_PANDARALLEL = True
def apply(df, func, **kwargs):
if USE_PANDARALLEL:
return df.parallel_apply(func, **kwargs)
return df.progress_apply(func, **kwargs)
def get_new_path(path, input_dir, output):
path_new = os.path.join(output, os.path.relpath(path, input_dir))
os.makedirs(os.path.dirname(path_new), exist_ok=True)
return path_new
def resize_longer(path, length, input_dir, output_dir):
path_new = get_new_path(path, input_dir, output_dir)
ext = os.path.splitext(path)[1].lower()
assert ext in IMG_EXTENSIONS
img = cv2.imread(path)
if img is not None:
h, w = img.shape[:2]
if min(h, w) > length:
if h > w:
new_h = length
new_w = int(w / h * length)
else:
new_w = length
new_h = int(h / w * length)
img = cv2.resize(img, (new_w, new_h))
cv2.imwrite(path_new, img)
else:
path_new = ""
return path_new
def resize_shorter(path, length, input_dir, output_dir):
path_new = get_new_path(path, input_dir, output_dir)
if os.path.exists(path_new):
return path_new
ext = os.path.splitext(path)[1].lower()
assert ext in IMG_EXTENSIONS
img = cv2.imread(path)
if img is not None:
h, w = img.shape[:2]
if min(h, w) > length:
if h > w:
new_w = length
new_h = int(h / w * length)
else:
new_h = length
new_w = int(w / h * length)
img = cv2.resize(img, (new_w, new_h))
cv2.imwrite(path_new, img)
else:
path_new = ""
return path_new
def rand_crop(path, input_dir, output):
ext = os.path.splitext(path)[1].lower()
path_new = get_new_path(path, input_dir, output)
assert ext in IMG_EXTENSIONS
img = cv2.imread(path)
if img is not None:
h, w = img.shape[:2]
width, height, _ = img.shape
pos = random.randint(0, 3)
if pos == 0:
img_cropped = img[: width // 2, : height // 2]
elif pos == 1:
img_cropped = img[width // 2 :, : height // 2]
elif pos == 2:
img_cropped = img[: width // 2, height // 2 :]
else:
img_cropped = img[width // 2 :, height // 2 :]
cv2.imwrite(path_new, img_cropped)
else:
path_new = ""
return path_new
def m2ts_to_mp4(row, output_dir):
input_path = row["path"]
output_name = os.path.basename(input_path).replace(".m2ts", ".mp4")
output_path = os.path.join(output_dir, output_name)
# create directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)
try:
ffmpeg.input(input_path).output(output_path).overwrite_output().global_args("-loglevel", "quiet").run(
capture_stdout=True
)
row["path"] = output_path
row["relpath"] = os.path.splitext(row["relpath"])[0] + ".mp4"
except Exception as e:
print(f"Error converting {input_path} to mp4: {e}")
row["path"] = ""
row["relpath"] = ""
return row
return row
def mkv_to_mp4(row, output_dir):
# str_to_replace and str_to_replace_with account for the different directory structure
input_path = row["path"]
output_name = os.path.basename(input_path).replace(".mkv", ".mp4")
output_path = os.path.join(output_dir, output_name)
# create directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)
try:
ffmpeg.input(input_path).output(output_path).overwrite_output().global_args("-loglevel", "quiet").run(
capture_stdout=True
)
row["path"] = output_path
row["relpath"] = os.path.splitext(row["relpath"])[0] + ".mp4"
except Exception as e:
print(f"Error converting {input_path} to mp4: {e}")
row["path"] = ""
row["relpath"] = ""
return row
return row
def mp4_to_mp4(row, output_dir):
# str_to_replace and str_to_replace_with account for the different directory structure
input_path = row["path"]
# 检查输入文件是否为.mp4文件
if not input_path.lower().endswith(".mp4"):
print(f"Error: {input_path} is not an .mp4 file.")
row["path"] = ""
row["relpath"] = ""
return row
output_name = os.path.basename(input_path)
output_path = os.path.join(output_dir, output_name)
# create directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)
try:
shutil.copy2(input_path, output_path) # 使用shutil复制文件
row["path"] = output_path
row["relpath"] = os.path.splitext(row["relpath"])[0] + ".mp4"
except Exception as e:
print(f"Error coy {input_path} to mp4: {e}")
row["path"] = ""
row["relpath"] = ""
return row
return row
def crop_to_square(input_path, output_path):
cmd = (
f"ffmpeg -i {input_path} "
f"-vf \"crop='min(in_w,in_h)':'min(in_w,in_h)':'(in_w-min(in_w,in_h))/2':'(in_h-min(in_w,in_h))/2'\" "
f"-c:v libx264 -an "
f"-map 0:v {output_path}"
)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True)
stdout, stderr = proc.communicate()
def vid_crop_center(row, input_dir, output_dir):
input_path = row["path"]
relpath = os.path.relpath(input_path, input_dir)
assert not relpath.startswith("..")
output_path = os.path.join(output_dir, relpath)
# create directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)
try:
crop_to_square(input_path, output_path)
size = min(row["height"], row["width"])
row["path"] = output_path
row["height"] = size
row["width"] = size
row["aspect_ratio"] = 1.0
row["resolution"] = size**2
except Exception as e:
print(f"Error cropping {input_path} to center: {e}")
row["path"] = ""
return row
def main():
args = parse_args()
global USE_PANDARALLEL
assert args.num_workers is None or not args.disable_parallel
if args.disable_parallel:
USE_PANDARALLEL = False
if args.num_workers is not None:
pandarallel.initialize(progress_bar=True, nb_workers=args.num_workers)
else:
pandarallel.initialize(progress_bar=True)
random.seed(args.seed)
data = pd.read_csv(args.meta_path)
if args.task == "img_rand_crop":
data["path"] = apply(data["path"], lambda x: rand_crop(x, args.input_dir, args.output_dir))
output_csv = args.meta_path.replace(".csv", "_rand_crop.csv")
elif args.task == "img_resize_longer":
data["path"] = apply(data["path"], lambda x: resize_longer(x, args.length, args.input_dir, args.output_dir))
output_csv = args.meta_path.replace(".csv", f"_resize-longer-{args.length}.csv")
elif args.task == "img_resize_shorter":
data["path"] = apply(data["path"], lambda x: resize_shorter(x, args.length, args.input_dir, args.output_dir))
output_csv = args.meta_path.replace(".csv", f"_resize-shorter-{args.length}.csv")
elif args.task == "vid_frame_extract":
points = args.points if args.points is not None else args.points_index
data = pd.DataFrame(np.repeat(data.values, 3, axis=0), columns=data.columns)
num_points = len(points)
data["point"] = np.nan
for i, point in enumerate(points):
if isinstance(point, int):
data.loc[i::num_points, "point"] = point
else:
data.loc[i::num_points, "point"] = data.loc[i::num_points, "num_frames"] * point
data["path"] = apply(
data, lambda x: extract_frames(x["path"], args.input_dir, args.output_dir, x["point"]), axis=1
)
output_csv = args.meta_path.replace(".csv", "_vid_frame_extract.csv")
elif args.task == "m2ts_to_mp4":
print(f"m2ts_to_mp4作业开始{args.output_dir}")
assert args.meta_path.endswith("_m2ts.csv"), "Input file must end with '_m2ts.csv'"
m2ts_to_mp4_partial = lambda x: m2ts_to_mp4(x, args.output_dir)
data = apply(data, m2ts_to_mp4_partial, axis=1)
data = data[data["path"] != ""]
output_csv = args.meta_path.replace("_m2ts.csv", ".csv")
elif args.task == "mkv_to_mp4":
print(f"mkv_to_mp4作业开始{args.output_dir}")
assert args.meta_path.endswith("_mkv.csv"), "Input file must end with '_mkv.csv'"
mkv_to_mp4_partial = lambda x: mkv_to_mp4(x, args.output_dir)
data = apply(data, mkv_to_mp4_partial, axis=1)
data = data[data["path"] != ""]
output_csv = args.meta_path.replace("_mkv.csv", ".csv")
elif args.task == "mp4_to_mp4":
# assert args.meta_path.endswith("meta.csv"), "Input file must end with '_mkv.csv'"
print(f"MP4复制作业开始{args.output_dir}")
mkv_to_mp4_partial = lambda x: mp4_to_mp4(x, args.output_dir)
data = apply(data, mkv_to_mp4_partial, axis=1)
data = data[data["path"] != ""]
output_csv = args.meta_path
elif args.task == "vid_crop_center":
vid_crop_center_partial = lambda x: vid_crop_center(x, args.input_dir, args.output_dir)
data = apply(data, vid_crop_center_partial, axis=1)
data = data[data["path"] != ""]
output_csv = args.meta_path.replace(".csv", "_center-crop.csv")
else:
raise ValueError
data.to_csv(output_csv, index=False)
print(f"Saved to {output_csv}")
raise SystemExit(0)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--task",
type=str,
required=True,
choices=[
"img_resize_longer",
"img_resize_shorter",
"img_rand_crop",
"vid_frame_extract",
"m2ts_to_mp4",
"mkv_to_mp4",
"mp4_to_mp4",
"vid_crop_center",
],
)
parser.add_argument("--meta_path", type=str, required=True)
parser.add_argument("--input_dir", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--length", type=int, default=1080)
parser.add_argument("--disable-parallel", action="store_true")
parser.add_argument("--num_workers", type=int, default=None)
parser.add_argument("--seed", type=int, default=42, help="seed for random")
parser.add_argument("--points", nargs="+", type=float, default=None)
parser.add_argument("--points_index", nargs="+", type=int, default=None)
args = parser.parse_args()
return args
if __name__ == "__main__":
main()

130
tools/datasets/utils.py Normal file
View File

@ -0,0 +1,130 @@
import os
import cv2
import numpy as np
from PIL import Image
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
def is_video(filename):
ext = os.path.splitext(filename)[-1].lower()
return ext in VID_EXTENSIONS
def extract_frames(
video_path,
frame_inds=None,
points=None,
backend="opencv",
return_length=False,
num_frames=None,
):
"""
Args:
video_path (str): path to video
frame_inds (List[int]): indices of frames to extract
points (List[float]): values within [0, 1); multiply #frames to get frame indices
Return:
List[PIL.Image]
"""
assert backend in ["av", "opencv", "decord"]
assert (frame_inds is None) or (points is None)
if backend == "av":
import av
container = av.open(video_path)
if num_frames is not None:
total_frames = num_frames
else:
total_frames = container.streams.video[0].frames
if points is not None:
frame_inds = [int(p * total_frames) for p in points]
frames = []
for idx in frame_inds:
if idx >= total_frames:
idx = total_frames - 1
target_timestamp = int(idx * av.time_base / container.streams.video[0].average_rate)
container.seek(target_timestamp) # return the nearest key frame, not the precise timestamp!!!
frame = next(container.decode(video=0)).to_image()
frames.append(frame)
if return_length:
return frames, total_frames
return frames
elif backend == "decord":
import decord
container = decord.VideoReader(video_path, num_threads=1)
if num_frames is not None:
total_frames = num_frames
else:
total_frames = len(container)
if points is not None:
frame_inds = [int(p * total_frames) for p in points]
frame_inds = np.array(frame_inds).astype(np.int32)
frame_inds[frame_inds >= total_frames] = total_frames - 1
frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C]
frames = [Image.fromarray(x) for x in frames]
if return_length:
return frames, total_frames
return frames
elif backend == "opencv":
cap = cv2.VideoCapture(video_path)
if num_frames is not None:
total_frames = num_frames
else:
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if points is not None:
frame_inds = [int(p * total_frames) for p in points]
frames = []
for idx in frame_inds:
if idx >= total_frames:
idx = total_frames - 1
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
# HACK: sometimes OpenCV fails to read frames, return a black frame instead
try:
ret, frame = cap.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
except Exception as e:
print(f"[Warning] Error reading frame {idx} from {video_path}: {e}")
# First, try to read the first frame
try:
print(f"[Warning] Try reading first frame.")
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
ret, frame = cap.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
# If that fails, return a black frame
except Exception as e:
print(f"[Warning] Error in reading first frame from {video_path}: {e}")
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame = Image.new("RGB", (width, height), (0, 0, 0))
# HACK: if height or width is 0, return a black frame instead
if frame.height == 0 or frame.width == 0:
height = width = 256
frame = Image.new("RGB", (width, height), (0, 0, 0))
frames.append(frame)
if return_length:
return frames, total_frames
return frames
else:
raise ValueError