420 lines
14 KiB
Python
420 lines
14 KiB
Python
import math
|
|
import os
|
|
import random
|
|
import re
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import requests
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
|
from torchvision.io import write_video
|
|
from torchvision.utils import save_image
|
|
|
|
from . import video_transforms
|
|
from .read_video import read_video
|
|
|
|
try:
|
|
import dask.dataframe as dd
|
|
|
|
SUPPORT_DASK = True
|
|
except:
|
|
SUPPORT_DASK = False
|
|
|
|
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
|
|
|
regex = re.compile(
|
|
r"^(?:http|ftp)s?://" # http:// or https://
|
|
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
|
|
r"localhost|" # localhost...
|
|
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
|
|
r"(?::\d+)?" # optional port
|
|
r"(?:/?|[/?]\S+)$",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
def is_img(path):
|
|
ext = os.path.splitext(path)[-1].lower()
|
|
return ext in IMG_EXTENSIONS
|
|
|
|
|
|
def is_vid(path):
|
|
ext = os.path.splitext(path)[-1].lower()
|
|
return ext in VID_EXTENSIONS
|
|
|
|
|
|
def is_url(url):
|
|
return re.match(regex, url) is not None
|
|
|
|
|
|
def read_file(input_path, memory_efficient=False):
|
|
if input_path.endswith(".csv"):
|
|
assert not memory_efficient, "Memory efficient mode is not supported for CSV files"
|
|
return pd.read_csv(input_path)
|
|
elif input_path.endswith(".parquet"):
|
|
columns = None
|
|
if memory_efficient:
|
|
columns = ["path", "num_frames", "height", "width", "aspect_ratio", "fps", "resolution"]
|
|
if SUPPORT_DASK:
|
|
ret = dd.read_parquet(input_path, columns=columns).compute()
|
|
else:
|
|
ret = pd.read_parquet(input_path, columns=columns)
|
|
return ret
|
|
else:
|
|
raise NotImplementedError(f"Unsupported file format: {input_path}")
|
|
|
|
|
|
def download_url(input_path):
|
|
output_dir = "cache"
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
base_name = os.path.basename(input_path)
|
|
output_path = os.path.join(output_dir, base_name)
|
|
img_data = requests.get(input_path).content
|
|
with open(output_path, "wb", encoding="utf-8") as handler:
|
|
handler.write(img_data)
|
|
print(f"URL {input_path} downloaded to {output_path}")
|
|
return output_path
|
|
|
|
|
|
def temporal_random_crop(
|
|
vframes: torch.Tensor, num_frames: int, frame_interval: int, return_frame_indices: bool = False
|
|
) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
|
|
temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
|
|
total_frames = len(vframes)
|
|
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
|
|
|
|
assert (
|
|
end_frame_ind - start_frame_ind >= num_frames
|
|
), f"Not enough frames to sample, {end_frame_ind} - {start_frame_ind} < {num_frames}"
|
|
|
|
frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int)
|
|
video = vframes[frame_indices]
|
|
if return_frame_indices:
|
|
return video, frame_indices
|
|
else:
|
|
return video
|
|
|
|
|
|
def get_transforms_video(name="center", image_size=(256, 256)):
|
|
if name is None:
|
|
return None
|
|
elif name == "center":
|
|
assert image_size[0] == image_size[1], "image_size must be square for center crop"
|
|
transform_video = transforms.Compose(
|
|
[
|
|
video_transforms.ToTensorVideo(), # TCHW
|
|
# video_transforms.RandomHorizontalFlipVideo(),
|
|
video_transforms.UCFCenterCropVideo(image_size[0]),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
|
]
|
|
)
|
|
elif name == "resize_crop":
|
|
transform_video = transforms.Compose(
|
|
[
|
|
video_transforms.ToTensorVideo(), # TCHW
|
|
video_transforms.ResizeCrop(image_size),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
|
]
|
|
)
|
|
elif name == "rand_size_crop":
|
|
transform_video = transforms.Compose(
|
|
[
|
|
video_transforms.ToTensorVideo(), # TCHW
|
|
video_transforms.RandomSizedCrop(image_size),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
|
]
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Transform {name} not implemented")
|
|
return transform_video
|
|
|
|
|
|
def get_transforms_image(name="center", image_size=(256, 256)):
|
|
if name is None:
|
|
return None
|
|
elif name == "center":
|
|
assert image_size[0] == image_size[1], "Image size must be square for center crop"
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])),
|
|
# transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
|
]
|
|
)
|
|
elif name == "resize_crop":
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
|
]
|
|
)
|
|
elif name == "rand_size_crop":
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.Lambda(lambda pil_image: rand_size_crop_arr(pil_image, image_size)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
|
]
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Transform {name} not implemented")
|
|
return transform
|
|
|
|
|
|
def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)):
|
|
image = pil_loader(path)
|
|
if transform is None:
|
|
transform = get_transforms_image(image_size=image_size, name=transform_name)
|
|
image = transform(image)
|
|
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
|
|
video = video.permute(1, 0, 2, 3)
|
|
return video
|
|
|
|
|
|
def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)):
|
|
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
|
if transform is None:
|
|
transform = get_transforms_video(image_size=image_size, name=transform_name)
|
|
video = transform(vframes) # T C H W
|
|
video = video.permute(1, 0, 2, 3)
|
|
return video
|
|
|
|
|
|
def read_from_path(path, image_size, transform_name="center"):
|
|
if is_url(path):
|
|
path = download_url(path)
|
|
ext = os.path.splitext(path)[-1].lower()
|
|
if ext.lower() in VID_EXTENSIONS:
|
|
return read_video_from_path(path, image_size=image_size, transform_name=transform_name)
|
|
else:
|
|
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
|
return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
|
|
|
|
|
|
def save_sample(
|
|
x,
|
|
save_path=None,
|
|
fps=8,
|
|
normalize=True,
|
|
value_range=(-1, 1),
|
|
force_video=False,
|
|
verbose=True,
|
|
crf=23,
|
|
):
|
|
"""
|
|
Args:
|
|
x (Tensor): shape [C, T, H, W]
|
|
"""
|
|
assert x.ndim == 4
|
|
|
|
if not force_video and x.shape[1] == 1: # T = 1: save as image
|
|
save_path += ".png"
|
|
x = x.squeeze(1)
|
|
save_image([x], save_path, normalize=normalize, value_range=value_range)
|
|
else:
|
|
save_path += ".mp4"
|
|
if normalize:
|
|
low, high = value_range
|
|
x.clamp_(min=low, max=high)
|
|
x.sub_(low).div_(max(high - low, 1e-5))
|
|
|
|
x = x.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
|
|
|
|
write_video(save_path, x, fps=fps, video_codec="h264", options={"crf": str(crf)})
|
|
if verbose:
|
|
print(f"Saved to {save_path}")
|
|
return save_path
|
|
|
|
|
|
def center_crop_arr(pil_image, image_size):
|
|
"""
|
|
Center cropping implementation from ADM.
|
|
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
|
"""
|
|
while min(*pil_image.size) >= 2 * image_size:
|
|
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
|
|
|
|
scale = image_size / min(*pil_image.size)
|
|
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
|
|
|
|
arr = np.array(pil_image)
|
|
crop_y = (arr.shape[0] - image_size) // 2
|
|
crop_x = (arr.shape[1] - image_size) // 2
|
|
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
|
|
|
|
|
|
def rand_size_crop_arr(pil_image, image_size):
|
|
"""
|
|
Randomly crop image for height and width, ranging from image_size[0] to image_size[1]
|
|
"""
|
|
arr = np.array(pil_image)
|
|
|
|
# get random target h w
|
|
height = random.randint(image_size[0], image_size[1])
|
|
width = random.randint(image_size[0], image_size[1])
|
|
# ensure that h w are factors of 8
|
|
height = height - height % 8
|
|
width = width - width % 8
|
|
|
|
# get random start pos
|
|
h_start = random.randint(0, max(len(arr) - height, 0))
|
|
w_start = random.randint(0, max(len(arr[0]) - height, 0))
|
|
|
|
# crop
|
|
return Image.fromarray(arr[h_start : h_start + height, w_start : w_start + width])
|
|
|
|
|
|
def resize_crop_to_fill(pil_image, image_size):
|
|
w, h = pil_image.size # PIL is (W, H)
|
|
th, tw = image_size
|
|
rh, rw = th / h, tw / w
|
|
if rh > rw:
|
|
sh, sw = th, round(w * rh)
|
|
image = pil_image.resize((sw, sh), Image.BICUBIC)
|
|
i = 0
|
|
j = int(round((sw - tw) / 2.0))
|
|
else:
|
|
sh, sw = round(h * rw), tw
|
|
image = pil_image.resize((sw, sh), Image.BICUBIC)
|
|
i = int(round((sh - th) / 2.0))
|
|
j = 0
|
|
arr = np.array(image)
|
|
assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
|
|
return Image.fromarray(arr[i : i + th, j : j + tw])
|
|
|
|
|
|
def map_target_fps(
|
|
fps: float,
|
|
max_fps: float,
|
|
) -> tuple[float, int]:
|
|
"""
|
|
Map fps to a new fps that is less than max_fps.
|
|
|
|
Args:
|
|
fps (float): Original fps.
|
|
max_fps (float): Maximum fps.
|
|
|
|
Returns:
|
|
tuple[float, int]: New fps and sampling interval.
|
|
"""
|
|
if math.isnan(fps):
|
|
return 0, 1
|
|
if fps < max_fps:
|
|
return fps, 1
|
|
sampling_interval = math.ceil(fps / max_fps)
|
|
new_fps = math.floor(fps / sampling_interval)
|
|
return new_fps, sampling_interval
|
|
|
|
|
|
def sync_object_across_devices(obj: Any, rank: int = 0):
|
|
"""
|
|
Synchronizes any picklable object across devices in a PyTorch distributed setting
|
|
using `broadcast_object_list` with CUDA support.
|
|
|
|
Parameters:
|
|
obj (Any): The object to synchronize. Can be any picklable object (e.g., list, dict, custom class).
|
|
rank (int): The rank of the device from which to broadcast the object state. Default is 0.
|
|
|
|
Note: Ensure torch.distributed is initialized before using this function and CUDA is available.
|
|
"""
|
|
|
|
# Move the object to a list for broadcasting
|
|
object_list = [obj]
|
|
|
|
# Broadcast the object list from the source rank to all other ranks
|
|
dist.broadcast_object_list(object_list, src=rank, device="cuda")
|
|
|
|
# Retrieve the synchronized object
|
|
obj = object_list[0]
|
|
|
|
return obj
|
|
|
|
|
|
def rescale_image_by_path(path: str, height: int, width: int):
|
|
"""
|
|
Rescales an image to the specified height and width and saves it back to the original path.
|
|
|
|
Args:
|
|
path (str): The file path of the image.
|
|
height (int): The target height of the image.
|
|
width (int): The target width of the image.
|
|
"""
|
|
try:
|
|
# read image
|
|
image = Image.open(path)
|
|
|
|
# check if image is valid
|
|
if image is None:
|
|
raise ValueError("The image is invalid or empty.")
|
|
|
|
# resize image
|
|
resize_transform = transforms.Resize((width, height))
|
|
resized_image = resize_transform(image)
|
|
|
|
# save resized image back to the original path
|
|
resized_image.save(path)
|
|
|
|
except Exception as e:
|
|
print(f"Error rescaling image: {e}")
|
|
|
|
|
|
def rescale_video_by_path(path: str, height: int, width: int):
|
|
"""
|
|
Rescales an MP4 video (without audio) to the specified height and width.
|
|
|
|
Args:
|
|
path (str): The file path of the video.
|
|
height (int): The target height of the video.
|
|
width (int): The target width of the video.
|
|
"""
|
|
try:
|
|
# Read video and metadata
|
|
video, info = read_video(path, backend="av")
|
|
|
|
# Check if video is valid
|
|
if video is None or video.size(0) == 0:
|
|
raise ValueError("The video is invalid or empty.")
|
|
|
|
# Resize video frames using a performant method
|
|
resize_transform = transforms.Compose([transforms.Resize((width, height))])
|
|
resized_video = torch.stack([resize_transform(frame) for frame in video])
|
|
|
|
# Save resized video back to the original path
|
|
resized_video = resized_video.permute(0, 2, 3, 1)
|
|
write_video(path, resized_video, fps=int(info["video_fps"]), video_codec="h264")
|
|
except Exception as e:
|
|
print(f"Error rescaling video: {e}")
|
|
|
|
|
|
def save_tensor_to_disk(tensor, path, exist_handling="overwrite"):
|
|
if os.path.exists(path):
|
|
if exist_handling == "ignore":
|
|
return
|
|
elif exist_handling == "raise":
|
|
raise UserWarning(f"File {path} already exists, rewriting!")
|
|
torch.save(tensor, path)
|
|
|
|
|
|
def save_tensor_to_internet(tensor, path):
|
|
raise NotImplementedError("save_tensor_to_internet is not implemented yet!")
|
|
|
|
|
|
def save_latent(latent, path, exist_handling="overwrite"):
|
|
if path.startswith(("http://", "https://", "ftp://", "sftp://")):
|
|
save_tensor_to_internet(latent, path)
|
|
else:
|
|
save_tensor_to_disk(latent, path, exist_handling=exist_handling)
|
|
|
|
|
|
def cache_latents(latents, path, exist_handling="overwrite"):
|
|
for i in range(latents.shape[0]):
|
|
save_latent(latents[i], path[i], exist_handling=exist_handling)
|