mysora/opensora/datasets/datasets.py

316 lines
11 KiB
Python

import os
import random
import numpy as np
import pandas as pd
import torch
from PIL import ImageFile
from torchvision.datasets.folder import pil_loader
from opensora.registry import DATASETS
from .read_video import read_video
from .utils import get_transforms_image, get_transforms_video, is_img, map_target_fps, read_file, temporal_random_crop
ImageFile.LOAD_TRUNCATED_IMAGES = True
VALID_KEYS = ("neg", "path")
K = 10000
class Iloc:
def __init__(self, data, sharded_folder, sharded_folders, rows_per_shard):
self.data = data
self.sharded_folder = sharded_folder
self.sharded_folders = sharded_folders
self.rows_per_shard = rows_per_shard
def __getitem__(self, index):
return Item(
index,
self.data,
self.sharded_folder,
self.sharded_folders,
self.rows_per_shard,
)
class Item:
def __init__(self, index, data, sharded_folder, sharded_folders, rows_per_shard):
self.index = index
self.data = data
self.sharded_folder = sharded_folder
self.sharded_folders = sharded_folders
self.rows_per_shard = rows_per_shard
def __getitem__(self, key):
index = self.index
if key in self.data.columns:
return self.data[key].iloc[index]
else:
shard_idx = index // self.rows_per_shard
idx = index % self.rows_per_shard
shard_parquet = os.path.join(self.sharded_folder, self.sharded_folders[shard_idx])
try:
text_parquet = pd.read_parquet(shard_parquet, engine="fastparquet")
path = text_parquet["path"].iloc[idx]
assert path == self.data["path"].iloc[index]
except Exception as e:
print(f"Error reading {shard_parquet}: {e}")
raise
return text_parquet[key].iloc[idx]
def to_dict(self):
index = self.index
ret = {}
ret.update(self.data.iloc[index].to_dict())
shard_idx = index // self.rows_per_shard
idx = index % self.rows_per_shard
shard_parquet = os.path.join(self.sharded_folder, self.sharded_folders[shard_idx])
try:
text_parquet = pd.read_parquet(shard_parquet, engine="fastparquet")
path = text_parquet["path"].iloc[idx]
assert path == self.data["path"].iloc[index]
ret.update(text_parquet.iloc[idx].to_dict())
except Exception as e:
print(f"Error reading {shard_parquet}: {e}")
ret.update({"text": ""})
return ret
class EfficientParquet:
def __init__(self, df, sharded_folder):
self.data = df
self.total_rows = len(df)
self.rows_per_shard = (self.total_rows + K - 1) // K
self.sharded_folder = sharded_folder
assert os.path.exists(sharded_folder), f"Sharded folder {sharded_folder} does not exist."
self.sharded_folders = os.listdir(sharded_folder)
self.sharded_folders = sorted(self.sharded_folders)
def __len__(self):
return self.total_rows
@property
def iloc(self):
return Iloc(self.data, self.sharded_folder, self.sharded_folders, self.rows_per_shard)
@DATASETS.register_module("text")
class TextDataset(torch.utils.data.Dataset):
"""
Dataset for text data
"""
def __init__(
self,
data_path: str = None,
tokenize_fn: callable = None,
fps_max: int = 16,
vmaf: bool = False,
memory_efficient: bool = False,
**kwargs,
):
self.data_path = data_path
self.data = read_file(data_path, memory_efficient=memory_efficient)
self.memory_efficient = memory_efficient
self.tokenize_fn = tokenize_fn
self.vmaf = vmaf
if fps_max is not None:
self.fps_max = fps_max
else:
self.fps_max = 999999999
def to_efficient(self):
if self.memory_efficient:
addition_data_path = self.data_path.split(".")[0]
self._data = self.data
self.data = EfficientParquet(self._data, addition_data_path)
def getitem(self, index: int) -> dict:
ret = dict()
sample = self.data.iloc[index].to_dict()
sample_fps = sample.get("fps", np.nan)
new_fps, sampling_interval = map_target_fps(sample_fps, self.fps_max)
ret.update({"sampling_interval": sampling_interval})
if "text" in sample:
ret["text"] = sample.pop("text")
postfixs = []
if new_fps != 0 and self.fps_max < 999:
postfixs.append(f"{new_fps} FPS")
if self.vmaf and "score_vmafmotion" in sample and not np.isnan(sample["score_vmafmotion"]):
postfixs.append(f"{int(sample['score_vmafmotion'] + 0.5)} motion score")
postfix = " " + ", ".join(postfixs) + "." if postfixs else ""
ret["text"] = ret["text"] + postfix
if self.tokenize_fn is not None:
ret.update({k: v.squeeze(0) for k, v in self.tokenize_fn(ret["text"]).items()})
if "ref" in sample: # i2v & v2v reference
ret["ref"] = sample.pop("ref")
# name of the generated sample
if "name" in sample: # sample name (`dataset_idx`)
ret["name"] = sample.pop("name")
else:
ret["index"] = index # use index for name
valid_sample = {k: v for k, v in sample.items() if k in VALID_KEYS}
ret.update(valid_sample)
return ret
def __getitem__(self, index):
return self.getitem(index)
def __len__(self):
return len(self.data)
@DATASETS.register_module("video_text")
class VideoTextDataset(TextDataset):
def __init__(
self,
transform_name: str = None,
bucket_class: str = "Bucket",
rand_sample_interval: int = None, # random sample_interval value from [1, min(rand_sample_interval, video_allowed_max)]
**kwargs,
):
super().__init__(**kwargs)
self.transform_name = transform_name
self.bucket_class = bucket_class
self.rand_sample_interval = rand_sample_interval
def get_image(self, index: int, height: int, width: int) -> dict:
sample = self.data.iloc[index]
path = sample["path"]
# loading
image = pil_loader(path)
# transform
transform = get_transforms_image(self.transform_name, (height, width))
image = transform(image)
# CHW -> CTHW
video = image.unsqueeze(1)
return {"video": video}
def get_video(self, index: int, num_frames: int, height: int, width: int, sampling_interval: int) -> dict:
sample = self.data.iloc[index]
path = sample["path"]
# loading
vframes, vinfo = read_video(path, backend="av")
if self.rand_sample_interval is not None:
# randomly sample from 1 - self.rand_sample_interval
video_allowed_max = min(len(vframes) // num_frames, self.rand_sample_interval)
sampling_interval = random.randint(1, video_allowed_max)
# Sampling video frames
video = temporal_random_crop(vframes, num_frames, sampling_interval)
video = video.clone()
del vframes
# transform
transform = get_transforms_video(self.transform_name, (height, width))
video = transform(video) # T C H W
video = video.permute(1, 0, 2, 3)
ret = {"video": video}
return ret
def get_image_or_video(self, index: int, num_frames: int, height: int, width: int, sampling_interval: int) -> dict:
sample = self.data.iloc[index]
path = sample["path"]
if is_img(path):
return self.get_image(index, height, width)
return self.get_video(index, num_frames, height, width, sampling_interval)
def getitem(self, index: str) -> dict:
# a hack to pass in the (time, height, width) info from sampler
index, num_frames, height, width = [int(val) for val in index.split("-")]
ret = dict()
ret.update(super().getitem(index))
try:
ret.update(self.get_image_or_video(index, num_frames, height, width, ret["sampling_interval"]))
except Exception as e:
path = self.data.iloc[index]["path"]
print(f"video {path}: {e}")
return None
return ret
def __getitem__(self, index):
return self.getitem(index)
@DATASETS.register_module("cached_video_text")
class CachedVideoTextDataset(VideoTextDataset):
def __init__(
self,
transform_name: str = None,
bucket_class: str = "Bucket",
rand_sample_interval: int = None, # random sample_interval value from [1, min(rand_sample_interval, video_allowed_max)]
cached_video: bool = False,
cached_text: bool = False,
return_latents_path: bool = False,
load_original_video: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.transform_name = transform_name
self.bucket_class = bucket_class
self.rand_sample_interval = rand_sample_interval
self.cached_video = cached_video
self.cached_text = cached_text
self.return_latents_path = return_latents_path
self.load_original_video = load_original_video
def get_latents(self, path):
try:
latents = torch.load(path, map_location=torch.device("cpu"))
except Exception as e:
print(f"Error loading latents from {path}: {e}")
return torch.zeros_like(torch.randn(1, 1, 1, 1))
return latents
def get_conditioning_latents(self, index: int) -> dict:
sample = self.data.iloc[index]
latents_path = sample["latents_path"]
text_t5_path = sample["text_t5_path"]
text_clip_path = sample["text_clip_path"]
res = dict()
if self.cached_video:
latents = self.get_latents(latents_path)
res["video_latents"] = latents
if self.cached_text:
text_t5 = self.get_latents(text_t5_path)
text_clip = self.get_latents(text_clip_path)
res["text_t5"] = text_t5
res["text_clip"] = text_clip
if self.return_latents_path:
res["latents_path"] = latents_path
res["text_t5_path"] = text_t5_path
res["text_clip_path"] = text_clip_path
return res
def getitem(self, index: str) -> dict:
# a hack to pass in the (time, height, width) info from sampler
real_index, num_frames, height, width = [int(val) for val in index.split("-")]
ret = dict()
if self.load_original_video:
ret.update(super().getitem(index))
try:
ret.update(self.get_conditioning_latents(real_index))
except Exception as e:
path = self.data.iloc[real_index]["path"]
print(f"video {path}: {e}")
return None
return ret
def __getitem__(self, index):
return self.getitem(index)