feat: add opensora/datasets module and tools/datasets
- Add opensora/datasets (aspect, bucket, dataloader, datasets, parallel, pin_memory_cache, read_video, sampler, utils, video_transforms) - Add tools/datasets pipeline scripts - Fix .gitignore: scope /datasets to root-level only, whitelist opensora/datasets/ Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
916ee2126d
commit
bdeb2870d4
|
|
@ -195,4 +195,5 @@ package.json
|
||||||
exps
|
exps
|
||||||
ckpts
|
ckpts
|
||||||
flash-attention
|
flash-attention
|
||||||
datasets
|
/datasets
|
||||||
|
!opensora/datasets/
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .datasets import TextDataset, VideoTextDataset
|
||||||
|
from .utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample
|
||||||
|
|
@ -0,0 +1,151 @@
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
|
ASPECT_RATIO_LD_LIST = [ # width:height
|
||||||
|
"2.39:1", # cinemascope, 2.39
|
||||||
|
"2:1", # rare, 2
|
||||||
|
"16:9", # rare, 1.89
|
||||||
|
"1.85:1", # american widescreen, 1.85
|
||||||
|
"9:16", # popular, 1.78
|
||||||
|
"5:8", # rare, 1.6
|
||||||
|
"3:2", # rare, 1.5
|
||||||
|
"4:3", # classic, 1.33
|
||||||
|
"1:1", # square
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_ratio(name: str) -> float:
|
||||||
|
width, height = map(float, name.split(":"))
|
||||||
|
return height / width
|
||||||
|
|
||||||
|
|
||||||
|
def get_aspect_ratios_dict(
|
||||||
|
total_pixels: int = 256 * 256, training: bool = True
|
||||||
|
) -> dict[str, tuple[int, int]]:
|
||||||
|
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
|
||||||
|
aspect_ratios_dict = {}
|
||||||
|
aspect_ratios_vertical_dict = {}
|
||||||
|
for ratio in ASPECT_RATIO_LD_LIST:
|
||||||
|
width_ratio, height_ratio = map(float, ratio.split(":"))
|
||||||
|
width = int(math.sqrt(total_pixels * (width_ratio / height_ratio)) // D) * D
|
||||||
|
height = int((total_pixels / width) // D) * D
|
||||||
|
|
||||||
|
if training:
|
||||||
|
# adjust aspect ratio to match total pixels
|
||||||
|
diff = abs(height * width - total_pixels)
|
||||||
|
candidate = [
|
||||||
|
(height - D, width),
|
||||||
|
(height + D, width),
|
||||||
|
(height, width - D),
|
||||||
|
(height, width + D),
|
||||||
|
]
|
||||||
|
for h, w in candidate:
|
||||||
|
if abs(h * w - total_pixels) < diff:
|
||||||
|
height, width = h, w
|
||||||
|
diff = abs(h * w - total_pixels)
|
||||||
|
|
||||||
|
# remove duplicated aspect ratio
|
||||||
|
if (height, width) not in aspect_ratios_dict.values() or not training:
|
||||||
|
aspect_ratios_dict[ratio] = (height, width)
|
||||||
|
vertial_ratios = ":".join(ratio.split(":")[::-1])
|
||||||
|
aspect_ratios_vertical_dict[vertial_ratios] = (width, height)
|
||||||
|
|
||||||
|
aspect_ratios_dict.update(aspect_ratios_vertical_dict)
|
||||||
|
|
||||||
|
return aspect_ratios_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_pexels(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]:
|
||||||
|
return {ratio: h * w for ratio, (h, w) in aspect_ratios_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_tokens(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]:
|
||||||
|
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
|
||||||
|
return {ratio: h * w // D // D for ratio, (h, w) in aspect_ratios_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_pexels_from_name(resolution: str) -> int:
|
||||||
|
resolution = resolution.split("_")[0]
|
||||||
|
if resolution.endswith("px"):
|
||||||
|
size = int(resolution[:-2])
|
||||||
|
num_pexels = size * size
|
||||||
|
elif resolution.endswith("p"):
|
||||||
|
size = int(resolution[:-1])
|
||||||
|
num_pexels = int(size * size / 9 * 16)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid resolution {resolution}")
|
||||||
|
return num_pexels
|
||||||
|
|
||||||
|
|
||||||
|
def get_resolution_with_aspect_ratio(
|
||||||
|
resolution: str,
|
||||||
|
) -> tuple[int, dict[str, tuple[int, int]]]:
|
||||||
|
"""Get resolution with aspect ratio
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resolution (str): resolution name. The format is name only or "{name}_{setting}".
|
||||||
|
name supports "256px" or "360p". setting supports "ar1:1" or "max".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[int, dict[str, tuple[int, int]]]: resolution with aspect ratio
|
||||||
|
"""
|
||||||
|
keys = resolution.split("_")
|
||||||
|
if len(keys) == 1:
|
||||||
|
resolution = keys[0]
|
||||||
|
setting = ""
|
||||||
|
else:
|
||||||
|
resolution, setting = keys
|
||||||
|
assert setting == "max" or setting.startswith(
|
||||||
|
"ar"
|
||||||
|
), f"Invalid setting {setting}"
|
||||||
|
|
||||||
|
# get resolution
|
||||||
|
num_pexels = get_num_pexels_from_name(resolution)
|
||||||
|
|
||||||
|
# get aspect ratio
|
||||||
|
aspect_ratio_dict = get_aspect_ratios_dict(num_pexels)
|
||||||
|
|
||||||
|
# handle setting
|
||||||
|
if setting == "max":
|
||||||
|
aspect_ratio = max(
|
||||||
|
aspect_ratio_dict,
|
||||||
|
key=lambda x: aspect_ratio_dict[x][0] * aspect_ratio_dict[x][1],
|
||||||
|
)
|
||||||
|
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}
|
||||||
|
elif setting.startswith("ar"):
|
||||||
|
aspect_ratio = setting[2:]
|
||||||
|
assert (
|
||||||
|
aspect_ratio in aspect_ratio_dict
|
||||||
|
), f"Aspect ratio {aspect_ratio} not found"
|
||||||
|
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}
|
||||||
|
|
||||||
|
return num_pexels, aspect_ratio_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_closest_ratio(height: float, width: float, ratios: dict) -> str:
|
||||||
|
aspect_ratio = height / width
|
||||||
|
closest_ratio = min(
|
||||||
|
ratios.keys(), key=lambda ratio: abs(aspect_ratio - get_ratio(ratio))
|
||||||
|
)
|
||||||
|
return closest_ratio
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_size(
|
||||||
|
resolution: str, ar_ratio: str, training: bool = True
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
num_pexels = get_num_pexels_from_name(resolution)
|
||||||
|
ar_dict = get_aspect_ratios_dict(num_pexels, training)
|
||||||
|
assert ar_ratio in ar_dict, f"Aspect ratio {ar_ratio} not found"
|
||||||
|
return ar_dict[ar_ratio]
|
||||||
|
|
||||||
|
|
||||||
|
def bucket_to_shapes(bucket_config, batch_size=None):
|
||||||
|
shapes = []
|
||||||
|
for resolution, infos in bucket_config.items():
|
||||||
|
for num_frames, (_, bs) in infos.items():
|
||||||
|
aspect_ratios = get_aspect_ratios_dict(get_num_pexels_from_name(resolution))
|
||||||
|
for ar, (height, width) in aspect_ratios.items():
|
||||||
|
if batch_size is not None:
|
||||||
|
bs = batch_size
|
||||||
|
shapes.append((bs, 3, num_frames, height, width))
|
||||||
|
return shapes
|
||||||
|
|
@ -0,0 +1,139 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from opensora.utils.logger import log_message
|
||||||
|
|
||||||
|
from .aspect import get_closest_ratio, get_resolution_with_aspect_ratio
|
||||||
|
from .utils import map_target_fps
|
||||||
|
|
||||||
|
|
||||||
|
class Bucket:
|
||||||
|
def __init__(self, bucket_config: dict[str, dict[int, tuple[float, int] | tuple[tuple[float, float], int]]]):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
bucket_config (dict): A dictionary containing the bucket configuration.
|
||||||
|
The dictionary should be in the following format:
|
||||||
|
{
|
||||||
|
"bucket_name": {
|
||||||
|
"time": (probability, batch_size),
|
||||||
|
"time": (probability, batch_size),
|
||||||
|
...
|
||||||
|
},
|
||||||
|
...
|
||||||
|
}
|
||||||
|
|
||||||
|
Or in the following format:
|
||||||
|
{
|
||||||
|
"bucket_name": {
|
||||||
|
"time": ((probability, next_probability), batch_size),
|
||||||
|
"time": ((probability, next_probability), batch_size),
|
||||||
|
...
|
||||||
|
},
|
||||||
|
...
|
||||||
|
}
|
||||||
|
The bucket_name should be the name of the bucket, and the time should be the number of frames in the video.
|
||||||
|
The probability should be a float between 0 and 1, and the batch_size should be an integer.
|
||||||
|
If the probability is a tuple, the second value should be the probability to skip to the next time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
aspect_ratios = {key: get_resolution_with_aspect_ratio(key) for key in bucket_config.keys()}
|
||||||
|
bucket_probs = OrderedDict()
|
||||||
|
bucket_bs = OrderedDict()
|
||||||
|
bucket_names = sorted(bucket_config.keys(), key=lambda x: aspect_ratios[x][0], reverse=True)
|
||||||
|
|
||||||
|
for key in bucket_names:
|
||||||
|
bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True)
|
||||||
|
bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names})
|
||||||
|
bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names})
|
||||||
|
|
||||||
|
self.hw_criteria = {k: aspect_ratios[k][0] for k in bucket_names}
|
||||||
|
self.t_criteria = {k1: {k2: k2 for k2 in bucket_config[k1].keys()} for k1 in bucket_names}
|
||||||
|
self.ar_criteria = {
|
||||||
|
k1: {k2: {k3: v3 for k3, v3 in aspect_ratios[k1][1].items()} for k2 in bucket_config[k1].keys()}
|
||||||
|
for k1 in bucket_names
|
||||||
|
}
|
||||||
|
|
||||||
|
bucket_id_cnt = num_bucket = 0
|
||||||
|
bucket_id = dict()
|
||||||
|
for k1, v1 in bucket_probs.items():
|
||||||
|
bucket_id[k1] = dict()
|
||||||
|
for k2, _ in v1.items():
|
||||||
|
bucket_id[k1][k2] = bucket_id_cnt
|
||||||
|
bucket_id_cnt += 1
|
||||||
|
num_bucket += len(aspect_ratios[k1][1])
|
||||||
|
|
||||||
|
self.bucket_probs = bucket_probs
|
||||||
|
self.bucket_bs = bucket_bs
|
||||||
|
self.bucket_id = bucket_id
|
||||||
|
self.num_bucket = num_bucket
|
||||||
|
|
||||||
|
log_message("Number of buckets: %s", num_bucket)
|
||||||
|
|
||||||
|
def get_bucket_id(
|
||||||
|
self,
|
||||||
|
T: int,
|
||||||
|
H: int,
|
||||||
|
W: int,
|
||||||
|
fps: float,
|
||||||
|
path: str | None = None,
|
||||||
|
seed: int | None = None,
|
||||||
|
fps_max: int = 16,
|
||||||
|
) -> tuple[str, int, int] | None:
|
||||||
|
approx = 0.8
|
||||||
|
_, sampling_interval = map_target_fps(fps, fps_max)
|
||||||
|
T = T // sampling_interval
|
||||||
|
resolution = H * W
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
|
||||||
|
# Reference to probabilities and criteria for faster access
|
||||||
|
bucket_probs = self.bucket_probs
|
||||||
|
hw_criteria = self.hw_criteria
|
||||||
|
ar_criteria = self.ar_criteria
|
||||||
|
|
||||||
|
# Start searching for the appropriate bucket
|
||||||
|
for hw_id, t_criteria in bucket_probs.items():
|
||||||
|
# if resolution is too low, skip
|
||||||
|
if resolution < hw_criteria[hw_id] * approx:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if sample is an image
|
||||||
|
if T == 1:
|
||||||
|
if 1 in t_criteria:
|
||||||
|
if rng.random() < t_criteria[1]:
|
||||||
|
return hw_id, 1, get_closest_ratio(H, W, ar_criteria[hw_id][1])
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Look for suitable t_id for video
|
||||||
|
for t_id, prob in t_criteria.items():
|
||||||
|
if T >= t_id and t_id != 1:
|
||||||
|
# if prob is a tuple, use the second value as the threshold to skip
|
||||||
|
# to the next t_id
|
||||||
|
if isinstance(prob, tuple):
|
||||||
|
next_hw_prob, next_t_prob = prob
|
||||||
|
if next_t_prob >= 1 or rng.random() <= next_t_prob:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
next_hw_prob = prob
|
||||||
|
if next_hw_prob >= 1 or rng.random() <= next_hw_prob:
|
||||||
|
ar_id = get_closest_ratio(H, W, ar_criteria[hw_id][t_id])
|
||||||
|
return hw_id, t_id, ar_id
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_thw(self, bucket_idx: tuple[str, int, int]) -> tuple[int, int, int]:
|
||||||
|
assert len(bucket_idx) == 3
|
||||||
|
T = self.t_criteria[bucket_idx[0]][bucket_idx[1]]
|
||||||
|
H, W = self.ar_criteria[bucket_idx[0]][bucket_idx[1]][bucket_idx[2]]
|
||||||
|
return T, H, W
|
||||||
|
|
||||||
|
def get_prob(self, bucket_idx: tuple[str, int]) -> float:
|
||||||
|
return self.bucket_probs[bucket_idx[0]][bucket_idx[1]]
|
||||||
|
|
||||||
|
def get_batch_size(self, bucket_idx: tuple[str, int]) -> int:
|
||||||
|
return self.bucket_bs[bucket_idx[0]][bucket_idx[1]]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.num_bucket
|
||||||
|
|
@ -0,0 +1,402 @@
|
||||||
|
import collections
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as multiprocessing
|
||||||
|
from torch._utils import ExceptionWrapper
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
from torch.utils.data import DataLoader, _utils
|
||||||
|
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
|
||||||
|
from torch.utils.data.dataloader import (
|
||||||
|
IterDataPipe,
|
||||||
|
MapDataPipe,
|
||||||
|
_BaseDataLoaderIter,
|
||||||
|
_MultiProcessingDataLoaderIter,
|
||||||
|
_sharding_worker_init_fn,
|
||||||
|
_SingleProcessDataLoaderIter,
|
||||||
|
)
|
||||||
|
|
||||||
|
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||||
|
from opensora.registry import DATASETS, build_module
|
||||||
|
from opensora.utils.config import parse_configs
|
||||||
|
from opensora.utils.logger import create_logger
|
||||||
|
from opensora.utils.misc import format_duration
|
||||||
|
from opensora.utils.train import setup_device
|
||||||
|
|
||||||
|
from .datasets import TextDataset, VideoTextDataset
|
||||||
|
from .pin_memory_cache import PinMemoryCache
|
||||||
|
from .sampler import DistributedSampler, VariableVideoBatchSampler
|
||||||
|
|
||||||
|
|
||||||
|
def _pin_memory_loop(
|
||||||
|
in_queue, out_queue, device_id, done_event, device, pin_memory_cache: PinMemoryCache, pin_memory_key: str
|
||||||
|
):
|
||||||
|
# This setting is thread local, and prevents the copy in pin_memory from
|
||||||
|
# consuming all CPU cores.
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
torch.cuda.set_device(device_id)
|
||||||
|
elif device == "xpu":
|
||||||
|
torch.xpu.set_device(device_id) # type: ignore[attr-defined]
|
||||||
|
elif device == torch._C._get_privateuse1_backend_name():
|
||||||
|
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
|
||||||
|
custom_device_mod.set_device(device_id)
|
||||||
|
|
||||||
|
def do_one_step():
|
||||||
|
try:
|
||||||
|
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
||||||
|
except queue.Empty:
|
||||||
|
return
|
||||||
|
idx, data = r
|
||||||
|
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
|
||||||
|
try:
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
if pin_memory_key in data:
|
||||||
|
val = data[pin_memory_key]
|
||||||
|
pin_memory_value = pin_memory_cache.get(val)
|
||||||
|
pin_memory_value.copy_(val)
|
||||||
|
data[pin_memory_key] = pin_memory_value
|
||||||
|
except Exception:
|
||||||
|
data = ExceptionWrapper(where=f"in pin memory thread for device {device_id}")
|
||||||
|
r = (idx, data)
|
||||||
|
while not done_event.is_set():
|
||||||
|
try:
|
||||||
|
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
|
||||||
|
break
|
||||||
|
except queue.Full:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
||||||
|
# logic of this function.
|
||||||
|
while not done_event.is_set():
|
||||||
|
# Make sure that we don't preserve any object from one iteration
|
||||||
|
# to the next
|
||||||
|
do_one_step()
|
||||||
|
|
||||||
|
|
||||||
|
class _MultiProcessingDataLoaderIterForVideo(_MultiProcessingDataLoaderIter):
|
||||||
|
pin_memory_key: str = "video"
|
||||||
|
|
||||||
|
def __init__(self, loader):
|
||||||
|
_BaseDataLoaderIter.__init__(self, loader)
|
||||||
|
self.pin_memory_cache = PinMemoryCache()
|
||||||
|
|
||||||
|
self._prefetch_factor = loader.prefetch_factor
|
||||||
|
|
||||||
|
assert self._num_workers > 0
|
||||||
|
assert self._prefetch_factor > 0
|
||||||
|
|
||||||
|
if loader.multiprocessing_context is None:
|
||||||
|
multiprocessing_context = multiprocessing
|
||||||
|
else:
|
||||||
|
multiprocessing_context = loader.multiprocessing_context
|
||||||
|
|
||||||
|
self._worker_init_fn = loader.worker_init_fn
|
||||||
|
|
||||||
|
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
||||||
|
# Additional worker init function will take care of sharding in MP and Distributed
|
||||||
|
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
|
||||||
|
self._worker_init_fn = functools.partial(
|
||||||
|
_sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank
|
||||||
|
)
|
||||||
|
|
||||||
|
# No certainty which module multiprocessing_context is
|
||||||
|
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
||||||
|
self._worker_pids_set = False
|
||||||
|
self._shutdown = False
|
||||||
|
self._workers_done_event = multiprocessing_context.Event()
|
||||||
|
|
||||||
|
self._index_queues = []
|
||||||
|
self._workers = []
|
||||||
|
for i in range(self._num_workers):
|
||||||
|
# No certainty which module multiprocessing_context is
|
||||||
|
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
||||||
|
# Need to `cancel_join_thread` here!
|
||||||
|
# See sections (2) and (3b) above.
|
||||||
|
index_queue.cancel_join_thread()
|
||||||
|
w = multiprocessing_context.Process(
|
||||||
|
target=_utils.worker._worker_loop,
|
||||||
|
args=(
|
||||||
|
self._dataset_kind,
|
||||||
|
self._dataset,
|
||||||
|
index_queue,
|
||||||
|
self._worker_result_queue,
|
||||||
|
self._workers_done_event,
|
||||||
|
self._auto_collation,
|
||||||
|
self._collate_fn,
|
||||||
|
self._drop_last,
|
||||||
|
self._base_seed,
|
||||||
|
self._worker_init_fn,
|
||||||
|
i,
|
||||||
|
self._num_workers,
|
||||||
|
self._persistent_workers,
|
||||||
|
self._shared_seed,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
w.daemon = True
|
||||||
|
# NB: Process.start() actually take some time as it needs to
|
||||||
|
# start a process and pass the arguments over via a pipe.
|
||||||
|
# Therefore, we only add a worker to self._workers list after
|
||||||
|
# it started, so that we do not call .join() if program dies
|
||||||
|
# before it starts, and __del__ tries to join but will get:
|
||||||
|
# AssertionError: can only join a started process.
|
||||||
|
w.start()
|
||||||
|
self._index_queues.append(index_queue)
|
||||||
|
self._workers.append(w)
|
||||||
|
|
||||||
|
if self._pin_memory:
|
||||||
|
self._pin_memory_thread_done_event = threading.Event()
|
||||||
|
|
||||||
|
# Queue is not type-annotated
|
||||||
|
self._data_queue = queue.Queue() # type: ignore[var-annotated]
|
||||||
|
if self._pin_memory_device == "xpu":
|
||||||
|
current_device = torch.xpu.current_device() # type: ignore[attr-defined]
|
||||||
|
elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
|
||||||
|
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
|
||||||
|
current_device = custom_device_mod.current_device()
|
||||||
|
else:
|
||||||
|
current_device = torch.cuda.current_device() # choose cuda for default
|
||||||
|
pin_memory_thread = threading.Thread(
|
||||||
|
target=_pin_memory_loop,
|
||||||
|
args=(
|
||||||
|
self._worker_result_queue,
|
||||||
|
self._data_queue,
|
||||||
|
current_device,
|
||||||
|
self._pin_memory_thread_done_event,
|
||||||
|
self._pin_memory_device,
|
||||||
|
self.pin_memory_cache,
|
||||||
|
self.pin_memory_key,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
pin_memory_thread.daemon = True
|
||||||
|
pin_memory_thread.start()
|
||||||
|
# Similar to workers (see comment above), we only register
|
||||||
|
# pin_memory_thread once it is started.
|
||||||
|
self._pin_memory_thread = pin_memory_thread
|
||||||
|
else:
|
||||||
|
self._data_queue = self._worker_result_queue # type: ignore[assignment]
|
||||||
|
|
||||||
|
# In some rare cases, persistent workers (daemonic processes)
|
||||||
|
# would be terminated before `__del__` of iterator is invoked
|
||||||
|
# when main process exits
|
||||||
|
# It would cause failure when pin_memory_thread tries to read
|
||||||
|
# corrupted data from worker_result_queue
|
||||||
|
# atexit is used to shutdown thread and child processes in the
|
||||||
|
# right sequence before main process exits
|
||||||
|
if self._persistent_workers and self._pin_memory:
|
||||||
|
import atexit
|
||||||
|
|
||||||
|
for w in self._workers:
|
||||||
|
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
|
||||||
|
|
||||||
|
# .pid can be None only before process is spawned (not the case, so ignore)
|
||||||
|
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
|
||||||
|
_utils.signal_handling._set_SIGCHLD_handler()
|
||||||
|
self._worker_pids_set = True
|
||||||
|
self._reset(loader, first_iter=True)
|
||||||
|
|
||||||
|
def remove_cache(self, output_tensor: torch.Tensor):
|
||||||
|
self.pin_memory_cache.remove(output_tensor)
|
||||||
|
|
||||||
|
def get_cache_info(self) -> str:
|
||||||
|
return str(self.pin_memory_cache)
|
||||||
|
|
||||||
|
|
||||||
|
class DataloaderForVideo(DataLoader):
|
||||||
|
def _get_iterator(self) -> "_BaseDataLoaderIter":
|
||||||
|
if self.num_workers == 0:
|
||||||
|
return _SingleProcessDataLoaderIter(self)
|
||||||
|
else:
|
||||||
|
self.check_worker_number_rationality()
|
||||||
|
return _MultiProcessingDataLoaderIterForVideo(self)
|
||||||
|
|
||||||
|
|
||||||
|
# Deterministic dataloader
|
||||||
|
def get_seed_worker(seed):
|
||||||
|
def seed_worker(worker_id):
|
||||||
|
worker_seed = seed
|
||||||
|
if seed is not None:
|
||||||
|
np.random.seed(worker_seed)
|
||||||
|
torch.manual_seed(worker_seed)
|
||||||
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
return seed_worker
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataloader(
|
||||||
|
dataset,
|
||||||
|
batch_size=None,
|
||||||
|
shuffle=False,
|
||||||
|
seed=1024,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=False,
|
||||||
|
num_workers=0,
|
||||||
|
process_group: ProcessGroup | None = None,
|
||||||
|
bucket_config=None,
|
||||||
|
num_bucket_build_workers=1,
|
||||||
|
prefetch_factor=None,
|
||||||
|
cache_pin_memory=False,
|
||||||
|
num_groups=1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
_kwargs = kwargs.copy()
|
||||||
|
if isinstance(dataset, VideoTextDataset):
|
||||||
|
batch_sampler = VariableVideoBatchSampler(
|
||||||
|
dataset,
|
||||||
|
bucket_config,
|
||||||
|
num_replicas=process_group.size(),
|
||||||
|
rank=process_group.rank(),
|
||||||
|
shuffle=shuffle,
|
||||||
|
seed=seed,
|
||||||
|
drop_last=drop_last,
|
||||||
|
verbose=True,
|
||||||
|
num_bucket_build_workers=num_bucket_build_workers,
|
||||||
|
num_groups=num_groups,
|
||||||
|
)
|
||||||
|
dl_cls = DataloaderForVideo if cache_pin_memory else DataLoader
|
||||||
|
return (
|
||||||
|
dl_cls(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
|
worker_init_fn=get_seed_worker(seed),
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
num_workers=num_workers,
|
||||||
|
collate_fn=collate_fn_default,
|
||||||
|
prefetch_factor=prefetch_factor,
|
||||||
|
**_kwargs,
|
||||||
|
),
|
||||||
|
batch_sampler,
|
||||||
|
)
|
||||||
|
elif isinstance(dataset, TextDataset):
|
||||||
|
if process_group is None:
|
||||||
|
return (
|
||||||
|
DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
worker_init_fn=get_seed_worker(seed),
|
||||||
|
drop_last=drop_last,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
num_workers=num_workers,
|
||||||
|
prefetch_factor=prefetch_factor,
|
||||||
|
**_kwargs,
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampler = DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
num_replicas=process_group.size(),
|
||||||
|
rank=process_group.rank(),
|
||||||
|
shuffle=shuffle,
|
||||||
|
seed=seed,
|
||||||
|
drop_last=drop_last,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
DataLoader(
|
||||||
|
dataset,
|
||||||
|
sampler=sampler,
|
||||||
|
worker_init_fn=get_seed_worker(seed),
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
num_workers=num_workers,
|
||||||
|
collate_fn=collate_fn_default,
|
||||||
|
prefetch_factor=prefetch_factor,
|
||||||
|
**_kwargs,
|
||||||
|
),
|
||||||
|
sampler,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_default(batch):
|
||||||
|
# filter out None
|
||||||
|
batch = [x for x in batch if x is not None]
|
||||||
|
assert len(batch) > 0, "batch is empty"
|
||||||
|
|
||||||
|
# HACK: for loading text features
|
||||||
|
use_mask = False
|
||||||
|
if "mask" in batch[0] and isinstance(batch[0]["mask"], int):
|
||||||
|
masks = [x.pop("mask") for x in batch]
|
||||||
|
|
||||||
|
texts = [x.pop("text") for x in batch]
|
||||||
|
texts = torch.cat(texts, dim=1)
|
||||||
|
use_mask = True
|
||||||
|
|
||||||
|
ret = torch.utils.data.default_collate(batch)
|
||||||
|
|
||||||
|
if use_mask:
|
||||||
|
ret["mask"] = masks
|
||||||
|
ret["text"] = texts
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_batch(batch):
|
||||||
|
"""
|
||||||
|
Used only with BatchDistributedSampler
|
||||||
|
"""
|
||||||
|
# filter out None
|
||||||
|
batch = [x for x in batch if x is not None]
|
||||||
|
|
||||||
|
res = torch.utils.data.default_collate(batch)
|
||||||
|
|
||||||
|
# squeeze the first dimension, which is due to torch.stack() in default_collate()
|
||||||
|
if isinstance(res, collections.abc.Mapping):
|
||||||
|
for k, v in res.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
res[k] = v.squeeze(0)
|
||||||
|
elif isinstance(res, collections.abc.Sequence):
|
||||||
|
res = [x.squeeze(0) if isinstance(x, torch.Tensor) else x for x in res]
|
||||||
|
elif isinstance(res, torch.Tensor):
|
||||||
|
res = res.squeeze(0)
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# NUM_GPU: number of GPUs for training
|
||||||
|
# TIME_PER_STEP: time per step in seconds
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
# torchrun --nproc_per_node 1 -m opensora.datasets.dataloader configs/diffusion/train/video_cond.py
|
||||||
|
cfg = parse_configs()
|
||||||
|
setup_device()
|
||||||
|
logger = create_logger()
|
||||||
|
|
||||||
|
# == build dataset ==
|
||||||
|
dataset = build_module(cfg.dataset, DATASETS)
|
||||||
|
|
||||||
|
# == build dataloader ==
|
||||||
|
dataloader_args = dict(
|
||||||
|
dataset=dataset,
|
||||||
|
batch_size=cfg.get("batch_size", None),
|
||||||
|
num_workers=cfg.get("num_workers", 4),
|
||||||
|
seed=cfg.get("seed", 1024),
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=True,
|
||||||
|
process_group=get_data_parallel_group(),
|
||||||
|
prefetch_factor=cfg.get("prefetch_factor", None),
|
||||||
|
)
|
||||||
|
dataloader, sampler = prepare_dataloader(
|
||||||
|
bucket_config=cfg.get("bucket_config", None),
|
||||||
|
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
|
||||||
|
**dataloader_args,
|
||||||
|
)
|
||||||
|
num_steps_per_epoch = len(dataloader)
|
||||||
|
num_machines = int(os.environ.get("NUM_MACHINES", 28))
|
||||||
|
num_gpu = num_machines * 8
|
||||||
|
logger.info("Number of GPUs: %d", num_gpu)
|
||||||
|
logger.info("Number of steps per epoch: %d", num_steps_per_epoch // num_gpu)
|
||||||
|
time_per_step = int(os.environ.get("TIME_PER_STEP", 20))
|
||||||
|
time_training = num_steps_per_epoch // num_gpu * time_per_step
|
||||||
|
logger.info("Time per step: %s", format_duration(time_per_step))
|
||||||
|
logger.info("Time for training: %s", format_duration(time_training))
|
||||||
|
|
@ -0,0 +1,315 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from PIL import ImageFile
|
||||||
|
from torchvision.datasets.folder import pil_loader
|
||||||
|
|
||||||
|
from opensora.registry import DATASETS
|
||||||
|
|
||||||
|
from .read_video import read_video
|
||||||
|
from .utils import get_transforms_image, get_transforms_video, is_img, map_target_fps, read_file, temporal_random_crop
|
||||||
|
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
VALID_KEYS = ("neg", "path")
|
||||||
|
K = 10000
|
||||||
|
|
||||||
|
|
||||||
|
class Iloc:
|
||||||
|
def __init__(self, data, sharded_folder, sharded_folders, rows_per_shard):
|
||||||
|
self.data = data
|
||||||
|
self.sharded_folder = sharded_folder
|
||||||
|
self.sharded_folders = sharded_folders
|
||||||
|
self.rows_per_shard = rows_per_shard
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return Item(
|
||||||
|
index,
|
||||||
|
self.data,
|
||||||
|
self.sharded_folder,
|
||||||
|
self.sharded_folders,
|
||||||
|
self.rows_per_shard,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Item:
|
||||||
|
def __init__(self, index, data, sharded_folder, sharded_folders, rows_per_shard):
|
||||||
|
self.index = index
|
||||||
|
self.data = data
|
||||||
|
self.sharded_folder = sharded_folder
|
||||||
|
self.sharded_folders = sharded_folders
|
||||||
|
self.rows_per_shard = rows_per_shard
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
index = self.index
|
||||||
|
if key in self.data.columns:
|
||||||
|
return self.data[key].iloc[index]
|
||||||
|
else:
|
||||||
|
shard_idx = index // self.rows_per_shard
|
||||||
|
idx = index % self.rows_per_shard
|
||||||
|
shard_parquet = os.path.join(self.sharded_folder, self.sharded_folders[shard_idx])
|
||||||
|
try:
|
||||||
|
text_parquet = pd.read_parquet(shard_parquet, engine="fastparquet")
|
||||||
|
path = text_parquet["path"].iloc[idx]
|
||||||
|
assert path == self.data["path"].iloc[index]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading {shard_parquet}: {e}")
|
||||||
|
raise
|
||||||
|
return text_parquet[key].iloc[idx]
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
index = self.index
|
||||||
|
ret = {}
|
||||||
|
ret.update(self.data.iloc[index].to_dict())
|
||||||
|
shard_idx = index // self.rows_per_shard
|
||||||
|
idx = index % self.rows_per_shard
|
||||||
|
shard_parquet = os.path.join(self.sharded_folder, self.sharded_folders[shard_idx])
|
||||||
|
try:
|
||||||
|
text_parquet = pd.read_parquet(shard_parquet, engine="fastparquet")
|
||||||
|
path = text_parquet["path"].iloc[idx]
|
||||||
|
assert path == self.data["path"].iloc[index]
|
||||||
|
ret.update(text_parquet.iloc[idx].to_dict())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading {shard_parquet}: {e}")
|
||||||
|
ret.update({"text": ""})
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class EfficientParquet:
|
||||||
|
def __init__(self, df, sharded_folder):
|
||||||
|
self.data = df
|
||||||
|
self.total_rows = len(df)
|
||||||
|
self.rows_per_shard = (self.total_rows + K - 1) // K
|
||||||
|
self.sharded_folder = sharded_folder
|
||||||
|
assert os.path.exists(sharded_folder), f"Sharded folder {sharded_folder} does not exist."
|
||||||
|
self.sharded_folders = os.listdir(sharded_folder)
|
||||||
|
self.sharded_folders = sorted(self.sharded_folders)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.total_rows
|
||||||
|
|
||||||
|
@property
|
||||||
|
def iloc(self):
|
||||||
|
return Iloc(self.data, self.sharded_folder, self.sharded_folders, self.rows_per_shard)
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module("text")
|
||||||
|
class TextDataset(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
Dataset for text data
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_path: str = None,
|
||||||
|
tokenize_fn: callable = None,
|
||||||
|
fps_max: int = 16,
|
||||||
|
vmaf: bool = False,
|
||||||
|
memory_efficient: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.data_path = data_path
|
||||||
|
self.data = read_file(data_path, memory_efficient=memory_efficient)
|
||||||
|
self.memory_efficient = memory_efficient
|
||||||
|
self.tokenize_fn = tokenize_fn
|
||||||
|
self.vmaf = vmaf
|
||||||
|
|
||||||
|
if fps_max is not None:
|
||||||
|
self.fps_max = fps_max
|
||||||
|
else:
|
||||||
|
self.fps_max = 999999999
|
||||||
|
|
||||||
|
def to_efficient(self):
|
||||||
|
if self.memory_efficient:
|
||||||
|
addition_data_path = self.data_path.split(".")[0]
|
||||||
|
self._data = self.data
|
||||||
|
self.data = EfficientParquet(self._data, addition_data_path)
|
||||||
|
|
||||||
|
def getitem(self, index: int) -> dict:
|
||||||
|
ret = dict()
|
||||||
|
sample = self.data.iloc[index].to_dict()
|
||||||
|
sample_fps = sample.get("fps", np.nan)
|
||||||
|
new_fps, sampling_interval = map_target_fps(sample_fps, self.fps_max)
|
||||||
|
ret.update({"sampling_interval": sampling_interval})
|
||||||
|
|
||||||
|
if "text" in sample:
|
||||||
|
ret["text"] = sample.pop("text")
|
||||||
|
postfixs = []
|
||||||
|
if new_fps != 0 and self.fps_max < 999:
|
||||||
|
postfixs.append(f"{new_fps} FPS")
|
||||||
|
if self.vmaf and "score_vmafmotion" in sample and not np.isnan(sample["score_vmafmotion"]):
|
||||||
|
postfixs.append(f"{int(sample['score_vmafmotion'] + 0.5)} motion score")
|
||||||
|
postfix = " " + ", ".join(postfixs) + "." if postfixs else ""
|
||||||
|
ret["text"] = ret["text"] + postfix
|
||||||
|
if self.tokenize_fn is not None:
|
||||||
|
ret.update({k: v.squeeze(0) for k, v in self.tokenize_fn(ret["text"]).items()})
|
||||||
|
|
||||||
|
if "ref" in sample: # i2v & v2v reference
|
||||||
|
ret["ref"] = sample.pop("ref")
|
||||||
|
|
||||||
|
# name of the generated sample
|
||||||
|
if "name" in sample: # sample name (`dataset_idx`)
|
||||||
|
ret["name"] = sample.pop("name")
|
||||||
|
else:
|
||||||
|
ret["index"] = index # use index for name
|
||||||
|
valid_sample = {k: v for k, v in sample.items() if k in VALID_KEYS}
|
||||||
|
ret.update(valid_sample)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.getitem(index)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module("video_text")
|
||||||
|
class VideoTextDataset(TextDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transform_name: str = None,
|
||||||
|
bucket_class: str = "Bucket",
|
||||||
|
rand_sample_interval: int = None, # random sample_interval value from [1, min(rand_sample_interval, video_allowed_max)]
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.transform_name = transform_name
|
||||||
|
self.bucket_class = bucket_class
|
||||||
|
self.rand_sample_interval = rand_sample_interval
|
||||||
|
|
||||||
|
def get_image(self, index: int, height: int, width: int) -> dict:
|
||||||
|
sample = self.data.iloc[index]
|
||||||
|
path = sample["path"]
|
||||||
|
# loading
|
||||||
|
image = pil_loader(path)
|
||||||
|
|
||||||
|
# transform
|
||||||
|
transform = get_transforms_image(self.transform_name, (height, width))
|
||||||
|
image = transform(image)
|
||||||
|
|
||||||
|
# CHW -> CTHW
|
||||||
|
video = image.unsqueeze(1)
|
||||||
|
|
||||||
|
return {"video": video}
|
||||||
|
|
||||||
|
def get_video(self, index: int, num_frames: int, height: int, width: int, sampling_interval: int) -> dict:
|
||||||
|
sample = self.data.iloc[index]
|
||||||
|
path = sample["path"]
|
||||||
|
|
||||||
|
# loading
|
||||||
|
vframes, vinfo = read_video(path, backend="av")
|
||||||
|
|
||||||
|
if self.rand_sample_interval is not None:
|
||||||
|
# randomly sample from 1 - self.rand_sample_interval
|
||||||
|
video_allowed_max = min(len(vframes) // num_frames, self.rand_sample_interval)
|
||||||
|
sampling_interval = random.randint(1, video_allowed_max)
|
||||||
|
|
||||||
|
# Sampling video frames
|
||||||
|
video = temporal_random_crop(vframes, num_frames, sampling_interval)
|
||||||
|
|
||||||
|
video = video.clone()
|
||||||
|
del vframes
|
||||||
|
|
||||||
|
# transform
|
||||||
|
transform = get_transforms_video(self.transform_name, (height, width))
|
||||||
|
video = transform(video) # T C H W
|
||||||
|
video = video.permute(1, 0, 2, 3)
|
||||||
|
|
||||||
|
ret = {"video": video}
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def get_image_or_video(self, index: int, num_frames: int, height: int, width: int, sampling_interval: int) -> dict:
|
||||||
|
sample = self.data.iloc[index]
|
||||||
|
path = sample["path"]
|
||||||
|
|
||||||
|
if is_img(path):
|
||||||
|
return self.get_image(index, height, width)
|
||||||
|
return self.get_video(index, num_frames, height, width, sampling_interval)
|
||||||
|
|
||||||
|
def getitem(self, index: str) -> dict:
|
||||||
|
# a hack to pass in the (time, height, width) info from sampler
|
||||||
|
index, num_frames, height, width = [int(val) for val in index.split("-")]
|
||||||
|
ret = dict()
|
||||||
|
ret.update(super().getitem(index))
|
||||||
|
try:
|
||||||
|
ret.update(self.get_image_or_video(index, num_frames, height, width, ret["sampling_interval"]))
|
||||||
|
except Exception as e:
|
||||||
|
path = self.data.iloc[index]["path"]
|
||||||
|
print(f"video {path}: {e}")
|
||||||
|
return None
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.getitem(index)
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module("cached_video_text")
|
||||||
|
class CachedVideoTextDataset(VideoTextDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transform_name: str = None,
|
||||||
|
bucket_class: str = "Bucket",
|
||||||
|
rand_sample_interval: int = None, # random sample_interval value from [1, min(rand_sample_interval, video_allowed_max)]
|
||||||
|
cached_video: bool = False,
|
||||||
|
cached_text: bool = False,
|
||||||
|
return_latents_path: bool = False,
|
||||||
|
load_original_video: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.transform_name = transform_name
|
||||||
|
self.bucket_class = bucket_class
|
||||||
|
self.rand_sample_interval = rand_sample_interval
|
||||||
|
self.cached_video = cached_video
|
||||||
|
self.cached_text = cached_text
|
||||||
|
self.return_latents_path = return_latents_path
|
||||||
|
self.load_original_video = load_original_video
|
||||||
|
|
||||||
|
def get_latents(self, path):
|
||||||
|
try:
|
||||||
|
latents = torch.load(path, map_location=torch.device("cpu"))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading latents from {path}: {e}")
|
||||||
|
return torch.zeros_like(torch.randn(1, 1, 1, 1))
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def get_conditioning_latents(self, index: int) -> dict:
|
||||||
|
sample = self.data.iloc[index]
|
||||||
|
latents_path = sample["latents_path"]
|
||||||
|
text_t5_path = sample["text_t5_path"]
|
||||||
|
text_clip_path = sample["text_clip_path"]
|
||||||
|
res = dict()
|
||||||
|
if self.cached_video:
|
||||||
|
latents = self.get_latents(latents_path)
|
||||||
|
res["video_latents"] = latents
|
||||||
|
if self.cached_text:
|
||||||
|
text_t5 = self.get_latents(text_t5_path)
|
||||||
|
text_clip = self.get_latents(text_clip_path)
|
||||||
|
res["text_t5"] = text_t5
|
||||||
|
res["text_clip"] = text_clip
|
||||||
|
if self.return_latents_path:
|
||||||
|
res["latents_path"] = latents_path
|
||||||
|
res["text_t5_path"] = text_t5_path
|
||||||
|
res["text_clip_path"] = text_clip_path
|
||||||
|
return res
|
||||||
|
|
||||||
|
def getitem(self, index: str) -> dict:
|
||||||
|
# a hack to pass in the (time, height, width) info from sampler
|
||||||
|
real_index, num_frames, height, width = [int(val) for val in index.split("-")]
|
||||||
|
ret = dict()
|
||||||
|
if self.load_original_video:
|
||||||
|
ret.update(super().getitem(index))
|
||||||
|
try:
|
||||||
|
ret.update(self.get_conditioning_latents(real_index))
|
||||||
|
except Exception as e:
|
||||||
|
path = self.data.iloc[real_index]["path"]
|
||||||
|
print(f"video {path}: {e}")
|
||||||
|
return None
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.getitem(index)
|
||||||
|
|
@ -0,0 +1,176 @@
|
||||||
|
import multiprocessing
|
||||||
|
from itertools import count
|
||||||
|
from multiprocessing.managers import SyncManager
|
||||||
|
from typing import Any, Callable, Dict, Tuple, Type, cast
|
||||||
|
|
||||||
|
import dill
|
||||||
|
import pandarallel
|
||||||
|
import pandas as pd
|
||||||
|
from pandarallel.data_types import DataType
|
||||||
|
from pandarallel.progress_bars import ProgressBarsType, get_progress_bars, progress_wrapper
|
||||||
|
from pandarallel.utils import WorkerStatus
|
||||||
|
|
||||||
|
CONTEXT = multiprocessing.get_context("fork")
|
||||||
|
TMP = []
|
||||||
|
|
||||||
|
|
||||||
|
class WrapWorkFunctionForPipe:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
work_function: Callable[
|
||||||
|
[
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
tuple,
|
||||||
|
Dict[str, Any],
|
||||||
|
Dict[str, Any],
|
||||||
|
],
|
||||||
|
Any,
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
self.work_function = work_function
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
progress_bars_type: ProgressBarsType,
|
||||||
|
worker_index: int,
|
||||||
|
master_workers_queue: multiprocessing.Queue,
|
||||||
|
dilled_user_defined_function: bytes,
|
||||||
|
user_defined_function_args: tuple,
|
||||||
|
user_defined_function_kwargs: Dict[str, Any],
|
||||||
|
extra: Dict[str, Any],
|
||||||
|
) -> Any:
|
||||||
|
try:
|
||||||
|
data = TMP[worker_index]
|
||||||
|
data_size = len(data)
|
||||||
|
user_defined_function: Callable = dill.loads(dilled_user_defined_function)
|
||||||
|
|
||||||
|
progress_wrapped_user_defined_function = progress_wrapper(
|
||||||
|
user_defined_function, master_workers_queue, worker_index, data_size
|
||||||
|
)
|
||||||
|
|
||||||
|
used_user_defined_function = (
|
||||||
|
progress_wrapped_user_defined_function
|
||||||
|
if progress_bars_type
|
||||||
|
in (
|
||||||
|
ProgressBarsType.InUserDefinedFunction,
|
||||||
|
ProgressBarsType.InUserDefinedFunctionMultiplyByNumberOfColumns,
|
||||||
|
)
|
||||||
|
else user_defined_function
|
||||||
|
)
|
||||||
|
|
||||||
|
results = self.work_function(
|
||||||
|
data,
|
||||||
|
used_user_defined_function,
|
||||||
|
user_defined_function_args,
|
||||||
|
user_defined_function_kwargs,
|
||||||
|
extra,
|
||||||
|
)
|
||||||
|
|
||||||
|
master_workers_queue.put((worker_index, WorkerStatus.Success, None))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except:
|
||||||
|
master_workers_queue.put((worker_index, WorkerStatus.Error, None))
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def parallelize_with_pipe(
|
||||||
|
nb_requested_workers: int,
|
||||||
|
data_type: Type[DataType],
|
||||||
|
progress_bars_type: ProgressBarsType,
|
||||||
|
):
|
||||||
|
def closure(
|
||||||
|
data: Any,
|
||||||
|
user_defined_function: Callable,
|
||||||
|
*user_defined_function_args: tuple,
|
||||||
|
**user_defined_function_kwargs: Dict[str, Any],
|
||||||
|
):
|
||||||
|
wrapped_work_function = WrapWorkFunctionForPipe(data_type.work)
|
||||||
|
dilled_user_defined_function = dill.dumps(user_defined_function)
|
||||||
|
manager: SyncManager = CONTEXT.Manager()
|
||||||
|
master_workers_queue = manager.Queue()
|
||||||
|
|
||||||
|
chunks = list(
|
||||||
|
data_type.get_chunks(
|
||||||
|
nb_requested_workers,
|
||||||
|
data,
|
||||||
|
user_defined_function_kwargs=user_defined_function_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
TMP.extend(chunks)
|
||||||
|
|
||||||
|
nb_workers = len(chunks)
|
||||||
|
|
||||||
|
multiplicator_factor = (
|
||||||
|
len(cast(pd.DataFrame, data).columns)
|
||||||
|
if progress_bars_type == ProgressBarsType.InUserDefinedFunctionMultiplyByNumberOfColumns
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
progresses_length = [len(chunk_) * multiplicator_factor for chunk_ in chunks]
|
||||||
|
|
||||||
|
work_extra = data_type.get_work_extra(data)
|
||||||
|
reduce_extra = data_type.get_reduce_extra(data, user_defined_function_kwargs)
|
||||||
|
|
||||||
|
show_progress_bars = progress_bars_type != ProgressBarsType.No
|
||||||
|
|
||||||
|
progress_bars = get_progress_bars(progresses_length, show_progress_bars)
|
||||||
|
progresses = [0] * nb_workers
|
||||||
|
workers_status = [WorkerStatus.Running] * nb_workers
|
||||||
|
|
||||||
|
work_args_list = [
|
||||||
|
(
|
||||||
|
progress_bars_type,
|
||||||
|
worker_index,
|
||||||
|
master_workers_queue,
|
||||||
|
dilled_user_defined_function,
|
||||||
|
user_defined_function_args,
|
||||||
|
user_defined_function_kwargs,
|
||||||
|
{
|
||||||
|
**work_extra,
|
||||||
|
**{
|
||||||
|
"master_workers_queue": master_workers_queue,
|
||||||
|
"show_progress_bars": show_progress_bars,
|
||||||
|
"worker_index": worker_index,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for worker_index in range(nb_workers)
|
||||||
|
]
|
||||||
|
|
||||||
|
pool = CONTEXT.Pool(nb_workers)
|
||||||
|
results_promise = pool.starmap_async(wrapped_work_function, work_args_list)
|
||||||
|
pool.close()
|
||||||
|
|
||||||
|
generation = count()
|
||||||
|
|
||||||
|
while any((worker_status == WorkerStatus.Running for worker_status in workers_status)):
|
||||||
|
message: Tuple[int, WorkerStatus, Any] = master_workers_queue.get()
|
||||||
|
worker_index, worker_status, payload = message
|
||||||
|
workers_status[worker_index] = worker_status
|
||||||
|
|
||||||
|
if worker_status == WorkerStatus.Success:
|
||||||
|
progresses[worker_index] = progresses_length[worker_index]
|
||||||
|
progress_bars.update(progresses)
|
||||||
|
elif worker_status == WorkerStatus.Running:
|
||||||
|
progress = cast(int, payload)
|
||||||
|
progresses[worker_index] = progress
|
||||||
|
|
||||||
|
if next(generation) % nb_workers == 0:
|
||||||
|
progress_bars.update(progresses)
|
||||||
|
elif worker_status == WorkerStatus.Error:
|
||||||
|
progress_bars.set_error(worker_index)
|
||||||
|
|
||||||
|
results = results_promise.get()
|
||||||
|
TMP.clear()
|
||||||
|
|
||||||
|
return data_type.reduce(results, reduce_extra)
|
||||||
|
|
||||||
|
return closure
|
||||||
|
|
||||||
|
|
||||||
|
pandarallel.core.WrapWorkFunctionForPipe = WrapWorkFunctionForPipe
|
||||||
|
pandarallel.core.parallelize_with_pipe = parallelize_with_pipe
|
||||||
|
pandarallel = pandarallel.pandarallel
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
import threading
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class PinMemoryCache:
|
||||||
|
force_dtype: Optional[torch.dtype] = None
|
||||||
|
min_cache_numel: int = 0
|
||||||
|
pre_alloc_numels: List[int] = []
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.cache: Dict[int, torch.Tensor] = {}
|
||||||
|
self.output_to_cache: Dict[int, int] = {}
|
||||||
|
self.cache_to_output: Dict[int, int] = {}
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.total_cnt = 0
|
||||||
|
self.hit_cnt = 0
|
||||||
|
|
||||||
|
if len(self.pre_alloc_numels) > 0 and self.force_dtype is not None:
|
||||||
|
for n in self.pre_alloc_numels:
|
||||||
|
cache_tensor = torch.empty(n, dtype=self.force_dtype, device="cpu", pin_memory=True)
|
||||||
|
with self.lock:
|
||||||
|
self.cache[id(cache_tensor)] = cache_tensor
|
||||||
|
|
||||||
|
def get(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Receive a cpu tensor and return the corresponding pinned tensor. Note that this only manage memory allocation, doesn't copy content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The tensor to be pinned.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The pinned tensor.
|
||||||
|
"""
|
||||||
|
self.total_cnt += 1
|
||||||
|
with self.lock:
|
||||||
|
# find free cache
|
||||||
|
for cache_id, cache_tensor in self.cache.items():
|
||||||
|
if cache_id not in self.cache_to_output and cache_tensor.numel() >= tensor.numel():
|
||||||
|
target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape)
|
||||||
|
out_id = id(target_cache_tensor)
|
||||||
|
self.output_to_cache[out_id] = cache_id
|
||||||
|
self.cache_to_output[cache_id] = out_id
|
||||||
|
self.hit_cnt += 1
|
||||||
|
return target_cache_tensor
|
||||||
|
# no free cache, create a new one
|
||||||
|
dtype = self.force_dtype if self.force_dtype is not None else tensor.dtype
|
||||||
|
cache_numel = max(tensor.numel(), self.min_cache_numel)
|
||||||
|
cache_tensor = torch.empty(cache_numel, dtype=dtype, device="cpu", pin_memory=True)
|
||||||
|
target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape)
|
||||||
|
out_id = id(target_cache_tensor)
|
||||||
|
with self.lock:
|
||||||
|
self.cache[id(cache_tensor)] = cache_tensor
|
||||||
|
self.output_to_cache[out_id] = id(cache_tensor)
|
||||||
|
self.cache_to_output[id(cache_tensor)] = out_id
|
||||||
|
return target_cache_tensor
|
||||||
|
|
||||||
|
def remove(self, output_tensor: torch.Tensor) -> None:
|
||||||
|
"""Release corresponding cache tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_tensor (torch.Tensor): The tensor to be released.
|
||||||
|
"""
|
||||||
|
out_id = id(output_tensor)
|
||||||
|
with self.lock:
|
||||||
|
if out_id not in self.output_to_cache:
|
||||||
|
raise ValueError("Tensor not found in cache.")
|
||||||
|
cache_id = self.output_to_cache.pop(out_id)
|
||||||
|
del self.cache_to_output[cache_id]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
with self.lock:
|
||||||
|
num_cached = len(self.cache)
|
||||||
|
num_used = len(self.output_to_cache)
|
||||||
|
total_cache_size = sum([v.numel() * v.element_size() for v in self.cache.values()])
|
||||||
|
return f"PinMemoryCache(num_cached={num_cached}, num_used={num_used}, total_cache_size={total_cache_size / 1024**3:.2f} GB, hit rate={self.hit_cnt / self.total_cnt:.2f})"
|
||||||
|
|
@ -0,0 +1,257 @@
|
||||||
|
import gc
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import warnings
|
||||||
|
from fractions import Fraction
|
||||||
|
|
||||||
|
import av
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torchvision import get_video_backend
|
||||||
|
from torchvision.io.video import _check_av_available
|
||||||
|
|
||||||
|
MAX_NUM_FRAMES = 2500
|
||||||
|
|
||||||
|
|
||||||
|
def read_video_av(
|
||||||
|
filename: str,
|
||||||
|
start_pts: float | Fraction = 0,
|
||||||
|
end_pts: float | Fraction | None = None,
|
||||||
|
pts_unit: str = "pts",
|
||||||
|
output_format: str = "THWC",
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, dict]:
|
||||||
|
"""
|
||||||
|
Reads a video from a file, returning both the video frames and the audio frames
|
||||||
|
|
||||||
|
This method is modified from torchvision.io.video.read_video, with the following changes:
|
||||||
|
|
||||||
|
1. will not extract audio frames and return empty for aframes
|
||||||
|
2. remove checks and only support pyav
|
||||||
|
3. add container.close() and gc.collect() to avoid thread leakage
|
||||||
|
4. try our best to avoid memory leak
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): path to the video file
|
||||||
|
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
||||||
|
The start presentation time of the video
|
||||||
|
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
||||||
|
The end presentation time
|
||||||
|
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
|
||||||
|
either 'pts' or 'sec'. Defaults to 'pts'.
|
||||||
|
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
|
||||||
|
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
|
||||||
|
info (dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
|
||||||
|
"""
|
||||||
|
# format
|
||||||
|
output_format = output_format.upper()
|
||||||
|
if output_format not in ("THWC", "TCHW"):
|
||||||
|
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
|
||||||
|
# file existence
|
||||||
|
if not os.path.exists(filename):
|
||||||
|
raise RuntimeError(f"File not found: {filename}")
|
||||||
|
# backend check
|
||||||
|
assert get_video_backend() == "pyav", "pyav backend is required for read_video_av"
|
||||||
|
_check_av_available()
|
||||||
|
# end_pts check
|
||||||
|
if end_pts is None:
|
||||||
|
end_pts = float("inf")
|
||||||
|
if end_pts < start_pts:
|
||||||
|
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
|
||||||
|
|
||||||
|
# == get video info ==
|
||||||
|
info = {}
|
||||||
|
# TODO: creating an container leads to memory leak (1G for 8 workers 1 GPU)
|
||||||
|
container = av.open(filename, metadata_errors="ignore")
|
||||||
|
# fps
|
||||||
|
video_fps = container.streams.video[0].average_rate
|
||||||
|
# guard against potentially corrupted files
|
||||||
|
if video_fps is not None:
|
||||||
|
info["video_fps"] = float(video_fps)
|
||||||
|
iter_video = container.decode(**{"video": 0})
|
||||||
|
frame = next(iter_video).to_rgb().to_ndarray()
|
||||||
|
height, width = frame.shape[:2]
|
||||||
|
total_frames = container.streams.video[0].frames
|
||||||
|
if total_frames == 0:
|
||||||
|
total_frames = MAX_NUM_FRAMES
|
||||||
|
warnings.warn(f"total_frames is 0, using {MAX_NUM_FRAMES} as a fallback")
|
||||||
|
container.close()
|
||||||
|
del container
|
||||||
|
|
||||||
|
# HACK: must create before iterating stream
|
||||||
|
# use np.zeros will not actually allocate memory
|
||||||
|
# use np.ones will lead to a little memory leak
|
||||||
|
video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# == read ==
|
||||||
|
try:
|
||||||
|
# TODO: The reading has memory leak (4G for 8 workers 1 GPU)
|
||||||
|
container = av.open(filename, metadata_errors="ignore")
|
||||||
|
assert container.streams.video is not None
|
||||||
|
video_frames = _read_from_stream(
|
||||||
|
video_frames,
|
||||||
|
container,
|
||||||
|
start_pts,
|
||||||
|
end_pts,
|
||||||
|
pts_unit,
|
||||||
|
container.streams.video[0],
|
||||||
|
{"video": 0},
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
except av.AVError as e:
|
||||||
|
print(f"[Warning] Error while reading video {filename}: {e}")
|
||||||
|
|
||||||
|
vframes = torch.from_numpy(video_frames).clone()
|
||||||
|
del video_frames
|
||||||
|
if output_format == "TCHW":
|
||||||
|
# [T,H,W,C] --> [T,C,H,W]
|
||||||
|
vframes = vframes.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
aframes = torch.empty((1, 0), dtype=torch.float32)
|
||||||
|
return vframes, aframes, info
|
||||||
|
|
||||||
|
|
||||||
|
def _read_from_stream(
|
||||||
|
video_frames,
|
||||||
|
container: "av.container.Container",
|
||||||
|
start_offset: float,
|
||||||
|
end_offset: float,
|
||||||
|
pts_unit: str,
|
||||||
|
stream: "av.stream.Stream",
|
||||||
|
stream_name: dict[str, int | tuple[int, ...] | list[int] | None],
|
||||||
|
filename: str | None = None,
|
||||||
|
) -> list["av.frame.Frame"]:
|
||||||
|
if pts_unit == "sec":
|
||||||
|
# TODO: we should change all of this from ground up to simply take
|
||||||
|
# sec and convert to MS in C++
|
||||||
|
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
|
||||||
|
if end_offset != float("inf"):
|
||||||
|
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
|
||||||
|
else:
|
||||||
|
warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
|
||||||
|
|
||||||
|
should_buffer = True
|
||||||
|
max_buffer_size = 5
|
||||||
|
if stream.type == "video":
|
||||||
|
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
|
||||||
|
# so need to buffer some extra frames to sort everything
|
||||||
|
# properly
|
||||||
|
extradata = stream.codec_context.extradata
|
||||||
|
# overly complicated way of finding if `divx_packed` is set, following
|
||||||
|
# https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
|
||||||
|
if extradata and b"DivX" in extradata:
|
||||||
|
# can't use regex directly because of some weird characters sometimes...
|
||||||
|
pos = extradata.find(b"DivX")
|
||||||
|
d = extradata[pos:]
|
||||||
|
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
|
||||||
|
if o is None:
|
||||||
|
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
|
||||||
|
if o is not None:
|
||||||
|
should_buffer = o.group(3) == b"p"
|
||||||
|
seek_offset = start_offset
|
||||||
|
# some files don't seek to the right location, so better be safe here
|
||||||
|
seek_offset = max(seek_offset - 1, 0)
|
||||||
|
if should_buffer:
|
||||||
|
# FIXME this is kind of a hack, but we will jump to the previous keyframe
|
||||||
|
# so this will be safe
|
||||||
|
seek_offset = max(seek_offset - max_buffer_size, 0)
|
||||||
|
try:
|
||||||
|
# TODO check if stream needs to always be the video stream here or not
|
||||||
|
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
|
||||||
|
except av.AVError as e:
|
||||||
|
print(f"[Warning] Error while seeking video {filename}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# == main ==
|
||||||
|
buffer_count = 0
|
||||||
|
frames_pts = []
|
||||||
|
cnt = 0
|
||||||
|
try:
|
||||||
|
for _idx, frame in enumerate(container.decode(**stream_name)):
|
||||||
|
frames_pts.append(frame.pts)
|
||||||
|
video_frames[cnt] = frame.to_rgb().to_ndarray()
|
||||||
|
cnt += 1
|
||||||
|
if cnt >= len(video_frames):
|
||||||
|
break
|
||||||
|
if frame.pts >= end_offset:
|
||||||
|
if should_buffer and buffer_count < max_buffer_size:
|
||||||
|
buffer_count += 1
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
except av.AVError as e:
|
||||||
|
print(f"[Warning] Error while reading video {filename}: {e}")
|
||||||
|
|
||||||
|
# garbage collection for thread leakage
|
||||||
|
container.close()
|
||||||
|
del container
|
||||||
|
# NOTE: manually garbage collect to close pyav threads
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# ensure that the results are sorted wrt the pts
|
||||||
|
# NOTE: here we assert frames_pts is sorted
|
||||||
|
start_ptr = 0
|
||||||
|
end_ptr = cnt
|
||||||
|
while start_ptr < end_ptr and frames_pts[start_ptr] < start_offset:
|
||||||
|
start_ptr += 1
|
||||||
|
while start_ptr < end_ptr and frames_pts[end_ptr - 1] > end_offset:
|
||||||
|
end_ptr -= 1
|
||||||
|
if start_offset > 0 and start_offset not in frames_pts[start_ptr:end_ptr]:
|
||||||
|
# if there is no frame that exactly matches the pts of start_offset
|
||||||
|
# add the last frame smaller than start_offset, to guarantee that
|
||||||
|
# we will have all the necessary data. This is most useful for audio
|
||||||
|
if start_ptr > 0:
|
||||||
|
start_ptr -= 1
|
||||||
|
result = video_frames[start_ptr:end_ptr].copy()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def read_video_cv2(video_path):
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
|
||||||
|
if not cap.isOpened():
|
||||||
|
# print("Error: Unable to open video")
|
||||||
|
raise ValueError
|
||||||
|
else:
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
|
vinfo = {
|
||||||
|
"video_fps": fps,
|
||||||
|
}
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
while True:
|
||||||
|
# Read a frame from the video
|
||||||
|
ret, frame = cap.read()
|
||||||
|
|
||||||
|
# If frame is not read correctly, break the loop
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
frames.append(frame[:, :, ::-1]) # BGR to RGB
|
||||||
|
|
||||||
|
# Exit if 'q' is pressed
|
||||||
|
if cv2.waitKey(25) & 0xFF == ord("q"):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Release the video capture object and close all windows
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
frames = np.stack(frames)
|
||||||
|
frames = torch.from_numpy(frames) # [T, H, W, C=3]
|
||||||
|
frames = frames.permute(0, 3, 1, 2)
|
||||||
|
return frames, vinfo
|
||||||
|
|
||||||
|
|
||||||
|
def read_video(video_path, backend="av"):
|
||||||
|
if backend == "cv2":
|
||||||
|
vframes, vinfo = read_video_cv2(video_path)
|
||||||
|
elif backend == "av":
|
||||||
|
vframes, _, vinfo = read_video_av(filename=video_path, pts_unit="sec", output_format="TCHW")
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
return vframes, vinfo
|
||||||
|
|
@ -0,0 +1,393 @@
|
||||||
|
from collections import OrderedDict, defaultdict
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.utils.data import Dataset, DistributedSampler
|
||||||
|
|
||||||
|
from opensora.utils.logger import log_message
|
||||||
|
from opensora.utils.misc import format_numel_str
|
||||||
|
|
||||||
|
from .aspect import get_num_pexels_from_name
|
||||||
|
from .bucket import Bucket
|
||||||
|
from .datasets import VideoTextDataset
|
||||||
|
from .parallel import pandarallel
|
||||||
|
from .utils import sync_object_across_devices
|
||||||
|
|
||||||
|
|
||||||
|
# use pandarallel to accelerate bucket processing
|
||||||
|
# NOTE: pandarallel should only access local variables
|
||||||
|
def apply(data, method=None, seed=None, num_bucket=None, fps_max=16):
|
||||||
|
return method(
|
||||||
|
data["num_frames"],
|
||||||
|
data["height"],
|
||||||
|
data["width"],
|
||||||
|
data["fps"],
|
||||||
|
data["path"],
|
||||||
|
seed + data["id"] * num_bucket,
|
||||||
|
fps_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StatefulDistributedSampler(DistributedSampler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: Dataset,
|
||||||
|
num_replicas: int | None = None,
|
||||||
|
rank: int | None = None,
|
||||||
|
shuffle: bool = True,
|
||||||
|
seed: int = 0,
|
||||||
|
drop_last: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
||||||
|
self.start_index: int = 0
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator:
|
||||||
|
iterator = super().__iter__()
|
||||||
|
indices = list(iterator)
|
||||||
|
indices = indices[self.start_index :]
|
||||||
|
return iter(indices)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.num_samples - self.start_index
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.start_index = 0
|
||||||
|
|
||||||
|
def state_dict(self, step) -> dict:
|
||||||
|
return {"start_index": step}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: dict) -> None:
|
||||||
|
self.__dict__.update(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class VariableVideoBatchSampler(DistributedSampler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: VideoTextDataset,
|
||||||
|
bucket_config: dict,
|
||||||
|
num_replicas: int | None = None,
|
||||||
|
rank: int | None = None,
|
||||||
|
shuffle: bool = True,
|
||||||
|
seed: int = 0,
|
||||||
|
drop_last: bool = False,
|
||||||
|
verbose: bool = False,
|
||||||
|
num_bucket_build_workers: int = 1,
|
||||||
|
num_groups: int = 1,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last
|
||||||
|
)
|
||||||
|
self.dataset = dataset
|
||||||
|
assert dataset.bucket_class == "Bucket", "Only support Bucket class for now"
|
||||||
|
self.bucket = Bucket(bucket_config)
|
||||||
|
self.verbose = verbose
|
||||||
|
self.last_micro_batch_access_index = 0
|
||||||
|
self.num_bucket_build_workers = num_bucket_build_workers
|
||||||
|
self._cached_bucket_sample_dict = None
|
||||||
|
self._cached_num_total_batch = None
|
||||||
|
self.num_groups = num_groups
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
pandarallel.initialize(
|
||||||
|
nb_workers=self.num_bucket_build_workers,
|
||||||
|
progress_bar=False,
|
||||||
|
verbose=0,
|
||||||
|
use_memory_fs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[list[int]]:
|
||||||
|
bucket_sample_dict, _ = self.group_by_bucket()
|
||||||
|
self.clear_cache()
|
||||||
|
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(self.seed + self.epoch)
|
||||||
|
bucket_micro_batch_count = OrderedDict()
|
||||||
|
bucket_last_consumed = OrderedDict()
|
||||||
|
|
||||||
|
# process the samples
|
||||||
|
for bucket_id, data_list in bucket_sample_dict.items():
|
||||||
|
# handle droplast
|
||||||
|
bs_per_gpu = self.bucket.get_batch_size(bucket_id)
|
||||||
|
remainder = len(data_list) % bs_per_gpu
|
||||||
|
|
||||||
|
if remainder > 0:
|
||||||
|
if not self.drop_last:
|
||||||
|
# if there is remainder, we pad to make it divisible
|
||||||
|
data_list += data_list[: bs_per_gpu - remainder]
|
||||||
|
else:
|
||||||
|
# we just drop the remainder to make it divisible
|
||||||
|
data_list = data_list[:-remainder]
|
||||||
|
bucket_sample_dict[bucket_id] = data_list
|
||||||
|
|
||||||
|
# handle shuffle
|
||||||
|
if self.shuffle:
|
||||||
|
data_indices = torch.randperm(len(data_list), generator=g).tolist()
|
||||||
|
data_list = [data_list[i] for i in data_indices]
|
||||||
|
bucket_sample_dict[bucket_id] = data_list
|
||||||
|
|
||||||
|
# compute how many micro-batches each bucket has
|
||||||
|
num_micro_batches = len(data_list) // bs_per_gpu
|
||||||
|
bucket_micro_batch_count[bucket_id] = num_micro_batches
|
||||||
|
|
||||||
|
# compute the bucket access order
|
||||||
|
# each bucket may have more than one batch of data
|
||||||
|
# thus bucket_id may appear more than 1 time
|
||||||
|
bucket_id_access_order = []
|
||||||
|
for bucket_id, num_micro_batch in bucket_micro_batch_count.items():
|
||||||
|
bucket_id_access_order.extend([bucket_id] * num_micro_batch)
|
||||||
|
|
||||||
|
# randomize the access order
|
||||||
|
if self.shuffle:
|
||||||
|
bucket_id_access_order_indices = torch.randperm(len(bucket_id_access_order), generator=g).tolist()
|
||||||
|
bucket_id_access_order = [bucket_id_access_order[i] for i in bucket_id_access_order_indices]
|
||||||
|
|
||||||
|
# make the number of bucket accesses divisible by dp size
|
||||||
|
remainder = len(bucket_id_access_order) % self.num_replicas
|
||||||
|
if remainder > 0:
|
||||||
|
if self.drop_last:
|
||||||
|
bucket_id_access_order = bucket_id_access_order[: len(bucket_id_access_order) - remainder]
|
||||||
|
else:
|
||||||
|
bucket_id_access_order += bucket_id_access_order[: self.num_replicas - remainder]
|
||||||
|
|
||||||
|
# prepare each batch from its bucket
|
||||||
|
# according to the predefined bucket access order
|
||||||
|
num_iters = len(bucket_id_access_order) // self.num_replicas
|
||||||
|
start_iter_idx = self.last_micro_batch_access_index // self.num_replicas
|
||||||
|
|
||||||
|
# re-compute the micro-batch consumption
|
||||||
|
# this is useful when resuming from a state dict with a different number of GPUs
|
||||||
|
self.last_micro_batch_access_index = start_iter_idx * self.num_replicas
|
||||||
|
for i in range(self.last_micro_batch_access_index):
|
||||||
|
bucket_id = bucket_id_access_order[i]
|
||||||
|
bucket_bs = self.bucket.get_batch_size(bucket_id)
|
||||||
|
if bucket_id in bucket_last_consumed:
|
||||||
|
bucket_last_consumed[bucket_id] += bucket_bs
|
||||||
|
else:
|
||||||
|
bucket_last_consumed[bucket_id] = bucket_bs
|
||||||
|
|
||||||
|
for i in range(start_iter_idx, num_iters):
|
||||||
|
bucket_access_list = bucket_id_access_order[i * self.num_replicas : (i + 1) * self.num_replicas]
|
||||||
|
self.last_micro_batch_access_index += self.num_replicas
|
||||||
|
|
||||||
|
# compute the data samples consumed by each access
|
||||||
|
bucket_access_boundaries = []
|
||||||
|
for bucket_id in bucket_access_list:
|
||||||
|
bucket_bs = self.bucket.get_batch_size(bucket_id)
|
||||||
|
last_consumed_index = bucket_last_consumed.get(bucket_id, 0)
|
||||||
|
bucket_access_boundaries.append([last_consumed_index, last_consumed_index + bucket_bs])
|
||||||
|
|
||||||
|
# update consumption
|
||||||
|
if bucket_id in bucket_last_consumed:
|
||||||
|
bucket_last_consumed[bucket_id] += bucket_bs
|
||||||
|
else:
|
||||||
|
bucket_last_consumed[bucket_id] = bucket_bs
|
||||||
|
|
||||||
|
# compute the range of data accessed by each GPU
|
||||||
|
bucket_id = bucket_access_list[self.rank]
|
||||||
|
boundary = bucket_access_boundaries[self.rank]
|
||||||
|
cur_micro_batch = bucket_sample_dict[bucket_id][boundary[0] : boundary[1]]
|
||||||
|
|
||||||
|
# encode t, h, w into the sample index
|
||||||
|
real_t, real_h, real_w = self.bucket.get_thw(bucket_id)
|
||||||
|
cur_micro_batch = [f"{idx}-{real_t}-{real_h}-{real_w}" for idx in cur_micro_batch]
|
||||||
|
yield cur_micro_batch
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.get_num_batch() // self.num_groups
|
||||||
|
|
||||||
|
def get_num_batch(self) -> int:
|
||||||
|
_, num_total_batch = self.group_by_bucket()
|
||||||
|
return num_total_batch
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
self._cached_bucket_sample_dict = None
|
||||||
|
self._cached_num_total_batch = 0
|
||||||
|
|
||||||
|
def group_by_bucket(self) -> dict:
|
||||||
|
"""
|
||||||
|
Group the dataset samples into buckets.
|
||||||
|
This method will set `self._cached_bucket_sample_dict` to the bucket sample dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: a dictionary with bucket id as key and a list of sample indices as value
|
||||||
|
"""
|
||||||
|
if self._cached_bucket_sample_dict is not None:
|
||||||
|
return self._cached_bucket_sample_dict, self._cached_num_total_batch
|
||||||
|
|
||||||
|
# use pandarallel to accelerate bucket processing
|
||||||
|
log_message("Building buckets using %d workers...", self.num_bucket_build_workers)
|
||||||
|
bucket_ids = None
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
data = self.dataset.data.copy(deep=True)
|
||||||
|
data["id"] = data.index
|
||||||
|
bucket_ids = data.parallel_apply(
|
||||||
|
apply,
|
||||||
|
axis=1,
|
||||||
|
method=self.bucket.get_bucket_id,
|
||||||
|
seed=self.seed + self.epoch,
|
||||||
|
num_bucket=self.bucket.num_bucket,
|
||||||
|
fps_max=self.dataset.fps_max,
|
||||||
|
)
|
||||||
|
dist.barrier()
|
||||||
|
bucket_ids = sync_object_across_devices(bucket_ids)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
# group by bucket
|
||||||
|
# each data sample is put into a bucket with a similar image/video size
|
||||||
|
bucket_sample_dict = defaultdict(list)
|
||||||
|
bucket_ids_np = np.array(bucket_ids)
|
||||||
|
valid_indices = np.where(bucket_ids_np != None)[0]
|
||||||
|
for i in valid_indices:
|
||||||
|
bucket_sample_dict[bucket_ids_np[i]].append(i)
|
||||||
|
|
||||||
|
# cache the bucket sample dict
|
||||||
|
self._cached_bucket_sample_dict = bucket_sample_dict
|
||||||
|
|
||||||
|
# num total batch
|
||||||
|
num_total_batch = self.print_bucket_info(bucket_sample_dict)
|
||||||
|
self._cached_num_total_batch = num_total_batch
|
||||||
|
|
||||||
|
return bucket_sample_dict, num_total_batch
|
||||||
|
|
||||||
|
def print_bucket_info(self, bucket_sample_dict: dict) -> int:
|
||||||
|
# collect statistics
|
||||||
|
num_total_samples = num_total_batch = 0
|
||||||
|
num_total_img_samples = num_total_vid_samples = 0
|
||||||
|
num_total_img_batch = num_total_vid_batch = 0
|
||||||
|
num_total_vid_batch_256 = num_total_vid_batch_768 = 0
|
||||||
|
num_aspect_dict = defaultdict(lambda: [0, 0])
|
||||||
|
num_hwt_dict = defaultdict(lambda: [0, 0])
|
||||||
|
for k, v in bucket_sample_dict.items():
|
||||||
|
size = len(v)
|
||||||
|
num_batch = size // self.bucket.get_batch_size(k[:-1])
|
||||||
|
|
||||||
|
num_total_samples += size
|
||||||
|
num_total_batch += num_batch
|
||||||
|
|
||||||
|
if k[1] == 1:
|
||||||
|
num_total_img_samples += size
|
||||||
|
num_total_img_batch += num_batch
|
||||||
|
else:
|
||||||
|
if k[0] == "256px":
|
||||||
|
num_total_vid_batch_256 += num_batch
|
||||||
|
elif k[0] == "768px":
|
||||||
|
num_total_vid_batch_768 += num_batch
|
||||||
|
num_total_vid_samples += size
|
||||||
|
num_total_vid_batch += num_batch
|
||||||
|
|
||||||
|
num_aspect_dict[k[-1]][0] += size
|
||||||
|
num_aspect_dict[k[-1]][1] += num_batch
|
||||||
|
num_hwt_dict[k[:-1]][0] += size
|
||||||
|
num_hwt_dict[k[:-1]][1] += num_batch
|
||||||
|
|
||||||
|
# sort
|
||||||
|
num_aspect_dict = dict(sorted(num_aspect_dict.items(), key=lambda x: x[0]))
|
||||||
|
num_hwt_dict = dict(
|
||||||
|
sorted(num_hwt_dict.items(), key=lambda x: (get_num_pexels_from_name(x[0][0]), x[0][1]), reverse=True)
|
||||||
|
)
|
||||||
|
num_hwt_img_dict = {k: v for k, v in num_hwt_dict.items() if k[1] == 1}
|
||||||
|
num_hwt_vid_dict = {k: v for k, v in num_hwt_dict.items() if k[1] > 1}
|
||||||
|
|
||||||
|
# log
|
||||||
|
if dist.get_rank() == 0 and self.verbose:
|
||||||
|
log_message("Bucket Info:")
|
||||||
|
log_message("Bucket [#sample, #batch] by aspect ratio:")
|
||||||
|
for k, v in num_aspect_dict.items():
|
||||||
|
log_message("(%s): #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
|
||||||
|
log_message("===== Image Info =====")
|
||||||
|
log_message("Image Bucket by HxWxT:")
|
||||||
|
for k, v in num_hwt_img_dict.items():
|
||||||
|
log_message("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
|
||||||
|
log_message("--------------------------------")
|
||||||
|
log_message(
|
||||||
|
"#image sample: %s, #image batch: %s",
|
||||||
|
format_numel_str(num_total_img_samples),
|
||||||
|
format_numel_str(num_total_img_batch),
|
||||||
|
)
|
||||||
|
log_message("===== Video Info =====")
|
||||||
|
log_message("Video Bucket by HxWxT:")
|
||||||
|
for k, v in num_hwt_vid_dict.items():
|
||||||
|
log_message("%s: #sample: %s, #batch: %s", k, format_numel_str(v[0]), format_numel_str(v[1]))
|
||||||
|
log_message("--------------------------------")
|
||||||
|
log_message(
|
||||||
|
"#video sample: %s, #video batch: %s",
|
||||||
|
format_numel_str(num_total_vid_samples),
|
||||||
|
format_numel_str(num_total_vid_batch),
|
||||||
|
)
|
||||||
|
log_message("===== Summary =====")
|
||||||
|
log_message("#non-empty buckets: %s", len(bucket_sample_dict))
|
||||||
|
log_message(
|
||||||
|
"Img/Vid sample ratio: %.2f",
|
||||||
|
num_total_img_samples / num_total_vid_samples if num_total_vid_samples > 0 else 0,
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
"Img/Vid batch ratio: %.2f", num_total_img_batch / num_total_vid_batch if num_total_vid_batch > 0 else 0
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
"vid batch 256: %s, vid batch 768: %s", format_numel_str(num_total_vid_batch_256), format_numel_str(num_total_vid_batch_768)
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
"Vid batch ratio (256px/768px): %.2f", num_total_vid_batch_256 / num_total_vid_batch_768 if num_total_vid_batch_768 > 0 else 0
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
"#training sample: %s, #training batch: %s",
|
||||||
|
format_numel_str(num_total_samples),
|
||||||
|
format_numel_str(num_total_batch),
|
||||||
|
)
|
||||||
|
return num_total_batch
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.last_micro_batch_access_index = 0
|
||||||
|
|
||||||
|
def set_step(self, start_step: int):
|
||||||
|
self.last_micro_batch_access_index = start_step * self.num_replicas
|
||||||
|
|
||||||
|
def state_dict(self, num_steps: int) -> dict:
|
||||||
|
# the last_micro_batch_access_index in the __iter__ is often
|
||||||
|
# not accurate during multi-workers and data prefetching
|
||||||
|
# thus, we need the user to pass the actual steps which have been executed
|
||||||
|
# to calculate the correct last_micro_batch_access_index
|
||||||
|
return {"seed": self.seed, "epoch": self.epoch, "last_micro_batch_access_index": num_steps * self.num_replicas}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: dict) -> None:
|
||||||
|
self.__dict__.update(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDistributedSampler(DistributedSampler):
|
||||||
|
"""
|
||||||
|
Used with BatchDataset;
|
||||||
|
Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then
|
||||||
|
| buffer {i} | buffer {i+1}
|
||||||
|
------ | ------------------- | -------------------
|
||||||
|
rank 0 | 0, 1, 2, 3, 4, | 5, 6, 7, 8, 9
|
||||||
|
rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19
|
||||||
|
rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset: Dataset, **kwargs):
|
||||||
|
super().__init__(dataset, **kwargs)
|
||||||
|
self.start_index = 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
num_buffers = self.dataset.num_buffers
|
||||||
|
len_buffer = self.dataset.len_buffer
|
||||||
|
num_buffers_i = num_buffers // self.num_replicas
|
||||||
|
num_samples_i = len_buffer * num_buffers_i
|
||||||
|
|
||||||
|
indices_i = np.arange(self.start_index, num_samples_i) + self.rank * num_samples_i
|
||||||
|
indices_i = indices_i.tolist()
|
||||||
|
|
||||||
|
return iter(indices_i)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.start_index = 0
|
||||||
|
|
||||||
|
def state_dict(self, step) -> dict:
|
||||||
|
return {"start_index": step}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: dict):
|
||||||
|
self.start_index = state_dict["start_index"] + 1
|
||||||
|
|
@ -0,0 +1,419 @@
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
||||||
|
from torchvision.io import write_video
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
|
||||||
|
from . import video_transforms
|
||||||
|
from .read_video import read_video
|
||||||
|
|
||||||
|
try:
|
||||||
|
import dask.dataframe as dd
|
||||||
|
|
||||||
|
SUPPORT_DASK = True
|
||||||
|
except:
|
||||||
|
SUPPORT_DASK = False
|
||||||
|
|
||||||
|
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||||
|
|
||||||
|
regex = re.compile(
|
||||||
|
r"^(?:http|ftp)s?://" # http:// or https://
|
||||||
|
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
|
||||||
|
r"localhost|" # localhost...
|
||||||
|
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
|
||||||
|
r"(?::\d+)?" # optional port
|
||||||
|
r"(?:/?|[/?]\S+)$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_img(path):
|
||||||
|
ext = os.path.splitext(path)[-1].lower()
|
||||||
|
return ext in IMG_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def is_vid(path):
|
||||||
|
ext = os.path.splitext(path)[-1].lower()
|
||||||
|
return ext in VID_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def is_url(url):
|
||||||
|
return re.match(regex, url) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def read_file(input_path, memory_efficient=False):
|
||||||
|
if input_path.endswith(".csv"):
|
||||||
|
assert not memory_efficient, "Memory efficient mode is not supported for CSV files"
|
||||||
|
return pd.read_csv(input_path)
|
||||||
|
elif input_path.endswith(".parquet"):
|
||||||
|
columns = None
|
||||||
|
if memory_efficient:
|
||||||
|
columns = ["path", "num_frames", "height", "width", "aspect_ratio", "fps", "resolution"]
|
||||||
|
if SUPPORT_DASK:
|
||||||
|
ret = dd.read_parquet(input_path, columns=columns).compute()
|
||||||
|
else:
|
||||||
|
ret = pd.read_parquet(input_path, columns=columns)
|
||||||
|
return ret
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported file format: {input_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def download_url(input_path):
|
||||||
|
output_dir = "cache"
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
base_name = os.path.basename(input_path)
|
||||||
|
output_path = os.path.join(output_dir, base_name)
|
||||||
|
img_data = requests.get(input_path).content
|
||||||
|
with open(output_path, "wb", encoding="utf-8") as handler:
|
||||||
|
handler.write(img_data)
|
||||||
|
print(f"URL {input_path} downloaded to {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def temporal_random_crop(
|
||||||
|
vframes: torch.Tensor, num_frames: int, frame_interval: int, return_frame_indices: bool = False
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
|
||||||
|
temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
|
||||||
|
total_frames = len(vframes)
|
||||||
|
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
end_frame_ind - start_frame_ind >= num_frames
|
||||||
|
), f"Not enough frames to sample, {end_frame_ind} - {start_frame_ind} < {num_frames}"
|
||||||
|
|
||||||
|
frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int)
|
||||||
|
video = vframes[frame_indices]
|
||||||
|
if return_frame_indices:
|
||||||
|
return video, frame_indices
|
||||||
|
else:
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def get_transforms_video(name="center", image_size=(256, 256)):
|
||||||
|
if name is None:
|
||||||
|
return None
|
||||||
|
elif name == "center":
|
||||||
|
assert image_size[0] == image_size[1], "image_size must be square for center crop"
|
||||||
|
transform_video = transforms.Compose(
|
||||||
|
[
|
||||||
|
video_transforms.ToTensorVideo(), # TCHW
|
||||||
|
# video_transforms.RandomHorizontalFlipVideo(),
|
||||||
|
video_transforms.UCFCenterCropVideo(image_size[0]),
|
||||||
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif name == "resize_crop":
|
||||||
|
transform_video = transforms.Compose(
|
||||||
|
[
|
||||||
|
video_transforms.ToTensorVideo(), # TCHW
|
||||||
|
video_transforms.ResizeCrop(image_size),
|
||||||
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif name == "rand_size_crop":
|
||||||
|
transform_video = transforms.Compose(
|
||||||
|
[
|
||||||
|
video_transforms.ToTensorVideo(), # TCHW
|
||||||
|
video_transforms.RandomSizedCrop(image_size),
|
||||||
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Transform {name} not implemented")
|
||||||
|
return transform_video
|
||||||
|
|
||||||
|
|
||||||
|
def get_transforms_image(name="center", image_size=(256, 256)):
|
||||||
|
if name is None:
|
||||||
|
return None
|
||||||
|
elif name == "center":
|
||||||
|
assert image_size[0] == image_size[1], "Image size must be square for center crop"
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])),
|
||||||
|
# transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif name == "resize_crop":
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif name == "rand_size_crop":
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Lambda(lambda pil_image: rand_size_crop_arr(pil_image, image_size)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Transform {name} not implemented")
|
||||||
|
return transform
|
||||||
|
|
||||||
|
|
||||||
|
def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)):
|
||||||
|
image = pil_loader(path)
|
||||||
|
if transform is None:
|
||||||
|
transform = get_transforms_image(image_size=image_size, name=transform_name)
|
||||||
|
image = transform(image)
|
||||||
|
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
|
||||||
|
video = video.permute(1, 0, 2, 3)
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)):
|
||||||
|
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||||
|
if transform is None:
|
||||||
|
transform = get_transforms_video(image_size=image_size, name=transform_name)
|
||||||
|
video = transform(vframes) # T C H W
|
||||||
|
video = video.permute(1, 0, 2, 3)
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def read_from_path(path, image_size, transform_name="center"):
|
||||||
|
if is_url(path):
|
||||||
|
path = download_url(path)
|
||||||
|
ext = os.path.splitext(path)[-1].lower()
|
||||||
|
if ext.lower() in VID_EXTENSIONS:
|
||||||
|
return read_video_from_path(path, image_size=image_size, transform_name=transform_name)
|
||||||
|
else:
|
||||||
|
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
||||||
|
return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
|
||||||
|
|
||||||
|
|
||||||
|
def save_sample(
|
||||||
|
x,
|
||||||
|
save_path=None,
|
||||||
|
fps=8,
|
||||||
|
normalize=True,
|
||||||
|
value_range=(-1, 1),
|
||||||
|
force_video=False,
|
||||||
|
verbose=True,
|
||||||
|
crf=23,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): shape [C, T, H, W]
|
||||||
|
"""
|
||||||
|
assert x.ndim == 4
|
||||||
|
|
||||||
|
if not force_video and x.shape[1] == 1: # T = 1: save as image
|
||||||
|
save_path += ".png"
|
||||||
|
x = x.squeeze(1)
|
||||||
|
save_image([x], save_path, normalize=normalize, value_range=value_range)
|
||||||
|
else:
|
||||||
|
save_path += ".mp4"
|
||||||
|
if normalize:
|
||||||
|
low, high = value_range
|
||||||
|
x.clamp_(min=low, max=high)
|
||||||
|
x.sub_(low).div_(max(high - low, 1e-5))
|
||||||
|
|
||||||
|
x = x.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
|
||||||
|
|
||||||
|
write_video(save_path, x, fps=fps, video_codec="h264", options={"crf": str(crf)})
|
||||||
|
if verbose:
|
||||||
|
print(f"Saved to {save_path}")
|
||||||
|
return save_path
|
||||||
|
|
||||||
|
|
||||||
|
def center_crop_arr(pil_image, image_size):
|
||||||
|
"""
|
||||||
|
Center cropping implementation from ADM.
|
||||||
|
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||||
|
"""
|
||||||
|
while min(*pil_image.size) >= 2 * image_size:
|
||||||
|
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
|
||||||
|
|
||||||
|
scale = image_size / min(*pil_image.size)
|
||||||
|
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
|
||||||
|
|
||||||
|
arr = np.array(pil_image)
|
||||||
|
crop_y = (arr.shape[0] - image_size) // 2
|
||||||
|
crop_x = (arr.shape[1] - image_size) // 2
|
||||||
|
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
|
||||||
|
|
||||||
|
|
||||||
|
def rand_size_crop_arr(pil_image, image_size):
|
||||||
|
"""
|
||||||
|
Randomly crop image for height and width, ranging from image_size[0] to image_size[1]
|
||||||
|
"""
|
||||||
|
arr = np.array(pil_image)
|
||||||
|
|
||||||
|
# get random target h w
|
||||||
|
height = random.randint(image_size[0], image_size[1])
|
||||||
|
width = random.randint(image_size[0], image_size[1])
|
||||||
|
# ensure that h w are factors of 8
|
||||||
|
height = height - height % 8
|
||||||
|
width = width - width % 8
|
||||||
|
|
||||||
|
# get random start pos
|
||||||
|
h_start = random.randint(0, max(len(arr) - height, 0))
|
||||||
|
w_start = random.randint(0, max(len(arr[0]) - height, 0))
|
||||||
|
|
||||||
|
# crop
|
||||||
|
return Image.fromarray(arr[h_start : h_start + height, w_start : w_start + width])
|
||||||
|
|
||||||
|
|
||||||
|
def resize_crop_to_fill(pil_image, image_size):
|
||||||
|
w, h = pil_image.size # PIL is (W, H)
|
||||||
|
th, tw = image_size
|
||||||
|
rh, rw = th / h, tw / w
|
||||||
|
if rh > rw:
|
||||||
|
sh, sw = th, round(w * rh)
|
||||||
|
image = pil_image.resize((sw, sh), Image.BICUBIC)
|
||||||
|
i = 0
|
||||||
|
j = int(round((sw - tw) / 2.0))
|
||||||
|
else:
|
||||||
|
sh, sw = round(h * rw), tw
|
||||||
|
image = pil_image.resize((sw, sh), Image.BICUBIC)
|
||||||
|
i = int(round((sh - th) / 2.0))
|
||||||
|
j = 0
|
||||||
|
arr = np.array(image)
|
||||||
|
assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
|
||||||
|
return Image.fromarray(arr[i : i + th, j : j + tw])
|
||||||
|
|
||||||
|
|
||||||
|
def map_target_fps(
|
||||||
|
fps: float,
|
||||||
|
max_fps: float,
|
||||||
|
) -> tuple[float, int]:
|
||||||
|
"""
|
||||||
|
Map fps to a new fps that is less than max_fps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fps (float): Original fps.
|
||||||
|
max_fps (float): Maximum fps.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[float, int]: New fps and sampling interval.
|
||||||
|
"""
|
||||||
|
if math.isnan(fps):
|
||||||
|
return 0, 1
|
||||||
|
if fps < max_fps:
|
||||||
|
return fps, 1
|
||||||
|
sampling_interval = math.ceil(fps / max_fps)
|
||||||
|
new_fps = math.floor(fps / sampling_interval)
|
||||||
|
return new_fps, sampling_interval
|
||||||
|
|
||||||
|
|
||||||
|
def sync_object_across_devices(obj: Any, rank: int = 0):
|
||||||
|
"""
|
||||||
|
Synchronizes any picklable object across devices in a PyTorch distributed setting
|
||||||
|
using `broadcast_object_list` with CUDA support.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
obj (Any): The object to synchronize. Can be any picklable object (e.g., list, dict, custom class).
|
||||||
|
rank (int): The rank of the device from which to broadcast the object state. Default is 0.
|
||||||
|
|
||||||
|
Note: Ensure torch.distributed is initialized before using this function and CUDA is available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Move the object to a list for broadcasting
|
||||||
|
object_list = [obj]
|
||||||
|
|
||||||
|
# Broadcast the object list from the source rank to all other ranks
|
||||||
|
dist.broadcast_object_list(object_list, src=rank, device="cuda")
|
||||||
|
|
||||||
|
# Retrieve the synchronized object
|
||||||
|
obj = object_list[0]
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_image_by_path(path: str, height: int, width: int):
|
||||||
|
"""
|
||||||
|
Rescales an image to the specified height and width and saves it back to the original path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The file path of the image.
|
||||||
|
height (int): The target height of the image.
|
||||||
|
width (int): The target width of the image.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# read image
|
||||||
|
image = Image.open(path)
|
||||||
|
|
||||||
|
# check if image is valid
|
||||||
|
if image is None:
|
||||||
|
raise ValueError("The image is invalid or empty.")
|
||||||
|
|
||||||
|
# resize image
|
||||||
|
resize_transform = transforms.Resize((width, height))
|
||||||
|
resized_image = resize_transform(image)
|
||||||
|
|
||||||
|
# save resized image back to the original path
|
||||||
|
resized_image.save(path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error rescaling image: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_video_by_path(path: str, height: int, width: int):
|
||||||
|
"""
|
||||||
|
Rescales an MP4 video (without audio) to the specified height and width.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The file path of the video.
|
||||||
|
height (int): The target height of the video.
|
||||||
|
width (int): The target width of the video.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Read video and metadata
|
||||||
|
video, info = read_video(path, backend="av")
|
||||||
|
|
||||||
|
# Check if video is valid
|
||||||
|
if video is None or video.size(0) == 0:
|
||||||
|
raise ValueError("The video is invalid or empty.")
|
||||||
|
|
||||||
|
# Resize video frames using a performant method
|
||||||
|
resize_transform = transforms.Compose([transforms.Resize((width, height))])
|
||||||
|
resized_video = torch.stack([resize_transform(frame) for frame in video])
|
||||||
|
|
||||||
|
# Save resized video back to the original path
|
||||||
|
resized_video = resized_video.permute(0, 2, 3, 1)
|
||||||
|
write_video(path, resized_video, fps=int(info["video_fps"]), video_codec="h264")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error rescaling video: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_tensor_to_disk(tensor, path, exist_handling="overwrite"):
|
||||||
|
if os.path.exists(path):
|
||||||
|
if exist_handling == "ignore":
|
||||||
|
return
|
||||||
|
elif exist_handling == "raise":
|
||||||
|
raise UserWarning(f"File {path} already exists, rewriting!")
|
||||||
|
torch.save(tensor, path)
|
||||||
|
|
||||||
|
|
||||||
|
def save_tensor_to_internet(tensor, path):
|
||||||
|
raise NotImplementedError("save_tensor_to_internet is not implemented yet!")
|
||||||
|
|
||||||
|
|
||||||
|
def save_latent(latent, path, exist_handling="overwrite"):
|
||||||
|
if path.startswith(("http://", "https://", "ftp://", "sftp://")):
|
||||||
|
save_tensor_to_internet(latent, path)
|
||||||
|
else:
|
||||||
|
save_tensor_to_disk(latent, path, exist_handling=exist_handling)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_latents(latents, path, exist_handling="overwrite"):
|
||||||
|
for i in range(latents.shape[0]):
|
||||||
|
save_latent(latents[i], path[i], exist_handling=exist_handling)
|
||||||
|
|
@ -0,0 +1,595 @@
|
||||||
|
# Copyright 2024 Vchitect/Latte
|
||||||
|
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.# Modified from Latte
|
||||||
|
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
# - This file is adapted from https://github.com/Vchitect/Latte/blob/main/datasets/video_transforms.py
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tensor_video_clip(clip):
|
||||||
|
if not torch.is_tensor(clip):
|
||||||
|
raise TypeError("clip should be Tensor. Got %s" % type(clip))
|
||||||
|
|
||||||
|
if not clip.ndimension() == 4:
|
||||||
|
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def crop(clip, i, j, h, w):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||||
|
"""
|
||||||
|
if len(clip.size()) != 4:
|
||||||
|
raise ValueError("clip should be a 4D tensor")
|
||||||
|
return clip[..., i : i + h, j : j + w]
|
||||||
|
|
||||||
|
|
||||||
|
def resize(clip, target_size, interpolation_mode):
|
||||||
|
if len(target_size) != 2:
|
||||||
|
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
||||||
|
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_scale(clip, target_size, interpolation_mode):
|
||||||
|
if len(target_size) != 2:
|
||||||
|
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
||||||
|
H, W = clip.size(-2), clip.size(-1)
|
||||||
|
scale_ = target_size[0] / min(H, W)
|
||||||
|
th, tw = int(round(H * scale_)), int(round(W * scale_))
|
||||||
|
return torch.nn.functional.interpolate(clip, size=(th, tw), mode=interpolation_mode, align_corners=False)
|
||||||
|
|
||||||
|
|
||||||
|
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
|
||||||
|
"""
|
||||||
|
Do spatial cropping and resizing to the video clip
|
||||||
|
Args:
|
||||||
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||||
|
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
||||||
|
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
||||||
|
h (int): Height of the cropped region.
|
||||||
|
w (int): Width of the cropped region.
|
||||||
|
size (tuple(int, int)): height and width of resized clip
|
||||||
|
Returns:
|
||||||
|
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
|
||||||
|
"""
|
||||||
|
if not _is_tensor_video_clip(clip):
|
||||||
|
raise ValueError("clip should be a 4D torch.tensor")
|
||||||
|
clip = crop(clip, i, j, h, w)
|
||||||
|
clip = resize(clip, size, interpolation_mode)
|
||||||
|
return clip
|
||||||
|
|
||||||
|
|
||||||
|
def center_crop(clip, crop_size):
|
||||||
|
if not _is_tensor_video_clip(clip):
|
||||||
|
raise ValueError("clip should be a 4D torch.tensor")
|
||||||
|
h, w = clip.size(-2), clip.size(-1)
|
||||||
|
th, tw = crop_size
|
||||||
|
if h < th or w < tw:
|
||||||
|
raise ValueError("height and width must be no smaller than crop_size")
|
||||||
|
|
||||||
|
i = int(round((h - th) / 2.0))
|
||||||
|
j = int(round((w - tw) / 2.0))
|
||||||
|
return crop(clip, i, j, th, tw)
|
||||||
|
|
||||||
|
|
||||||
|
def center_crop_using_short_edge(clip):
|
||||||
|
if not _is_tensor_video_clip(clip):
|
||||||
|
raise ValueError("clip should be a 4D torch.tensor")
|
||||||
|
h, w = clip.size(-2), clip.size(-1)
|
||||||
|
if h < w:
|
||||||
|
th, tw = h, h
|
||||||
|
i = 0
|
||||||
|
j = int(round((w - tw) / 2.0))
|
||||||
|
else:
|
||||||
|
th, tw = w, w
|
||||||
|
i = int(round((h - th) / 2.0))
|
||||||
|
j = 0
|
||||||
|
return crop(clip, i, j, th, tw)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_crop_to_fill(clip, target_size):
|
||||||
|
if not _is_tensor_video_clip(clip):
|
||||||
|
raise ValueError("clip should be a 4D torch.tensor")
|
||||||
|
h, w = clip.size(-2), clip.size(-1)
|
||||||
|
th, tw = target_size[0], target_size[1]
|
||||||
|
rh, rw = th / h, tw / w
|
||||||
|
if rh > rw:
|
||||||
|
sh, sw = th, round(w * rh)
|
||||||
|
clip = resize(clip, (sh, sw), "bilinear")
|
||||||
|
i = 0
|
||||||
|
j = int(round(sw - tw) / 2.0)
|
||||||
|
else:
|
||||||
|
sh, sw = round(h * rw), tw
|
||||||
|
clip = resize(clip, (sh, sw), "bilinear")
|
||||||
|
i = int(round(sh - th) / 2.0)
|
||||||
|
j = 0
|
||||||
|
assert i + th <= clip.size(-2) and j + tw <= clip.size(-1)
|
||||||
|
return crop(clip, i, j, th, tw)
|
||||||
|
|
||||||
|
|
||||||
|
# def rand_crop_h_w(clip, target_size_range, multiples_of: int = 8):
|
||||||
|
# # NOTE: for some reason, if don't re-import, gives same randint results
|
||||||
|
# import sys
|
||||||
|
|
||||||
|
# del sys.modules["random"]
|
||||||
|
# import random
|
||||||
|
|
||||||
|
# if not _is_tensor_video_clip(clip):
|
||||||
|
# raise ValueError("clip should be a 4D torch.tensor")
|
||||||
|
# h, w = clip.size(-2), clip.size(-1)
|
||||||
|
|
||||||
|
# # get random target h w
|
||||||
|
# th = random.randint(target_size_range[0], target_size_range[1])
|
||||||
|
# tw = random.randint(target_size_range[0], target_size_range[1])
|
||||||
|
|
||||||
|
# # ensure that h w are factors of 8
|
||||||
|
# th = th - th % multiples_of
|
||||||
|
# tw = tw - tw % multiples_of
|
||||||
|
|
||||||
|
# # get random start pos
|
||||||
|
# i = random.randint(0, h-th) if h > th else 0
|
||||||
|
# j = random.randint(0, w-tw) if w > tw else 0
|
||||||
|
|
||||||
|
# th = th if th < h else h
|
||||||
|
# tw = tw if tw < w else w
|
||||||
|
|
||||||
|
# # print("target size range:",target_size_range)
|
||||||
|
# # print("original size:", h, w)
|
||||||
|
# # print("crop size:", th, tw)
|
||||||
|
# # print(f"crop:{i}-{i+th}, {j}-{j+tw}")
|
||||||
|
|
||||||
|
# return (crop(clip, i, j, th, tw), th, tw)
|
||||||
|
|
||||||
|
|
||||||
|
def random_shift_crop(clip):
|
||||||
|
"""
|
||||||
|
Slide along the long edge, with the short edge as crop size
|
||||||
|
"""
|
||||||
|
if not _is_tensor_video_clip(clip):
|
||||||
|
raise ValueError("clip should be a 4D torch.tensor")
|
||||||
|
h, w = clip.size(-2), clip.size(-1)
|
||||||
|
|
||||||
|
if h <= w:
|
||||||
|
short_edge = h
|
||||||
|
else:
|
||||||
|
short_edge = w
|
||||||
|
|
||||||
|
th, tw = short_edge, short_edge
|
||||||
|
|
||||||
|
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
||||||
|
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
||||||
|
return crop(clip, i, j, th, tw)
|
||||||
|
|
||||||
|
|
||||||
|
def to_tensor(clip):
|
||||||
|
"""
|
||||||
|
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
||||||
|
permute the dimensions of clip tensor
|
||||||
|
Args:
|
||||||
|
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
||||||
|
Return:
|
||||||
|
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
||||||
|
"""
|
||||||
|
_is_tensor_video_clip(clip)
|
||||||
|
if not clip.dtype == torch.uint8:
|
||||||
|
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
|
||||||
|
# return clip.float().permute(3, 0, 1, 2) / 255.0
|
||||||
|
return clip.float() / 255.0
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(clip, mean, std, inplace=False):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
||||||
|
mean (tuple): pixel RGB mean. Size is (3)
|
||||||
|
std (tuple): pixel standard deviation. Size is (3)
|
||||||
|
Returns:
|
||||||
|
normalized clip (torch.tensor): Size is (T, C, H, W)
|
||||||
|
"""
|
||||||
|
if not _is_tensor_video_clip(clip):
|
||||||
|
raise ValueError("clip should be a 4D torch.tensor")
|
||||||
|
if not inplace:
|
||||||
|
clip = clip.clone()
|
||||||
|
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
|
||||||
|
# print(mean)
|
||||||
|
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
|
||||||
|
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
||||||
|
return clip
|
||||||
|
|
||||||
|
|
||||||
|
def hflip(clip):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
||||||
|
Returns:
|
||||||
|
flipped clip (torch.tensor): Size is (T, C, H, W)
|
||||||
|
"""
|
||||||
|
if not _is_tensor_video_clip(clip):
|
||||||
|
raise ValueError("clip should be a 4D torch.tensor")
|
||||||
|
return clip.flip(-1)
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeCrop:
|
||||||
|
def __init__(self, size):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
self.size = (int(size), int(size))
|
||||||
|
else:
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
clip = resize_crop_to_fill(clip, self.size)
|
||||||
|
return clip
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(size={self.size})"
|
||||||
|
|
||||||
|
|
||||||
|
class RandomSizedCrop:
|
||||||
|
def __init__(self, size):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
self.size = (int(size), int(size))
|
||||||
|
else:
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
i, j, h, w = self.get_params(clip)
|
||||||
|
# self.size = (h, w)
|
||||||
|
return crop(clip, i, j, h, w)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(size={self.size})"
|
||||||
|
|
||||||
|
def get_params(self, clip, multiples_of=8):
|
||||||
|
h, w = clip.shape[-2:]
|
||||||
|
|
||||||
|
# get random target h w
|
||||||
|
th = random.randint(self.size[0], self.size[1])
|
||||||
|
tw = random.randint(self.size[0], self.size[1])
|
||||||
|
# ensure that h w are factors of 8
|
||||||
|
th = th - th % multiples_of
|
||||||
|
tw = tw - tw % multiples_of
|
||||||
|
|
||||||
|
if h < th:
|
||||||
|
th = h - h % multiples_of
|
||||||
|
if w < tw:
|
||||||
|
tw = w - w % multiples_of
|
||||||
|
|
||||||
|
if w == tw and h == th:
|
||||||
|
return 0, 0, h, w
|
||||||
|
|
||||||
|
else:
|
||||||
|
# get random start pos
|
||||||
|
i = random.randint(0, h - th)
|
||||||
|
j = random.randint(0, w - tw)
|
||||||
|
|
||||||
|
return i, j, th, tw
|
||||||
|
|
||||||
|
|
||||||
|
class RandomCropVideo:
|
||||||
|
def __init__(self, size):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
self.size = (int(size), int(size))
|
||||||
|
else:
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||||
|
Returns:
|
||||||
|
torch.tensor: randomly cropped video clip.
|
||||||
|
size is (T, C, OH, OW)
|
||||||
|
"""
|
||||||
|
i, j, h, w = self.get_params(clip)
|
||||||
|
return crop(clip, i, j, h, w)
|
||||||
|
|
||||||
|
def get_params(self, clip):
|
||||||
|
h, w = clip.shape[-2:]
|
||||||
|
th, tw = self.size
|
||||||
|
|
||||||
|
if h < th or w < tw:
|
||||||
|
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
|
||||||
|
|
||||||
|
if w == tw and h == th:
|
||||||
|
return 0, 0, h, w
|
||||||
|
|
||||||
|
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
||||||
|
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
||||||
|
|
||||||
|
return i, j, th, tw
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(size={self.size})"
|
||||||
|
|
||||||
|
|
||||||
|
class CenterCropResizeVideo:
|
||||||
|
"""
|
||||||
|
First use the short side for cropping length,
|
||||||
|
center crop video, then resize to the specified size
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size,
|
||||||
|
interpolation_mode="bilinear",
|
||||||
|
):
|
||||||
|
if isinstance(size, tuple):
|
||||||
|
if len(size) != 2:
|
||||||
|
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
||||||
|
self.size = size
|
||||||
|
else:
|
||||||
|
self.size = (size, size)
|
||||||
|
|
||||||
|
self.interpolation_mode = interpolation_mode
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||||
|
Returns:
|
||||||
|
torch.tensor: scale resized / center cropped video clip.
|
||||||
|
size is (T, C, crop_size, crop_size)
|
||||||
|
"""
|
||||||
|
clip_center_crop = center_crop_using_short_edge(clip)
|
||||||
|
clip_center_crop_resize = resize(
|
||||||
|
clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode
|
||||||
|
)
|
||||||
|
return clip_center_crop_resize
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
||||||
|
|
||||||
|
|
||||||
|
class UCFCenterCropVideo:
|
||||||
|
"""
|
||||||
|
First scale to the specified size in equal proportion to the short edge,
|
||||||
|
then center cropping
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size,
|
||||||
|
interpolation_mode="bilinear",
|
||||||
|
):
|
||||||
|
if isinstance(size, tuple):
|
||||||
|
if len(size) != 2:
|
||||||
|
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
||||||
|
self.size = size
|
||||||
|
else:
|
||||||
|
self.size = (size, size)
|
||||||
|
|
||||||
|
self.interpolation_mode = interpolation_mode
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||||
|
Returns:
|
||||||
|
torch.tensor: scale resized / center cropped video clip.
|
||||||
|
size is (T, C, crop_size, crop_size)
|
||||||
|
"""
|
||||||
|
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
||||||
|
clip_center_crop = center_crop(clip_resize, self.size)
|
||||||
|
return clip_center_crop
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
||||||
|
|
||||||
|
|
||||||
|
class KineticsRandomCropResizeVideo:
|
||||||
|
"""
|
||||||
|
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size,
|
||||||
|
interpolation_mode="bilinear",
|
||||||
|
):
|
||||||
|
if isinstance(size, tuple):
|
||||||
|
if len(size) != 2:
|
||||||
|
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
||||||
|
self.size = size
|
||||||
|
else:
|
||||||
|
self.size = (size, size)
|
||||||
|
|
||||||
|
self.interpolation_mode = interpolation_mode
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
clip_random_crop = random_shift_crop(clip)
|
||||||
|
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
|
||||||
|
return clip_resize
|
||||||
|
|
||||||
|
|
||||||
|
class CenterCropVideo:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size,
|
||||||
|
interpolation_mode="bilinear",
|
||||||
|
):
|
||||||
|
if isinstance(size, tuple):
|
||||||
|
if len(size) != 2:
|
||||||
|
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
||||||
|
self.size = size
|
||||||
|
else:
|
||||||
|
self.size = (size, size)
|
||||||
|
|
||||||
|
self.interpolation_mode = interpolation_mode
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
||||||
|
Returns:
|
||||||
|
torch.tensor: center cropped video clip.
|
||||||
|
size is (T, C, crop_size, crop_size)
|
||||||
|
"""
|
||||||
|
clip_center_crop = center_crop(clip, self.size)
|
||||||
|
return clip_center_crop
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeVideo:
|
||||||
|
"""
|
||||||
|
Normalize the video clip by mean subtraction and division by standard deviation
|
||||||
|
Args:
|
||||||
|
mean (3-tuple): pixel RGB mean
|
||||||
|
std (3-tuple): pixel RGB standard deviation
|
||||||
|
inplace (boolean): whether do in-place normalization
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mean, std, inplace=False):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
self.inplace = inplace
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
|
||||||
|
"""
|
||||||
|
return normalize(clip, self.mean, self.std, self.inplace)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
|
||||||
|
|
||||||
|
|
||||||
|
class ToTensorVideo:
|
||||||
|
"""
|
||||||
|
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
||||||
|
permute the dimensions of clip tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
||||||
|
Return:
|
||||||
|
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
||||||
|
"""
|
||||||
|
return to_tensor(clip)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
class RandomHorizontalFlipVideo:
|
||||||
|
"""
|
||||||
|
Flip the video clip along the horizontal direction with a given probability
|
||||||
|
Args:
|
||||||
|
p (float): probability of the clip being flipped. Default value is 0.5
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, p=0.5):
|
||||||
|
self.p = p
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
clip (torch.tensor): Size is (T, C, H, W)
|
||||||
|
Return:
|
||||||
|
clip (torch.tensor): Size is (T, C, H, W)
|
||||||
|
"""
|
||||||
|
if random.random() < self.p:
|
||||||
|
clip = hflip(clip)
|
||||||
|
return clip
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(p={self.p})"
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# --------------------- Sampling ---------------------------
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
class TemporalRandomCrop(object):
|
||||||
|
"""Temporally crop the given frame indices at a random location.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (int): Desired length of frames will be seen in the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, size):
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, total_frames):
|
||||||
|
rand_end = max(0, total_frames - self.size - 1)
|
||||||
|
begin_index = random.randint(0, rand_end)
|
||||||
|
end_index = min(begin_index + self.size, total_frames)
|
||||||
|
return begin_index, end_index
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torchvision.io as io
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
|
||||||
|
vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW")
|
||||||
|
|
||||||
|
trans = transforms.Compose(
|
||||||
|
[
|
||||||
|
ToTensorVideo(),
|
||||||
|
RandomHorizontalFlipVideo(),
|
||||||
|
UCFCenterCropVideo(512),
|
||||||
|
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||||
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
target_video_len = 32
|
||||||
|
frame_interval = 1
|
||||||
|
total_frames = len(vframes)
|
||||||
|
print(total_frames)
|
||||||
|
|
||||||
|
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
|
||||||
|
|
||||||
|
# Sampling video frames
|
||||||
|
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
|
||||||
|
# print(start_frame_ind)
|
||||||
|
# print(end_frame_ind)
|
||||||
|
assert end_frame_ind - start_frame_ind >= target_video_len
|
||||||
|
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
|
||||||
|
print(frame_indice)
|
||||||
|
|
||||||
|
select_vframes = vframes[frame_indice]
|
||||||
|
print(select_vframes.shape)
|
||||||
|
print(select_vframes.dtype)
|
||||||
|
|
||||||
|
select_vframes_trans = trans(select_vframes)
|
||||||
|
print(select_vframes_trans.shape)
|
||||||
|
print(select_vframes_trans.dtype)
|
||||||
|
|
||||||
|
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
|
||||||
|
print(select_vframes_trans_int.dtype)
|
||||||
|
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
|
||||||
|
|
||||||
|
io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
|
||||||
|
|
||||||
|
for i in range(target_video_len):
|
||||||
|
save_image(
|
||||||
|
select_vframes_trans[i], os.path.join("./test000", "%04d.png" % i), normalize=True, value_range=(-1, 1)
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,282 @@
|
||||||
|
# Dataset Management
|
||||||
|
|
||||||
|
- [Dataset Management](#dataset-management)
|
||||||
|
- [Dataset Format](#dataset-format)
|
||||||
|
- [Dataset to CSV](#dataset-to-csv)
|
||||||
|
- [Manage datasets](#manage-datasets)
|
||||||
|
- [Requirement](#requirement)
|
||||||
|
- [Basic Usage](#basic-usage)
|
||||||
|
- [Score filtering](#score-filtering)
|
||||||
|
- [Documentation](#documentation)
|
||||||
|
- [Transform datasets](#transform-datasets)
|
||||||
|
- [Resize](#resize)
|
||||||
|
- [Frame extraction](#frame-extraction)
|
||||||
|
- [Crop Midjourney 4 grid](#crop-midjourney-4-grid)
|
||||||
|
- [Analyze datasets](#analyze-datasets)
|
||||||
|
- [Data Process Pipeline](#data-process-pipeline)
|
||||||
|
|
||||||
|
After preparing the raw dataset according to the [instructions](/docs/datasets.md), you can use the following commands to manage the dataset.
|
||||||
|
|
||||||
|
## Dataset Format
|
||||||
|
|
||||||
|
All dataset should be provided in a `.csv` file (or `parquet.gzip` to save space), which is used for both training and data preprocessing. The columns should follow the words below:
|
||||||
|
|
||||||
|
- `path`: the relative/absolute path or url to the image or video file. Required.
|
||||||
|
- `text`: the caption or description of the image or video. Required for training.
|
||||||
|
- `num_frames`: the number of frames in the video. Required for training.
|
||||||
|
- `width`: the width of the video frame. Required for dynamic bucket.
|
||||||
|
- `height`: the height of the video frame. Required for dynamic bucket.
|
||||||
|
- `aspect_ratio`: the aspect ratio of the video frame (height / width). Required for dynamic bucket.
|
||||||
|
- `resolution`: height x width. For analysis.
|
||||||
|
- `text_len`: the number of tokens in the text. For analysis.
|
||||||
|
- `aes`: aesthetic score calculated by [asethetic scorer](/tools/aesthetic/README.md). For filtering.
|
||||||
|
- `flow`: optical flow score calculated by [UniMatch](/tools/scoring/README.md). For filtering.
|
||||||
|
- `match`: matching score of a image-text/video-text pair calculated by [CLIP](/tools/scoring/README.md). For filtering.
|
||||||
|
- `fps`: the frame rate of the video. Optional.
|
||||||
|
- `cmotion`: the camera motion.
|
||||||
|
|
||||||
|
An example ready for training:
|
||||||
|
|
||||||
|
```csv
|
||||||
|
path, text, num_frames, width, height, aspect_ratio
|
||||||
|
/absolute/path/to/image1.jpg, caption, 1, 720, 1280, 0.5625
|
||||||
|
/absolute/path/to/video1.mp4, caption, 120, 720, 1280, 0.5625
|
||||||
|
/absolute/path/to/video2.mp4, caption, 20, 256, 256, 1
|
||||||
|
```
|
||||||
|
|
||||||
|
We use pandas to manage the `.csv` or `.parquet` files. The following code is for reading and writing files:
|
||||||
|
|
||||||
|
```python
|
||||||
|
df = pd.read_csv(input_path)
|
||||||
|
df = df.to_csv(output_path, index=False)
|
||||||
|
# or use parquet, which is smaller
|
||||||
|
df = pd.read_parquet(input_path)
|
||||||
|
df = df.to_parquet(output_path, index=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dataset to CSV
|
||||||
|
|
||||||
|
As a start point, `convert.py` is used to convert the dataset to a CSV file. You can use the following commands to convert the dataset to a CSV file:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m tools.datasets.convert DATASET-TYPE DATA_FOLDER
|
||||||
|
|
||||||
|
# general video folder
|
||||||
|
python -m tools.datasets.convert video VIDEO_FOLDER --output video.csv
|
||||||
|
# general image folder
|
||||||
|
python -m tools.datasets.convert image IMAGE_FOLDER --output image.csv
|
||||||
|
# imagenet
|
||||||
|
python -m tools.datasets.convert imagenet IMAGENET_FOLDER --split train
|
||||||
|
# ucf101
|
||||||
|
python -m tools.datasets.convert ucf101 UCF101_FOLDER --split videos
|
||||||
|
# vidprom
|
||||||
|
python -m tools.datasets.convert vidprom VIDPROM_FOLDER --info VidProM_semantic_unique.csv
|
||||||
|
```
|
||||||
|
|
||||||
|
## Manage datasets
|
||||||
|
|
||||||
|
Use `datautil` to manage the dataset.
|
||||||
|
|
||||||
|
### Requirement
|
||||||
|
|
||||||
|
Follow our [installation guide](../../docs/installation.md)'s "Data Dependencies" and "Datasets" section to install the required packages.
|
||||||
|
<!-- To accelerate processing speed, you can install [pandarallel](https://github.com/nalepae/pandarallel):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install pandarallel
|
||||||
|
``` -->
|
||||||
|
|
||||||
|
<!-- To get image and video information, you need to install [opencv-python](https://github.com/opencv/opencv-python): -->
|
||||||
|
|
||||||
|
<!-- ```bash
|
||||||
|
pip install opencv-python
|
||||||
|
# If your videos are in av1 codec instead of h264, you need to
|
||||||
|
# - install ffmpeg first
|
||||||
|
# - install via conda to support av1 codec
|
||||||
|
conda install -c conda-forge opencv
|
||||||
|
``` -->
|
||||||
|
|
||||||
|
<!-- Or to get video information, you can install ffmpeg and ffmpeg-python:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install ffmpeg-python
|
||||||
|
``` -->
|
||||||
|
|
||||||
|
<!-- To filter a specific language, you need to install [lingua](https://github.com/pemistahl/lingua-py):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install lingua-language-detector
|
||||||
|
``` -->
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
You can use the following commands to process the `csv` or `parquet` files. The output file will be saved in the same directory as the input, with different suffixes indicating the processed method.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# datautil takes multiple CSV files as input and merge them into one CSV file
|
||||||
|
# output: DATA1+DATA2.csv
|
||||||
|
python -m tools.datasets.datautil DATA1.csv DATA2.csv
|
||||||
|
|
||||||
|
# shard CSV files into multiple CSV files
|
||||||
|
# output: DATA1_0.csv, DATA1_1.csv, ...
|
||||||
|
python -m tools.datasets.datautil DATA1.csv --shard 10
|
||||||
|
|
||||||
|
# filter frames between 128 and 256, with captions
|
||||||
|
# output: DATA1_fmin_128_fmax_256.csv
|
||||||
|
python -m tools.datasets.datautil DATA.csv --fmin 128 --fmax 256
|
||||||
|
|
||||||
|
# Disable parallel processing
|
||||||
|
python -m tools.datasets.datautil DATA.csv --fmin 128 --fmax 256 --disable-parallel
|
||||||
|
|
||||||
|
# Compute num_frames, height, width, fps, aspect_ratio for videos or images
|
||||||
|
# output: IMG_DATA+VID_DATA_vinfo.csv
|
||||||
|
python -m tools.datasets.datautil IMG_DATA.csv VID_DATA.csv --video-info
|
||||||
|
|
||||||
|
# You can run multiple operations at the same time.
|
||||||
|
python -m tools.datasets.datautil DATA.csv --video-info --remove-empty-caption --remove-url --lang en
|
||||||
|
```
|
||||||
|
|
||||||
|
### Score filtering
|
||||||
|
|
||||||
|
To examine and filter the quality of the dataset by aesthetic score and clip score, you can use the following commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# sort the dataset by aesthetic score
|
||||||
|
# output: DATA_sort.csv
|
||||||
|
python -m tools.datasets.datautil DATA.csv --sort aesthetic_score
|
||||||
|
# View examples of high aesthetic score
|
||||||
|
head -n 10 DATA_sort.csv
|
||||||
|
# View examples of low aesthetic score
|
||||||
|
tail -n 10 DATA_sort.csv
|
||||||
|
|
||||||
|
# sort the dataset by clip score
|
||||||
|
# output: DATA_sort.csv
|
||||||
|
python -m tools.datasets.datautil DATA.csv --sort clip_score
|
||||||
|
|
||||||
|
# filter the dataset by aesthetic score
|
||||||
|
# output: DATA_aesmin_0.5.csv
|
||||||
|
python -m tools.datasets.datautil DATA.csv --aesmin 0.5
|
||||||
|
# filter the dataset by clip score
|
||||||
|
# output: DATA_matchmin_0.5.csv
|
||||||
|
python -m tools.datasets.datautil DATA.csv --matchmin 0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
|
||||||
|
You can also use `python -m tools.datasets.datautil --help` to see usage.
|
||||||
|
|
||||||
|
| Args | File suffix | Description |
|
||||||
|
| --------------------------- | -------------- | ------------------------------------------------------------- |
|
||||||
|
| `--output OUTPUT` | | Output path |
|
||||||
|
| `--format FORMAT` | | Output format (csv, parquet, parquet.gzip) |
|
||||||
|
| `--disable-parallel` | | Disable `pandarallel` |
|
||||||
|
| `--seed SEED` | | Random seed |
|
||||||
|
| `--shard SHARD` | `_0`,`_1`, ... | Shard the dataset |
|
||||||
|
| `--sort KEY` | `_sort` | Sort the dataset by KEY |
|
||||||
|
| `--sort-descending KEY` | `_sort` | Sort the dataset by KEY in descending order |
|
||||||
|
| `--difference DATA.csv` | | Remove the paths in DATA.csv from the dataset |
|
||||||
|
| `--intersection DATA.csv` | | Keep the paths in DATA.csv from the dataset and merge columns |
|
||||||
|
| `--info` | `_info` | Get the basic information of each video and image (cv2) |
|
||||||
|
| `--ext` | `_ext` | Remove rows if the file does not exist |
|
||||||
|
| `--relpath` | `_relpath` | Modify the path to relative path by root given |
|
||||||
|
| `--abspath` | `_abspath` | Modify the path to absolute path by root given |
|
||||||
|
| `--remove-empty-caption` | `_noempty` | Remove rows with empty caption |
|
||||||
|
| `--remove-url` | `_nourl` | Remove rows with url in caption |
|
||||||
|
| `--lang LANG` | `_lang` | Remove rows with other language |
|
||||||
|
| `--remove-path-duplication` | `_noduppath` | Remove rows with duplicated path |
|
||||||
|
| `--remove-text-duplication` | `_noduptext` | Remove rows with duplicated caption |
|
||||||
|
| `--refine-llm-caption` | `_llm` | Modify the caption generated by LLM |
|
||||||
|
| `--clean-caption MODEL` | `_clean` | Modify the caption according to T5 pipeline to suit training |
|
||||||
|
| `--unescape` | `_unescape` | Unescape the caption |
|
||||||
|
| `--merge-cmotion` | `_cmotion` | Merge the camera motion to the caption |
|
||||||
|
| `--count-num-token` | `_ntoken` | Count the number of tokens in the caption |
|
||||||
|
| `--load-caption EXT` | `_load` | Load the caption from the file |
|
||||||
|
| `--fmin FMIN` | `_fmin` | Filter the dataset by minimum number of frames |
|
||||||
|
| `--fmax FMAX` | `_fmax` | Filter the dataset by maximum number of frames |
|
||||||
|
| `--hwmax HWMAX` | `_hwmax` | Filter the dataset by maximum height x width |
|
||||||
|
| `--aesmin AESMIN` | `_aesmin` | Filter the dataset by minimum aesthetic score |
|
||||||
|
| `--matchmin MATCHMIN` | `_matchmin` | Filter the dataset by minimum clip score |
|
||||||
|
| `--flowmin FLOWMIN` | `_flowmin` | Filter the dataset by minimum optical flow score |
|
||||||
|
|
||||||
|
## Transform datasets
|
||||||
|
|
||||||
|
The `tools.datasets.transform` module provides a set of tools to transform the dataset. The general usage is as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m tools.datasets.transform TRANSFORM_TYPE META.csv ORIGINAL_DATA_FOLDER DATA_FOLDER_TO_SAVE_RESULTS --additional-args
|
||||||
|
```
|
||||||
|
|
||||||
|
### Resize
|
||||||
|
|
||||||
|
Sometimes you may need to resize the images or videos to a specific resolution. You can use the following commands to resize the dataset:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m tools.datasets.transform meta.csv /path/to/raw/data /path/to/new/data --length 2160
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frame extraction
|
||||||
|
|
||||||
|
To extract frames from videos, you can use the following commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m tools.datasets.transform vid_frame_extract meta.csv /path/to/raw/data /path/to/new/data --points 0.1 0.5 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
### Crop Midjourney 4 grid
|
||||||
|
|
||||||
|
Randomly select one of the 4 images in the 4 grid generated by Midjourney.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m tools.datasets.transform img_rand_crop meta.csv /path/to/raw/data /path/to/new/data
|
||||||
|
```
|
||||||
|
|
||||||
|
## Analyze datasets
|
||||||
|
|
||||||
|
You can easily get basic information about a `.csv` dataset by using the following commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# examine the first 10 rows of the CSV file
|
||||||
|
head -n 10 DATA1.csv
|
||||||
|
# count the number of data in the CSV file (approximately)
|
||||||
|
wc -l DATA1.csv
|
||||||
|
```
|
||||||
|
|
||||||
|
For the dataset provided in a `.csv` or `.parquet` file, you can easily analyze the dataset using the following commands. Plots will be automatically saved.
|
||||||
|
|
||||||
|
```python
|
||||||
|
pyhton -m tools.datasets.analyze DATA_info.csv
|
||||||
|
```
|
||||||
|
|
||||||
|
## Data Process Pipeline
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Suppose videos and images under ~/dataset/
|
||||||
|
# 1. Convert dataset to CSV
|
||||||
|
python -m tools.datasets.convert video ~/dataset --output meta.csv
|
||||||
|
|
||||||
|
# 2. Get video information
|
||||||
|
python -m tools.datasets.datautil meta.csv --info --fmin 1
|
||||||
|
|
||||||
|
# 3. Get caption
|
||||||
|
# 3.1. generate caption
|
||||||
|
torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava meta_info_fmin1.csv --dp-size 8 --tp-size 1 --model-path liuhaotian/llava-v1.6-mistral-7b --prompt video
|
||||||
|
# merge generated results
|
||||||
|
python -m tools.datasets.datautil meta_info_fmin1_caption_part*.csv --output meta_caption.csv
|
||||||
|
# merge caption and info
|
||||||
|
python -m tools.datasets.datautil meta_info_fmin1.csv --intersection meta_caption.csv --output meta_caption_info.csv
|
||||||
|
# clean caption
|
||||||
|
python -m tools.datasets.datautil meta_caption_info.csv --clean-caption --refine-llm-caption --remove-empty-caption --output meta_caption_processed.csv
|
||||||
|
# 3.2. extract caption
|
||||||
|
python -m tools.datasets.datautil meta_info_fmin1.csv --load-caption json --remove-empty-caption --clean-caption
|
||||||
|
|
||||||
|
# 4. Scoring
|
||||||
|
# aesthetic scoring
|
||||||
|
torchrun --standalone --nproc_per_node 8 -m tools.scoring.aesthetic.inference meta_caption_processed.csv
|
||||||
|
python -m tools.datasets.datautil meta_caption_processed_part*.csv --output meta_caption_processed_aes.csv
|
||||||
|
# optical flow scoring
|
||||||
|
torchrun --standalone --nproc_per_node 8 -m tools.scoring.optical_flow.inference meta_caption_processed.csv
|
||||||
|
# matching scoring
|
||||||
|
torchrun --standalone --nproc_per_node 8 -m tools.scoring.matching.inference meta_caption_processed.csv
|
||||||
|
# camera motion
|
||||||
|
python -m tools.caption.camera_motion_detect meta_caption_processed.csv
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,96 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def read_file(input_path):
|
||||||
|
if input_path.endswith(".csv"):
|
||||||
|
return pd.read_csv(input_path)
|
||||||
|
elif input_path.endswith(".parquet"):
|
||||||
|
return pd.read_parquet(input_path)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported file format: {input_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("input", type=str, help="Path to the input dataset")
|
||||||
|
parser.add_argument("--save-img", type=str, default="samples/infos/", help="Path to save the image")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_data(data, column, bins, name):
|
||||||
|
plt.clf()
|
||||||
|
data.hist(column=column, bins=bins)
|
||||||
|
os.makedirs(os.path.dirname(name), exist_ok=True)
|
||||||
|
plt.savefig(name)
|
||||||
|
print(f"Saved {name}")
|
||||||
|
|
||||||
|
|
||||||
|
def plot_categorical_data(data, column, name):
|
||||||
|
plt.clf()
|
||||||
|
data[column].value_counts().plot(kind="bar")
|
||||||
|
os.makedirs(os.path.dirname(name), exist_ok=True)
|
||||||
|
plt.savefig(name)
|
||||||
|
print(f"Saved {name}")
|
||||||
|
|
||||||
|
|
||||||
|
COLUMNS = {
|
||||||
|
"num_frames": 100,
|
||||||
|
"resolution": 100,
|
||||||
|
"text_len": 100,
|
||||||
|
"aes": 100,
|
||||||
|
"match": 100,
|
||||||
|
"flow": 100,
|
||||||
|
"cmotion": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
data = read_file(args.input)
|
||||||
|
|
||||||
|
# === Image Data Info ===
|
||||||
|
image_index = data["num_frames"] == 1
|
||||||
|
if image_index.sum() > 0:
|
||||||
|
print("=== Image Data Info ===")
|
||||||
|
img_data = data[image_index]
|
||||||
|
print(f"Number of images: {len(img_data)}")
|
||||||
|
print(img_data.head())
|
||||||
|
print(img_data.describe())
|
||||||
|
if args.save_img:
|
||||||
|
for column in COLUMNS:
|
||||||
|
if column in img_data.columns and column not in ["num_frames", "cmotion"]:
|
||||||
|
if COLUMNS[column] is None:
|
||||||
|
plot_categorical_data(img_data, column, os.path.join(args.save_img, f"image_{column}.png"))
|
||||||
|
else:
|
||||||
|
plot_data(img_data, column, COLUMNS[column], os.path.join(args.save_img, f"image_{column}.png"))
|
||||||
|
|
||||||
|
# === Video Data Info ===
|
||||||
|
if not image_index.all():
|
||||||
|
print("=== Video Data Info ===")
|
||||||
|
video_data = data[~image_index]
|
||||||
|
print(f"Number of videos: {len(video_data)}")
|
||||||
|
if "num_frames" in video_data.columns:
|
||||||
|
total_num_frames = video_data["num_frames"].sum()
|
||||||
|
print(f"Number of frames: {total_num_frames}")
|
||||||
|
DEFAULT_FPS = 30
|
||||||
|
total_hours = total_num_frames / DEFAULT_FPS / 3600
|
||||||
|
print(f"Total hours (30 FPS): {int(total_hours)}")
|
||||||
|
print(video_data.head())
|
||||||
|
print(video_data.describe())
|
||||||
|
if args.save_img:
|
||||||
|
for column in COLUMNS:
|
||||||
|
if column in video_data.columns:
|
||||||
|
if COLUMNS[column] is None:
|
||||||
|
plot_categorical_data(video_data, column, os.path.join(args.save_img, f"video_{column}.png"))
|
||||||
|
else:
|
||||||
|
plot_data(
|
||||||
|
video_data, column, COLUMNS[column], os.path.join(args.save_img, f"video_{column}.png")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,79 @@
|
||||||
|
import argparse
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
tqdm.pandas()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pandarallel import pandarallel
|
||||||
|
|
||||||
|
PANDA_USE_PARALLEL = True
|
||||||
|
except ImportError:
|
||||||
|
PANDA_USE_PARALLEL = False
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
if not shutil.which("ffmpeg"):
|
||||||
|
raise ImportError("FFmpeg is not installed")
|
||||||
|
|
||||||
|
|
||||||
|
def apply(df, func, **kwargs):
|
||||||
|
if PANDA_USE_PARALLEL:
|
||||||
|
return df.parallel_apply(func, **kwargs)
|
||||||
|
return df.progress_apply(func, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def check_video_integrity(video_path):
|
||||||
|
# try:
|
||||||
|
can_open_result = subprocess.run(
|
||||||
|
["ffmpeg", "-v", "error", "-i", video_path, "-t", "0", "-f", "null", "-"], # open video and capture 0 seconds
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
fast_scan_result = subprocess.run(
|
||||||
|
["ffmpeg", "-v", "error", "-analyzeduration", "10M", "-probesize", "10M", "-i", video_path, "-f", "null", "-"],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
if can_open_result.stderr == "" and fast_scan_result.stderr == "":
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
# except Exception as e:
|
||||||
|
# return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("input", type=str, help="path to the input dataset")
|
||||||
|
parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing")
|
||||||
|
parser.add_argument("--num-workers", type=int, default=None, help="number of workers")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.disable_parallel:
|
||||||
|
PANDA_USE_PARALLEL = False
|
||||||
|
if PANDA_USE_PARALLEL:
|
||||||
|
if args.num_workers is not None:
|
||||||
|
pandarallel.initialize(nb_workers=args.num_workers, progress_bar=True)
|
||||||
|
else:
|
||||||
|
pandarallel.initialize(progress_bar=True)
|
||||||
|
|
||||||
|
data = pd.read_csv(args.input)
|
||||||
|
assert "path" in data.columns
|
||||||
|
data["integrity"] = apply(data["path"], check_video_integrity)
|
||||||
|
|
||||||
|
integrity_file_path = args.input.replace(".csv", "_intact.csv")
|
||||||
|
broken_file_path = args.input.replace(".csv", "_broken.csv")
|
||||||
|
|
||||||
|
intact_data = data[data["integrity"] == True].drop(columns=["integrity"])
|
||||||
|
intact_data.to_csv(integrity_file_path, index=False)
|
||||||
|
broken_data = data[data["integrity"] == False].drop(columns=["integrity"])
|
||||||
|
broken_data.to_csv(broken_file_path, index=False)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Integrity check completed. Intact videos saved to: {integrity_file_path}, broken videos saved to {broken_file_path}."
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,144 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from torchvision.datasets import ImageNet
|
||||||
|
|
||||||
|
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
|
||||||
|
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv", ".m2ts")
|
||||||
|
|
||||||
|
|
||||||
|
def scan_recursively(root):
|
||||||
|
num = 0
|
||||||
|
for entry in os.scandir(root):
|
||||||
|
if entry.is_file():
|
||||||
|
yield entry
|
||||||
|
elif entry.is_dir():
|
||||||
|
num += 1
|
||||||
|
if num % 100 == 0:
|
||||||
|
print(f"Scanned {num} directories.")
|
||||||
|
yield from scan_recursively(entry.path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_filelist(file_path, exts=None):
|
||||||
|
filelist = []
|
||||||
|
time_start = time.time()
|
||||||
|
|
||||||
|
# == OS Walk ==
|
||||||
|
# for home, dirs, files in os.walk(file_path):
|
||||||
|
# for filename in files:
|
||||||
|
# ext = os.path.splitext(filename)[-1].lower()
|
||||||
|
# if exts is None or ext in exts:
|
||||||
|
# filelist.append(os.path.join(home, filename))
|
||||||
|
|
||||||
|
# == Scandir ==
|
||||||
|
obj = scan_recursively(file_path)
|
||||||
|
for entry in obj:
|
||||||
|
if entry.is_file():
|
||||||
|
ext = os.path.splitext(entry.name)[-1].lower()
|
||||||
|
if exts is None or ext in exts:
|
||||||
|
filelist.append(entry.path)
|
||||||
|
|
||||||
|
time_end = time.time()
|
||||||
|
print(f"Scanned {len(filelist)} files in {time_end - time_start:.2f} seconds.")
|
||||||
|
return filelist
|
||||||
|
|
||||||
|
|
||||||
|
def split_by_capital(name):
|
||||||
|
# BoxingPunchingBag -> Boxing Punching Bag
|
||||||
|
new_name = ""
|
||||||
|
for i in range(len(name)):
|
||||||
|
if name[i].isupper() and i != 0:
|
||||||
|
new_name += " "
|
||||||
|
new_name += name[i]
|
||||||
|
return new_name
|
||||||
|
|
||||||
|
|
||||||
|
def process_imagenet(root, split):
|
||||||
|
root = os.path.expanduser(root)
|
||||||
|
data = ImageNet(root, split=split)
|
||||||
|
samples = [(path, data.classes[label][0]) for path, label in data.samples]
|
||||||
|
output = f"imagenet_{split}.csv"
|
||||||
|
|
||||||
|
df = pd.DataFrame(samples, columns=["path", "text"])
|
||||||
|
df.to_csv(output, index=False)
|
||||||
|
print(f"Saved {len(samples)} samples to {output}.")
|
||||||
|
|
||||||
|
|
||||||
|
def process_ucf101(root, split):
|
||||||
|
root = os.path.expanduser(root)
|
||||||
|
video_lists = get_filelist(os.path.join(root, split))
|
||||||
|
classes = [x.split("/")[-2] for x in video_lists]
|
||||||
|
classes = [split_by_capital(x) for x in classes]
|
||||||
|
samples = list(zip(video_lists, classes))
|
||||||
|
output = f"ucf101_{split}.csv"
|
||||||
|
|
||||||
|
df = pd.DataFrame(samples, columns=["path", "text"])
|
||||||
|
df.to_csv(output, index=False)
|
||||||
|
print(f"Saved {len(samples)} samples to {output}.")
|
||||||
|
|
||||||
|
|
||||||
|
def process_vidprom(root, info):
|
||||||
|
root = os.path.expanduser(root)
|
||||||
|
video_lists = get_filelist(root)
|
||||||
|
video_set = set(video_lists)
|
||||||
|
# read info csv
|
||||||
|
infos = pd.read_csv(info)
|
||||||
|
abs_path = infos["uuid"].apply(lambda x: os.path.join(root, f"pika-{x}.mp4"))
|
||||||
|
is_exist = abs_path.apply(lambda x: x in video_set)
|
||||||
|
df = pd.DataFrame(dict(path=abs_path[is_exist], text=infos["prompt"][is_exist]))
|
||||||
|
df.to_csv("vidprom.csv", index=False)
|
||||||
|
print(f"Saved {len(df)} samples to vidprom.csv.")
|
||||||
|
|
||||||
|
|
||||||
|
def process_general_images(root, output):
|
||||||
|
root = os.path.expanduser(root)
|
||||||
|
if not os.path.exists(root):
|
||||||
|
return
|
||||||
|
path_list = get_filelist(root, IMG_EXTENSIONS)
|
||||||
|
fname_list = [os.path.splitext(os.path.basename(x))[0] for x in path_list]
|
||||||
|
relpath_list = [os.path.relpath(x, root) for x in path_list]
|
||||||
|
df = pd.DataFrame(dict(path=path_list, id=fname_list, relpath=relpath_list))
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(output), exist_ok=True)
|
||||||
|
df.to_csv(output, index=False)
|
||||||
|
print(f"Saved {len(df)} samples to {output}.")
|
||||||
|
|
||||||
|
|
||||||
|
def process_general_videos(root, output):
|
||||||
|
root = os.path.expanduser(root)
|
||||||
|
if not os.path.exists(root):
|
||||||
|
return
|
||||||
|
path_list = get_filelist(root, VID_EXTENSIONS)
|
||||||
|
path_list = list(set(path_list)) # remove duplicates
|
||||||
|
fname_list = [os.path.splitext(os.path.basename(x))[0] for x in path_list]
|
||||||
|
relpath_list = [os.path.relpath(x, root) for x in path_list]
|
||||||
|
df = pd.DataFrame(dict(path=path_list, id=fname_list, relpath=relpath_list))
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(output), exist_ok=True)
|
||||||
|
df.to_csv(output, index=False)
|
||||||
|
print(f"Saved {len(df)} samples to {output}.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101", "vidprom", "image", "video"])
|
||||||
|
parser.add_argument("root", type=str)
|
||||||
|
parser.add_argument("--split", type=str, default="train")
|
||||||
|
parser.add_argument("--info", type=str, default=None)
|
||||||
|
parser.add_argument("--output", type=str, default=None, required=True, help="Output path")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.dataset == "imagenet":
|
||||||
|
process_imagenet(args.root, args.split)
|
||||||
|
elif args.dataset == "ucf101":
|
||||||
|
process_ucf101(args.root, args.split)
|
||||||
|
elif args.dataset == "vidprom":
|
||||||
|
process_vidprom(args.root, args.info)
|
||||||
|
elif args.dataset == "image":
|
||||||
|
process_general_images(args.root, args.output)
|
||||||
|
elif args.dataset == "video":
|
||||||
|
process_general_videos(args.root, args.output)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid dataset")
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Convert CSV file to txt file")
|
||||||
|
parser.add_argument("csv_file", type=str, help="CSV file to convert")
|
||||||
|
parser.add_argument("txt_file", type=str, help="TXT file to save")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
data = pd.read_csv(args.csv_file)
|
||||||
|
text = data["text"].to_list()
|
||||||
|
text = "\n".join(text)
|
||||||
|
with open(args.txt_file, "w") as f:
|
||||||
|
f.write(text)
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,262 @@
|
||||||
|
# TODO: remove this file before releasing
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import html
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
tqdm.pandas()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pandarallel import pandarallel
|
||||||
|
|
||||||
|
pandarallel.initialize(progress_bar=True)
|
||||||
|
pandas_has_parallel = True
|
||||||
|
except ImportError:
|
||||||
|
pandas_has_parallel = False
|
||||||
|
|
||||||
|
|
||||||
|
def apply(df, func, **kwargs):
|
||||||
|
if pandas_has_parallel:
|
||||||
|
return df.parallel_apply(func, **kwargs)
|
||||||
|
return df.progress_apply(func, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def basic_clean(text):
|
||||||
|
import ftfy
|
||||||
|
|
||||||
|
text = ftfy.fix_text(text)
|
||||||
|
text = html.unescape(html.unescape(text))
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
BAD_PUNCT_REGEX = re.compile(
|
||||||
|
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||||
|
) # noqa
|
||||||
|
|
||||||
|
|
||||||
|
def clean_caption(caption):
|
||||||
|
import urllib.parse as ul
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
caption = str(caption)
|
||||||
|
caption = ul.unquote_plus(caption)
|
||||||
|
caption = caption.strip().lower()
|
||||||
|
caption = re.sub("<person>", "person", caption)
|
||||||
|
# urls:
|
||||||
|
caption = re.sub(
|
||||||
|
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||||
|
"",
|
||||||
|
caption,
|
||||||
|
) # regex for urls
|
||||||
|
caption = re.sub(
|
||||||
|
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||||
|
"",
|
||||||
|
caption,
|
||||||
|
) # regex for urls
|
||||||
|
# html:
|
||||||
|
caption = BeautifulSoup(caption, features="html.parser").text
|
||||||
|
|
||||||
|
# @<nickname>
|
||||||
|
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||||
|
|
||||||
|
# 31C0—31EF CJK Strokes
|
||||||
|
# 31F0—31FF Katakana Phonetic Extensions
|
||||||
|
# 3200—32FF Enclosed CJK Letters and Months
|
||||||
|
# 3300—33FF CJK Compatibility
|
||||||
|
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||||
|
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||||
|
# 4E00—9FFF CJK Unified Ideographs
|
||||||
|
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||||
|
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||||
|
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||||
|
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||||
|
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||||
|
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||||
|
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||||
|
#######################################################
|
||||||
|
|
||||||
|
# все виды тире / all types of dash --> "-"
|
||||||
|
caption = re.sub(
|
||||||
|
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||||
|
"-",
|
||||||
|
caption,
|
||||||
|
)
|
||||||
|
|
||||||
|
# кавычки к одному стандарту
|
||||||
|
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||||
|
caption = re.sub(r"[‘’]", "'", caption)
|
||||||
|
|
||||||
|
# "
|
||||||
|
caption = re.sub(r""?", "", caption)
|
||||||
|
# &
|
||||||
|
caption = re.sub(r"&", "", caption)
|
||||||
|
|
||||||
|
# ip adresses:
|
||||||
|
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||||
|
|
||||||
|
# article ids:
|
||||||
|
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||||
|
|
||||||
|
# \n
|
||||||
|
caption = re.sub(r"\\n", " ", caption)
|
||||||
|
|
||||||
|
# "#123"
|
||||||
|
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||||
|
# "#12345.."
|
||||||
|
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||||
|
# "123456.."
|
||||||
|
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||||
|
# filenames:
|
||||||
|
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||||
|
|
||||||
|
#
|
||||||
|
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||||
|
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||||
|
|
||||||
|
caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||||
|
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||||
|
|
||||||
|
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||||
|
regex2 = re.compile(r"(?:\-|\_)")
|
||||||
|
if len(re.findall(regex2, caption)) > 3:
|
||||||
|
caption = re.sub(regex2, " ", caption)
|
||||||
|
|
||||||
|
caption = basic_clean(caption)
|
||||||
|
|
||||||
|
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||||
|
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||||
|
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||||
|
|
||||||
|
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||||
|
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||||
|
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||||
|
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||||
|
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||||
|
|
||||||
|
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||||
|
|
||||||
|
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||||
|
|
||||||
|
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||||
|
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||||
|
caption = re.sub(r"\s+", " ", caption)
|
||||||
|
|
||||||
|
caption.strip()
|
||||||
|
|
||||||
|
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||||
|
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||||
|
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||||
|
caption = re.sub(r"^\.\S+$", "", caption)
|
||||||
|
|
||||||
|
return caption.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def get_10m_set():
|
||||||
|
meta_path_10m = "/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m.csv"
|
||||||
|
meta_10m = pd.read_csv(meta_path_10m)
|
||||||
|
|
||||||
|
def process_single_caption(row):
|
||||||
|
text_list = eval(row["caption"])
|
||||||
|
clean_list = [clean_caption(x) for x in text_list]
|
||||||
|
return str(clean_list)
|
||||||
|
|
||||||
|
ret = apply(meta_10m, process_single_caption, axis=1)
|
||||||
|
# ret = meta_10m.progress_apply(process_single_caption, axis=1)
|
||||||
|
print("==> text processed.")
|
||||||
|
|
||||||
|
text_list = []
|
||||||
|
for x in ret:
|
||||||
|
text_list += eval(x)
|
||||||
|
# text_set = text_set.union(set(eval(x)))
|
||||||
|
text_set = set(text_list)
|
||||||
|
# meta_10m['caption_new'] = ret
|
||||||
|
# meta_10m.to_csv('/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m_new-cap.csv')
|
||||||
|
|
||||||
|
# video_id_set = set(meta_10m['videoID'])
|
||||||
|
# id2t = {}
|
||||||
|
# for idx, row in tqdm(meta_10m.iterrows(), total=len(meta_10m)):
|
||||||
|
# video_id = row['videoID']
|
||||||
|
# text_list = eval(row['caption'])
|
||||||
|
# id2t[video_id] = set(text_list)
|
||||||
|
|
||||||
|
print(f"==> Loaded meta_10m from '{meta_path_10m}'")
|
||||||
|
return text_set
|
||||||
|
|
||||||
|
|
||||||
|
def filter_panda10m_text(meta_path, text_set):
|
||||||
|
def process_single_row(row):
|
||||||
|
# path = row['path']
|
||||||
|
t = row["text"]
|
||||||
|
# fname = os.path.basename(path)
|
||||||
|
# video_id = fname[:fname.rindex('_')]
|
||||||
|
if t not in text_set:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
meta = pd.read_csv(meta_path)
|
||||||
|
ret = apply(meta, process_single_row, axis=1)
|
||||||
|
# ret = meta.progress_apply(process_single_row, axis=1)
|
||||||
|
|
||||||
|
meta = meta[ret]
|
||||||
|
wo_ext, ext = os.path.splitext(meta_path)
|
||||||
|
out_path = f"{wo_ext}_filter-10m{ext}"
|
||||||
|
meta.to_csv(out_path, index=False)
|
||||||
|
print(f"New meta (shape={meta.shape}) saved to '{out_path}'.")
|
||||||
|
|
||||||
|
|
||||||
|
def filter_panda10m_timestamp(meta_path):
|
||||||
|
meta_path_10m = "/mnt/hdd/data/Panda-70M/raw/meta/train/panda70m_training_10m.csv"
|
||||||
|
meta_10m = pd.read_csv(meta_path_10m)
|
||||||
|
|
||||||
|
id2t = {}
|
||||||
|
for idx, row in tqdm(meta_10m.iterrows(), total=len(meta_10m)):
|
||||||
|
video_id = row["videoID"]
|
||||||
|
timestamp = eval(row["timestamp"])
|
||||||
|
timestamp = [str(tuple(x)) for x in timestamp]
|
||||||
|
id2t[video_id] = timestamp
|
||||||
|
|
||||||
|
# video_id_set_10m = set(meta_10m['videoID'])
|
||||||
|
print(f"==> Loaded meta_10m from '{meta_path_10m}'")
|
||||||
|
|
||||||
|
def process_single_row(row):
|
||||||
|
path = row["path"]
|
||||||
|
t = row["timestamp"]
|
||||||
|
fname = os.path.basename(path)
|
||||||
|
video_id = fname[: fname.rindex("_")]
|
||||||
|
if video_id not in id2t:
|
||||||
|
return False
|
||||||
|
if t not in id2t[video_id]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
# return video_id in video_id_set_10m
|
||||||
|
|
||||||
|
meta = pd.read_csv(meta_path)
|
||||||
|
ret = apply(meta, process_single_row, axis=1)
|
||||||
|
|
||||||
|
meta = meta[ret]
|
||||||
|
wo_ext, ext = os.path.splitext(meta_path)
|
||||||
|
out_path = f"{wo_ext}_filter-10m{ext}"
|
||||||
|
meta.to_csv(out_path, index=False)
|
||||||
|
print(f"New meta (shape={meta.shape}) saved to '{out_path}'.")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--meta_path", type=str, nargs="+")
|
||||||
|
parser.add_argument("--num_workers", default=5, type=int)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
text_set = get_10m_set()
|
||||||
|
for x in args.meta_path:
|
||||||
|
filter_panda10m_text(x, text_set)
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
tqdm.pandas()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pandarallel import pandarallel
|
||||||
|
|
||||||
|
PANDA_USE_PARALLEL = True
|
||||||
|
except ImportError:
|
||||||
|
PANDA_USE_PARALLEL = False
|
||||||
|
|
||||||
|
|
||||||
|
def save_first_frame(video_path, img_dir):
|
||||||
|
if not os.path.exists(video_path):
|
||||||
|
print(f"Video not found: {video_path}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
success, frame = cap.read()
|
||||||
|
if success:
|
||||||
|
video_name = os.path.basename(video_path)
|
||||||
|
image_name = os.path.splitext(video_name)[0] + "_first_frame.jpg"
|
||||||
|
image_path = os.path.join(img_dir, image_name)
|
||||||
|
|
||||||
|
cv2.imwrite(image_path, frame)
|
||||||
|
else:
|
||||||
|
raise ValueError("Video broken.")
|
||||||
|
cap.release()
|
||||||
|
return image_path
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Save first frame of `{video_path}` failed. {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("input", type=str, help="path to the input csv dataset")
|
||||||
|
parser.add_argument("--img-dir", type=str, help="path to save first frame image")
|
||||||
|
parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing")
|
||||||
|
parser.add_argument("--num-workers", type=int, default=None, help="number of workers")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.disable_parallel:
|
||||||
|
PANDA_USE_PARALLEL = False
|
||||||
|
if PANDA_USE_PARALLEL:
|
||||||
|
if args.num_workers is not None:
|
||||||
|
pandarallel.initialize(nb_workers=args.num_workers, progress_bar=True)
|
||||||
|
else:
|
||||||
|
pandarallel.initialize(progress_bar=True)
|
||||||
|
|
||||||
|
if not os.path.exists(args.img_dir):
|
||||||
|
os.makedirs(args.img_dir)
|
||||||
|
|
||||||
|
data = pd.read_csv(args.input)
|
||||||
|
|
||||||
|
data["first_frame_path"] = data["path"].parallel_apply(save_first_frame, img_dir=args.img_dir)
|
||||||
|
data_filtered = data.loc[data["first_frame_path"] != ""]
|
||||||
|
output_csv_path = args.input.replace(".csv", "_first-frame.csv")
|
||||||
|
data_filtered.to_csv(output_csv_path, index=False)
|
||||||
|
print(f"First frame csv saved to: {output_csv_path}, first frame images saved to {args.img_dir}.")
|
||||||
|
|
@ -0,0 +1,72 @@
|
||||||
|
import argparse
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from mmengine.config import Config
|
||||||
|
|
||||||
|
from opensora.datasets.bucket import Bucket
|
||||||
|
|
||||||
|
|
||||||
|
def split_by_bucket(
|
||||||
|
bucket: Bucket,
|
||||||
|
input_files: List[str],
|
||||||
|
output_path: str,
|
||||||
|
limit: int,
|
||||||
|
frame_interval: int,
|
||||||
|
):
|
||||||
|
print(f"Split {len(input_files)} files into {len(bucket)} buckets")
|
||||||
|
total_limit = len(bucket) * limit
|
||||||
|
bucket_cnt = {}
|
||||||
|
# get all bucket id
|
||||||
|
for hw_id, d in bucket.ar_criteria.items():
|
||||||
|
for t_id, v in d.items():
|
||||||
|
for ar_id in v.keys():
|
||||||
|
bucket_id = (hw_id, t_id, ar_id)
|
||||||
|
bucket_cnt[bucket_id] = 0
|
||||||
|
output_df = None
|
||||||
|
# split files
|
||||||
|
for path in input_files:
|
||||||
|
df = pd.read_csv(path)
|
||||||
|
if output_df is None:
|
||||||
|
output_df = pd.DataFrame(columns=df.columns)
|
||||||
|
for i in range(len(df)):
|
||||||
|
row = df.iloc[i]
|
||||||
|
t, h, w = row["num_frames"], row["height"], row["width"]
|
||||||
|
bucket_id = bucket.get_bucket_id(t, h, w, frame_interval)
|
||||||
|
if bucket_id is None:
|
||||||
|
continue
|
||||||
|
if bucket_cnt[bucket_id] < limit:
|
||||||
|
bucket_cnt[bucket_id] += 1
|
||||||
|
output_df = pd.concat([output_df, pd.DataFrame([row])], ignore_index=True)
|
||||||
|
if len(output_df) >= total_limit:
|
||||||
|
break
|
||||||
|
if len(output_df) >= total_limit:
|
||||||
|
break
|
||||||
|
assert len(output_df) <= total_limit
|
||||||
|
if len(output_df) == total_limit:
|
||||||
|
print(f"All buckets are full ({total_limit} samples)")
|
||||||
|
else:
|
||||||
|
print(f"Only {len(output_df)} files are used")
|
||||||
|
output_df.to_csv(output_path, index=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("input", type=str, nargs="+")
|
||||||
|
parser.add_argument("-o", "--output", required=True)
|
||||||
|
parser.add_argument("-c", "--config", required=True)
|
||||||
|
parser.add_argument("-l", "--limit", default=200, type=int)
|
||||||
|
args = parser.parse_args()
|
||||||
|
assert args.limit > 0
|
||||||
|
|
||||||
|
cfg = Config.fromfile(args.config)
|
||||||
|
bucket_config = cfg.bucket_config
|
||||||
|
# rewrite bucket_config
|
||||||
|
for ar, d in bucket_config.items():
|
||||||
|
for frames, t in d.items():
|
||||||
|
p, bs = t
|
||||||
|
if p > 0.0:
|
||||||
|
p = 1.0
|
||||||
|
d[frames] = (p, bs)
|
||||||
|
bucket = Bucket(bucket_config)
|
||||||
|
split_by_bucket(bucket, args.input, args.output, args.limit, cfg.dataset.frame_interval)
|
||||||
|
|
@ -0,0 +1,306 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import ffmpeg
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from pandarallel import pandarallel
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .utils import IMG_EXTENSIONS, extract_frames
|
||||||
|
|
||||||
|
tqdm.pandas()
|
||||||
|
USE_PANDARALLEL = True
|
||||||
|
|
||||||
|
|
||||||
|
def apply(df, func, **kwargs):
|
||||||
|
if USE_PANDARALLEL:
|
||||||
|
return df.parallel_apply(func, **kwargs)
|
||||||
|
return df.progress_apply(func, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_new_path(path, input_dir, output):
|
||||||
|
path_new = os.path.join(output, os.path.relpath(path, input_dir))
|
||||||
|
os.makedirs(os.path.dirname(path_new), exist_ok=True)
|
||||||
|
return path_new
|
||||||
|
|
||||||
|
|
||||||
|
def resize_longer(path, length, input_dir, output_dir):
|
||||||
|
path_new = get_new_path(path, input_dir, output_dir)
|
||||||
|
ext = os.path.splitext(path)[1].lower()
|
||||||
|
assert ext in IMG_EXTENSIONS
|
||||||
|
img = cv2.imread(path)
|
||||||
|
if img is not None:
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
if min(h, w) > length:
|
||||||
|
if h > w:
|
||||||
|
new_h = length
|
||||||
|
new_w = int(w / h * length)
|
||||||
|
else:
|
||||||
|
new_w = length
|
||||||
|
new_h = int(h / w * length)
|
||||||
|
img = cv2.resize(img, (new_w, new_h))
|
||||||
|
cv2.imwrite(path_new, img)
|
||||||
|
else:
|
||||||
|
path_new = ""
|
||||||
|
return path_new
|
||||||
|
|
||||||
|
|
||||||
|
def resize_shorter(path, length, input_dir, output_dir):
|
||||||
|
path_new = get_new_path(path, input_dir, output_dir)
|
||||||
|
if os.path.exists(path_new):
|
||||||
|
return path_new
|
||||||
|
|
||||||
|
ext = os.path.splitext(path)[1].lower()
|
||||||
|
assert ext in IMG_EXTENSIONS
|
||||||
|
img = cv2.imread(path)
|
||||||
|
if img is not None:
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
if min(h, w) > length:
|
||||||
|
if h > w:
|
||||||
|
new_w = length
|
||||||
|
new_h = int(h / w * length)
|
||||||
|
else:
|
||||||
|
new_h = length
|
||||||
|
new_w = int(w / h * length)
|
||||||
|
img = cv2.resize(img, (new_w, new_h))
|
||||||
|
cv2.imwrite(path_new, img)
|
||||||
|
else:
|
||||||
|
path_new = ""
|
||||||
|
return path_new
|
||||||
|
|
||||||
|
|
||||||
|
def rand_crop(path, input_dir, output):
|
||||||
|
ext = os.path.splitext(path)[1].lower()
|
||||||
|
path_new = get_new_path(path, input_dir, output)
|
||||||
|
assert ext in IMG_EXTENSIONS
|
||||||
|
img = cv2.imread(path)
|
||||||
|
if img is not None:
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
width, height, _ = img.shape
|
||||||
|
pos = random.randint(0, 3)
|
||||||
|
if pos == 0:
|
||||||
|
img_cropped = img[: width // 2, : height // 2]
|
||||||
|
elif pos == 1:
|
||||||
|
img_cropped = img[width // 2 :, : height // 2]
|
||||||
|
elif pos == 2:
|
||||||
|
img_cropped = img[: width // 2, height // 2 :]
|
||||||
|
else:
|
||||||
|
img_cropped = img[width // 2 :, height // 2 :]
|
||||||
|
cv2.imwrite(path_new, img_cropped)
|
||||||
|
else:
|
||||||
|
path_new = ""
|
||||||
|
return path_new
|
||||||
|
|
||||||
|
|
||||||
|
def m2ts_to_mp4(row, output_dir):
|
||||||
|
input_path = row["path"]
|
||||||
|
output_name = os.path.basename(input_path).replace(".m2ts", ".mp4")
|
||||||
|
output_path = os.path.join(output_dir, output_name)
|
||||||
|
# create directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
try:
|
||||||
|
ffmpeg.input(input_path).output(output_path).overwrite_output().global_args("-loglevel", "quiet").run(
|
||||||
|
capture_stdout=True
|
||||||
|
)
|
||||||
|
row["path"] = output_path
|
||||||
|
row["relpath"] = os.path.splitext(row["relpath"])[0] + ".mp4"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error converting {input_path} to mp4: {e}")
|
||||||
|
row["path"] = ""
|
||||||
|
row["relpath"] = ""
|
||||||
|
return row
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
|
def mkv_to_mp4(row, output_dir):
|
||||||
|
# str_to_replace and str_to_replace_with account for the different directory structure
|
||||||
|
input_path = row["path"]
|
||||||
|
output_name = os.path.basename(input_path).replace(".mkv", ".mp4")
|
||||||
|
output_path = os.path.join(output_dir, output_name)
|
||||||
|
|
||||||
|
# create directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
ffmpeg.input(input_path).output(output_path).overwrite_output().global_args("-loglevel", "quiet").run(
|
||||||
|
capture_stdout=True
|
||||||
|
)
|
||||||
|
row["path"] = output_path
|
||||||
|
row["relpath"] = os.path.splitext(row["relpath"])[0] + ".mp4"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error converting {input_path} to mp4: {e}")
|
||||||
|
row["path"] = ""
|
||||||
|
row["relpath"] = ""
|
||||||
|
return row
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
|
def mp4_to_mp4(row, output_dir):
|
||||||
|
# str_to_replace and str_to_replace_with account for the different directory structure
|
||||||
|
input_path = row["path"]
|
||||||
|
|
||||||
|
# 检查输入文件是否为.mp4文件
|
||||||
|
if not input_path.lower().endswith(".mp4"):
|
||||||
|
print(f"Error: {input_path} is not an .mp4 file.")
|
||||||
|
row["path"] = ""
|
||||||
|
row["relpath"] = ""
|
||||||
|
return row
|
||||||
|
output_name = os.path.basename(input_path)
|
||||||
|
output_path = os.path.join(output_dir, output_name)
|
||||||
|
|
||||||
|
# create directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
shutil.copy2(input_path, output_path) # 使用shutil复制文件
|
||||||
|
row["path"] = output_path
|
||||||
|
row["relpath"] = os.path.splitext(row["relpath"])[0] + ".mp4"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error coy {input_path} to mp4: {e}")
|
||||||
|
row["path"] = ""
|
||||||
|
row["relpath"] = ""
|
||||||
|
return row
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
|
def crop_to_square(input_path, output_path):
|
||||||
|
cmd = (
|
||||||
|
f"ffmpeg -i {input_path} "
|
||||||
|
f"-vf \"crop='min(in_w,in_h)':'min(in_w,in_h)':'(in_w-min(in_w,in_h))/2':'(in_h-min(in_w,in_h))/2'\" "
|
||||||
|
f"-c:v libx264 -an "
|
||||||
|
f"-map 0:v {output_path}"
|
||||||
|
)
|
||||||
|
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True)
|
||||||
|
stdout, stderr = proc.communicate()
|
||||||
|
|
||||||
|
|
||||||
|
def vid_crop_center(row, input_dir, output_dir):
|
||||||
|
input_path = row["path"]
|
||||||
|
relpath = os.path.relpath(input_path, input_dir)
|
||||||
|
assert not relpath.startswith("..")
|
||||||
|
output_path = os.path.join(output_dir, relpath)
|
||||||
|
|
||||||
|
# create directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
crop_to_square(input_path, output_path)
|
||||||
|
size = min(row["height"], row["width"])
|
||||||
|
row["path"] = output_path
|
||||||
|
row["height"] = size
|
||||||
|
row["width"] = size
|
||||||
|
row["aspect_ratio"] = 1.0
|
||||||
|
row["resolution"] = size**2
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error cropping {input_path} to center: {e}")
|
||||||
|
row["path"] = ""
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
global USE_PANDARALLEL
|
||||||
|
|
||||||
|
assert args.num_workers is None or not args.disable_parallel
|
||||||
|
if args.disable_parallel:
|
||||||
|
USE_PANDARALLEL = False
|
||||||
|
if args.num_workers is not None:
|
||||||
|
pandarallel.initialize(progress_bar=True, nb_workers=args.num_workers)
|
||||||
|
else:
|
||||||
|
pandarallel.initialize(progress_bar=True)
|
||||||
|
|
||||||
|
random.seed(args.seed)
|
||||||
|
data = pd.read_csv(args.meta_path)
|
||||||
|
if args.task == "img_rand_crop":
|
||||||
|
data["path"] = apply(data["path"], lambda x: rand_crop(x, args.input_dir, args.output_dir))
|
||||||
|
output_csv = args.meta_path.replace(".csv", "_rand_crop.csv")
|
||||||
|
elif args.task == "img_resize_longer":
|
||||||
|
data["path"] = apply(data["path"], lambda x: resize_longer(x, args.length, args.input_dir, args.output_dir))
|
||||||
|
output_csv = args.meta_path.replace(".csv", f"_resize-longer-{args.length}.csv")
|
||||||
|
elif args.task == "img_resize_shorter":
|
||||||
|
data["path"] = apply(data["path"], lambda x: resize_shorter(x, args.length, args.input_dir, args.output_dir))
|
||||||
|
output_csv = args.meta_path.replace(".csv", f"_resize-shorter-{args.length}.csv")
|
||||||
|
elif args.task == "vid_frame_extract":
|
||||||
|
points = args.points if args.points is not None else args.points_index
|
||||||
|
data = pd.DataFrame(np.repeat(data.values, 3, axis=0), columns=data.columns)
|
||||||
|
num_points = len(points)
|
||||||
|
data["point"] = np.nan
|
||||||
|
for i, point in enumerate(points):
|
||||||
|
if isinstance(point, int):
|
||||||
|
data.loc[i::num_points, "point"] = point
|
||||||
|
else:
|
||||||
|
data.loc[i::num_points, "point"] = data.loc[i::num_points, "num_frames"] * point
|
||||||
|
data["path"] = apply(
|
||||||
|
data, lambda x: extract_frames(x["path"], args.input_dir, args.output_dir, x["point"]), axis=1
|
||||||
|
)
|
||||||
|
output_csv = args.meta_path.replace(".csv", "_vid_frame_extract.csv")
|
||||||
|
elif args.task == "m2ts_to_mp4":
|
||||||
|
print(f"m2ts_to_mp4作业开始:{args.output_dir}")
|
||||||
|
assert args.meta_path.endswith("_m2ts.csv"), "Input file must end with '_m2ts.csv'"
|
||||||
|
m2ts_to_mp4_partial = lambda x: m2ts_to_mp4(x, args.output_dir)
|
||||||
|
data = apply(data, m2ts_to_mp4_partial, axis=1)
|
||||||
|
data = data[data["path"] != ""]
|
||||||
|
output_csv = args.meta_path.replace("_m2ts.csv", ".csv")
|
||||||
|
elif args.task == "mkv_to_mp4":
|
||||||
|
print(f"mkv_to_mp4作业开始:{args.output_dir}")
|
||||||
|
assert args.meta_path.endswith("_mkv.csv"), "Input file must end with '_mkv.csv'"
|
||||||
|
mkv_to_mp4_partial = lambda x: mkv_to_mp4(x, args.output_dir)
|
||||||
|
data = apply(data, mkv_to_mp4_partial, axis=1)
|
||||||
|
data = data[data["path"] != ""]
|
||||||
|
output_csv = args.meta_path.replace("_mkv.csv", ".csv")
|
||||||
|
elif args.task == "mp4_to_mp4":
|
||||||
|
# assert args.meta_path.endswith("meta.csv"), "Input file must end with '_mkv.csv'"
|
||||||
|
print(f"MP4复制作业开始:{args.output_dir}")
|
||||||
|
mkv_to_mp4_partial = lambda x: mp4_to_mp4(x, args.output_dir)
|
||||||
|
data = apply(data, mkv_to_mp4_partial, axis=1)
|
||||||
|
data = data[data["path"] != ""]
|
||||||
|
output_csv = args.meta_path
|
||||||
|
elif args.task == "vid_crop_center":
|
||||||
|
vid_crop_center_partial = lambda x: vid_crop_center(x, args.input_dir, args.output_dir)
|
||||||
|
data = apply(data, vid_crop_center_partial, axis=1)
|
||||||
|
data = data[data["path"] != ""]
|
||||||
|
output_csv = args.meta_path.replace(".csv", "_center-crop.csv")
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
data.to_csv(output_csv, index=False)
|
||||||
|
print(f"Saved to {output_csv}")
|
||||||
|
raise SystemExit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--task",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
choices=[
|
||||||
|
"img_resize_longer",
|
||||||
|
"img_resize_shorter",
|
||||||
|
"img_rand_crop",
|
||||||
|
"vid_frame_extract",
|
||||||
|
"m2ts_to_mp4",
|
||||||
|
"mkv_to_mp4",
|
||||||
|
"mp4_to_mp4",
|
||||||
|
"vid_crop_center",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
parser.add_argument("--meta_path", type=str, required=True)
|
||||||
|
parser.add_argument("--input_dir", type=str)
|
||||||
|
parser.add_argument("--output_dir", type=str)
|
||||||
|
parser.add_argument("--length", type=int, default=1080)
|
||||||
|
parser.add_argument("--disable-parallel", action="store_true")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=None)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="seed for random")
|
||||||
|
parser.add_argument("--points", nargs="+", type=float, default=None)
|
||||||
|
parser.add_argument("--points_index", nargs="+", type=int, default=None)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,130 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
|
||||||
|
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||||
|
|
||||||
|
|
||||||
|
def is_video(filename):
|
||||||
|
ext = os.path.splitext(filename)[-1].lower()
|
||||||
|
return ext in VID_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def extract_frames(
|
||||||
|
video_path,
|
||||||
|
frame_inds=None,
|
||||||
|
points=None,
|
||||||
|
backend="opencv",
|
||||||
|
return_length=False,
|
||||||
|
num_frames=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
video_path (str): path to video
|
||||||
|
frame_inds (List[int]): indices of frames to extract
|
||||||
|
points (List[float]): values within [0, 1); multiply #frames to get frame indices
|
||||||
|
Return:
|
||||||
|
List[PIL.Image]
|
||||||
|
"""
|
||||||
|
assert backend in ["av", "opencv", "decord"]
|
||||||
|
assert (frame_inds is None) or (points is None)
|
||||||
|
|
||||||
|
if backend == "av":
|
||||||
|
import av
|
||||||
|
|
||||||
|
container = av.open(video_path)
|
||||||
|
if num_frames is not None:
|
||||||
|
total_frames = num_frames
|
||||||
|
else:
|
||||||
|
total_frames = container.streams.video[0].frames
|
||||||
|
|
||||||
|
if points is not None:
|
||||||
|
frame_inds = [int(p * total_frames) for p in points]
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
for idx in frame_inds:
|
||||||
|
if idx >= total_frames:
|
||||||
|
idx = total_frames - 1
|
||||||
|
target_timestamp = int(idx * av.time_base / container.streams.video[0].average_rate)
|
||||||
|
container.seek(target_timestamp) # return the nearest key frame, not the precise timestamp!!!
|
||||||
|
frame = next(container.decode(video=0)).to_image()
|
||||||
|
frames.append(frame)
|
||||||
|
|
||||||
|
if return_length:
|
||||||
|
return frames, total_frames
|
||||||
|
return frames
|
||||||
|
|
||||||
|
elif backend == "decord":
|
||||||
|
import decord
|
||||||
|
|
||||||
|
container = decord.VideoReader(video_path, num_threads=1)
|
||||||
|
if num_frames is not None:
|
||||||
|
total_frames = num_frames
|
||||||
|
else:
|
||||||
|
total_frames = len(container)
|
||||||
|
|
||||||
|
if points is not None:
|
||||||
|
frame_inds = [int(p * total_frames) for p in points]
|
||||||
|
|
||||||
|
frame_inds = np.array(frame_inds).astype(np.int32)
|
||||||
|
frame_inds[frame_inds >= total_frames] = total_frames - 1
|
||||||
|
frames = container.get_batch(frame_inds).asnumpy() # [N, H, W, C]
|
||||||
|
frames = [Image.fromarray(x) for x in frames]
|
||||||
|
|
||||||
|
if return_length:
|
||||||
|
return frames, total_frames
|
||||||
|
return frames
|
||||||
|
|
||||||
|
elif backend == "opencv":
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
if num_frames is not None:
|
||||||
|
total_frames = num_frames
|
||||||
|
else:
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
|
||||||
|
if points is not None:
|
||||||
|
frame_inds = [int(p * total_frames) for p in points]
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
for idx in frame_inds:
|
||||||
|
if idx >= total_frames:
|
||||||
|
idx = total_frames - 1
|
||||||
|
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
||||||
|
|
||||||
|
# HACK: sometimes OpenCV fails to read frames, return a black frame instead
|
||||||
|
try:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
frame = Image.fromarray(frame)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Warning] Error reading frame {idx} from {video_path}: {e}")
|
||||||
|
# First, try to read the first frame
|
||||||
|
try:
|
||||||
|
print(f"[Warning] Try reading first frame.")
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||||
|
ret, frame = cap.read()
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
frame = Image.fromarray(frame)
|
||||||
|
# If that fails, return a black frame
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Warning] Error in reading first frame from {video_path}: {e}")
|
||||||
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
frame = Image.new("RGB", (width, height), (0, 0, 0))
|
||||||
|
|
||||||
|
# HACK: if height or width is 0, return a black frame instead
|
||||||
|
if frame.height == 0 or frame.width == 0:
|
||||||
|
height = width = 256
|
||||||
|
frame = Image.new("RGB", (width, height), (0, 0, 0))
|
||||||
|
|
||||||
|
frames.append(frame)
|
||||||
|
|
||||||
|
if return_length:
|
||||||
|
return frames, total_frames
|
||||||
|
return frames
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
Loading…
Reference in New Issue