mysora/opensora/datasets/utils.py

420 lines
14 KiB
Python

import math
import os
import random
import re
from typing import Any
import numpy as np
import pandas as pd
import requests
import torch
import torch.distributed as dist
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from torchvision.io import write_video
from torchvision.utils import save_image
from . import video_transforms
from .read_video import read_video
try:
import dask.dataframe as dd
SUPPORT_DASK = True
except:
SUPPORT_DASK = False
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
regex = re.compile(
r"^(?:http|ftp)s?://" # http:// or https://
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
r"localhost|" # localhost...
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
r"(?::\d+)?" # optional port
r"(?:/?|[/?]\S+)$",
re.IGNORECASE,
)
def is_img(path):
ext = os.path.splitext(path)[-1].lower()
return ext in IMG_EXTENSIONS
def is_vid(path):
ext = os.path.splitext(path)[-1].lower()
return ext in VID_EXTENSIONS
def is_url(url):
return re.match(regex, url) is not None
def read_file(input_path, memory_efficient=False):
if input_path.endswith(".csv"):
assert not memory_efficient, "Memory efficient mode is not supported for CSV files"
return pd.read_csv(input_path)
elif input_path.endswith(".parquet"):
columns = None
if memory_efficient:
columns = ["path", "num_frames", "height", "width", "aspect_ratio", "fps", "resolution"]
if SUPPORT_DASK:
ret = dd.read_parquet(input_path, columns=columns).compute()
else:
ret = pd.read_parquet(input_path, columns=columns)
return ret
else:
raise NotImplementedError(f"Unsupported file format: {input_path}")
def download_url(input_path):
output_dir = "cache"
os.makedirs(output_dir, exist_ok=True)
base_name = os.path.basename(input_path)
output_path = os.path.join(output_dir, base_name)
img_data = requests.get(input_path).content
with open(output_path, "wb", encoding="utf-8") as handler:
handler.write(img_data)
print(f"URL {input_path} downloaded to {output_path}")
return output_path
def temporal_random_crop(
vframes: torch.Tensor, num_frames: int, frame_interval: int, return_frame_indices: bool = False
) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
total_frames = len(vframes)
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= num_frames
), f"Not enough frames to sample, {end_frame_ind} - {start_frame_ind} < {num_frames}"
frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int)
video = vframes[frame_indices]
if return_frame_indices:
return video, frame_indices
else:
return video
def get_transforms_video(name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[1], "image_size must be square for center crop"
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(image_size[0]),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.ResizeCrop(image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "rand_size_crop":
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomSizedCrop(image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
else:
raise NotImplementedError(f"Transform {name} not implemented")
return transform_video
def get_transforms_image(name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[1], "Image size must be square for center crop"
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "rand_size_crop":
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: rand_size_crop_arr(pil_image, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
else:
raise NotImplementedError(f"Transform {name} not implemented")
return transform
def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)):
image = pil_loader(path)
if transform is None:
transform = get_transforms_image(image_size=image_size, name=transform_name)
image = transform(image)
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
video = video.permute(1, 0, 2, 3)
return video
def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)):
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
if transform is None:
transform = get_transforms_video(image_size=image_size, name=transform_name)
video = transform(vframes) # T C H W
video = video.permute(1, 0, 2, 3)
return video
def read_from_path(path, image_size, transform_name="center"):
if is_url(path):
path = download_url(path)
ext = os.path.splitext(path)[-1].lower()
if ext.lower() in VID_EXTENSIONS:
return read_video_from_path(path, image_size=image_size, transform_name=transform_name)
else:
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
def save_sample(
x,
save_path=None,
fps=8,
normalize=True,
value_range=(-1, 1),
force_video=False,
verbose=True,
crf=23,
):
"""
Args:
x (Tensor): shape [C, T, H, W]
"""
assert x.ndim == 4
if not force_video and x.shape[1] == 1: # T = 1: save as image
save_path += ".png"
x = x.squeeze(1)
save_image([x], save_path, normalize=normalize, value_range=value_range)
else:
save_path += ".mp4"
if normalize:
low, high = value_range
x.clamp_(min=low, max=high)
x.sub_(low).div_(max(high - low, 1e-5))
x = x.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
write_video(save_path, x, fps=fps, video_codec="h264", options={"crf": str(crf)})
if verbose:
print(f"Saved to {save_path}")
return save_path
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
def rand_size_crop_arr(pil_image, image_size):
"""
Randomly crop image for height and width, ranging from image_size[0] to image_size[1]
"""
arr = np.array(pil_image)
# get random target h w
height = random.randint(image_size[0], image_size[1])
width = random.randint(image_size[0], image_size[1])
# ensure that h w are factors of 8
height = height - height % 8
width = width - width % 8
# get random start pos
h_start = random.randint(0, max(len(arr) - height, 0))
w_start = random.randint(0, max(len(arr[0]) - height, 0))
# crop
return Image.fromarray(arr[h_start : h_start + height, w_start : w_start + width])
def resize_crop_to_fill(pil_image, image_size):
w, h = pil_image.size # PIL is (W, H)
th, tw = image_size
rh, rw = th / h, tw / w
if rh > rw:
sh, sw = th, round(w * rh)
image = pil_image.resize((sw, sh), Image.BICUBIC)
i = 0
j = int(round((sw - tw) / 2.0))
else:
sh, sw = round(h * rw), tw
image = pil_image.resize((sw, sh), Image.BICUBIC)
i = int(round((sh - th) / 2.0))
j = 0
arr = np.array(image)
assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
return Image.fromarray(arr[i : i + th, j : j + tw])
def map_target_fps(
fps: float,
max_fps: float,
) -> tuple[float, int]:
"""
Map fps to a new fps that is less than max_fps.
Args:
fps (float): Original fps.
max_fps (float): Maximum fps.
Returns:
tuple[float, int]: New fps and sampling interval.
"""
if math.isnan(fps):
return 0, 1
if fps < max_fps:
return fps, 1
sampling_interval = math.ceil(fps / max_fps)
new_fps = math.floor(fps / sampling_interval)
return new_fps, sampling_interval
def sync_object_across_devices(obj: Any, rank: int = 0):
"""
Synchronizes any picklable object across devices in a PyTorch distributed setting
using `broadcast_object_list` with CUDA support.
Parameters:
obj (Any): The object to synchronize. Can be any picklable object (e.g., list, dict, custom class).
rank (int): The rank of the device from which to broadcast the object state. Default is 0.
Note: Ensure torch.distributed is initialized before using this function and CUDA is available.
"""
# Move the object to a list for broadcasting
object_list = [obj]
# Broadcast the object list from the source rank to all other ranks
dist.broadcast_object_list(object_list, src=rank, device="cuda")
# Retrieve the synchronized object
obj = object_list[0]
return obj
def rescale_image_by_path(path: str, height: int, width: int):
"""
Rescales an image to the specified height and width and saves it back to the original path.
Args:
path (str): The file path of the image.
height (int): The target height of the image.
width (int): The target width of the image.
"""
try:
# read image
image = Image.open(path)
# check if image is valid
if image is None:
raise ValueError("The image is invalid or empty.")
# resize image
resize_transform = transforms.Resize((width, height))
resized_image = resize_transform(image)
# save resized image back to the original path
resized_image.save(path)
except Exception as e:
print(f"Error rescaling image: {e}")
def rescale_video_by_path(path: str, height: int, width: int):
"""
Rescales an MP4 video (without audio) to the specified height and width.
Args:
path (str): The file path of the video.
height (int): The target height of the video.
width (int): The target width of the video.
"""
try:
# Read video and metadata
video, info = read_video(path, backend="av")
# Check if video is valid
if video is None or video.size(0) == 0:
raise ValueError("The video is invalid or empty.")
# Resize video frames using a performant method
resize_transform = transforms.Compose([transforms.Resize((width, height))])
resized_video = torch.stack([resize_transform(frame) for frame in video])
# Save resized video back to the original path
resized_video = resized_video.permute(0, 2, 3, 1)
write_video(path, resized_video, fps=int(info["video_fps"]), video_codec="h264")
except Exception as e:
print(f"Error rescaling video: {e}")
def save_tensor_to_disk(tensor, path, exist_handling="overwrite"):
if os.path.exists(path):
if exist_handling == "ignore":
return
elif exist_handling == "raise":
raise UserWarning(f"File {path} already exists, rewriting!")
torch.save(tensor, path)
def save_tensor_to_internet(tensor, path):
raise NotImplementedError("save_tensor_to_internet is not implemented yet!")
def save_latent(latent, path, exist_handling="overwrite"):
if path.startswith(("http://", "https://", "ftp://", "sftp://")):
save_tensor_to_internet(latent, path)
else:
save_tensor_to_disk(latent, path, exist_handling=exist_handling)
def cache_latents(latents, path, exist_handling="overwrite"):
for i in range(latents.shape[0]):
save_latent(latents[i], path[i], exist_handling=exist_handling)