220 lines
7.0 KiB
Python
220 lines
7.0 KiB
Python
# this script is modified from https://github.com/MCG-NKU/AMT/blob/main/demos/demo_2x.py
|
|
import argparse
|
|
import os
|
|
import os.path as osp
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
|
|
from opensora.utils.ckpt_utils import download_model
|
|
|
|
from .networks.amt_g import Model
|
|
from .utils.utils import InputPadder, img2tensor, tensor2img
|
|
|
|
hf_endpoint = os.environ.get("HF_ENDPOINT")
|
|
if hf_endpoint is None:
|
|
hf_endpoint = "https://huggingface.co"
|
|
VID_EXT = [".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm"]
|
|
network_cfg = {
|
|
"params": {
|
|
"corr_radius": 3,
|
|
"corr_lvls": 4,
|
|
"num_flows": 5,
|
|
},
|
|
}
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def init():
|
|
"""
|
|
initialize the device and the anchor resolution.
|
|
"""
|
|
|
|
if device == "cuda":
|
|
anchor_resolution = 1024 * 512
|
|
anchor_memory = 1500 * 1024**2
|
|
anchor_memory_bias = 2500 * 1024**2
|
|
vram_avail = torch.cuda.get_device_properties(device).total_memory
|
|
print("VRAM available: {:.1f} MB".format(vram_avail / 1024**2))
|
|
else:
|
|
# Do not resize in cpu mode
|
|
anchor_resolution = 8192 * 8192
|
|
anchor_memory = 1
|
|
anchor_memory_bias = 0
|
|
vram_avail = 1
|
|
|
|
return anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail
|
|
|
|
|
|
def get_input_video_from_path(input_path):
|
|
"""
|
|
Get the input video from the input_path.
|
|
|
|
params:
|
|
input_path: str, the path of the input video.
|
|
devices: str, the device to run the model.
|
|
returns:
|
|
inputs: list, the list of the input frames.
|
|
scale: float, the scale of the input frames.
|
|
padder: InputPadder, the padder to pad the input frames.
|
|
"""
|
|
|
|
anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail = init()
|
|
|
|
if osp.splitext(input_path)[-1].lower() in VID_EXT:
|
|
vcap = cv2.VideoCapture(input_path)
|
|
|
|
inputs = []
|
|
w = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
h = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
|
|
scale = 1 if scale > 1 else scale
|
|
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
|
|
if scale < 1:
|
|
print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
|
|
padding = int(16 / scale)
|
|
padder = InputPadder((h, w), padding)
|
|
while True:
|
|
ret, frame = vcap.read()
|
|
if ret is False:
|
|
break
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
frame_t = img2tensor(frame).to(device)
|
|
frame_t = padder.pad(frame_t)
|
|
inputs.append(frame_t)
|
|
print(f"Loading the [video] from {input_path}, the number of frames [{len(inputs)}]")
|
|
else:
|
|
raise TypeError("Input should be a video.")
|
|
|
|
return inputs, scale, padder
|
|
|
|
|
|
def load_model(ckpt):
|
|
"""
|
|
load the frame interpolation model.
|
|
"""
|
|
params = network_cfg.get("params", {})
|
|
model = Model(**params)
|
|
model.load_state_dict(ckpt["state_dict"])
|
|
model = model.to(device)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def interpolater(model, inputs, scale, padder, iters=1):
|
|
"""
|
|
interpolating with the interpolation model.
|
|
|
|
params:
|
|
model: nn.Module, the frame interpolation model.
|
|
inputs: list, the list of the input frames.
|
|
scale: float, the scale of the input frames.
|
|
iters: int, the number of iterations of interpolation. The final frames model generating is 2 ** iters * (m - 1) + 1 and m is input frames.
|
|
returns:
|
|
outputs: list, the list of the output frames.
|
|
"""
|
|
|
|
print("Start frame interpolation:")
|
|
embt = torch.tensor(1 / 2).float().view(1, 1, 1, 1).to(device)
|
|
|
|
for i in range(iters):
|
|
print(f"Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}")
|
|
outputs = [inputs[0]]
|
|
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
|
|
in_0 = in_0.to(device)
|
|
in_1 = in_1.to(device)
|
|
with torch.no_grad():
|
|
imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)["imgt_pred"]
|
|
outputs += [imgt_pred.cpu(), in_1.cpu()]
|
|
inputs = outputs
|
|
|
|
outputs = padder.unpad(*outputs)
|
|
return outputs
|
|
|
|
|
|
def write(outputs, input_path, output_path, fps=30):
|
|
"""
|
|
write results to the output_path.
|
|
"""
|
|
|
|
if osp.exists(output_path) is False:
|
|
os.makedirs(output_path)
|
|
|
|
size = outputs[0].shape[2:][::-1]
|
|
|
|
_, file_name_with_extension = os.path.split(input_path)
|
|
file_name, _ = os.path.splitext(file_name_with_extension)
|
|
|
|
save_video_path = f"{output_path}/fps{fps}_{file_name}.mp4"
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
|
writer = cv2.VideoWriter(save_video_path, fourcc, fps, size)
|
|
|
|
for i, imgt_pred in enumerate(outputs):
|
|
imgt_pred = tensor2img(imgt_pred)
|
|
imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR)
|
|
writer.write(imgt_pred)
|
|
print(f"Demo video is saved to [{save_video_path}]")
|
|
|
|
writer.release()
|
|
|
|
|
|
def process(
|
|
model,
|
|
image_path,
|
|
output_path,
|
|
fps,
|
|
iters,
|
|
):
|
|
inputs, scale, padder = get_input_video_from_path(image_path)
|
|
outputs = interpolater(model, inputs, scale, padder, iters)
|
|
write(outputs, image_path, output_path, fps)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("input", help="Input video.")
|
|
parser.add_argument("--ckpt", type=str, default="./pretrained_models/amt-g.pth", help="The pretrained model.")
|
|
parser.add_argument(
|
|
"--niters",
|
|
type=int,
|
|
default=1,
|
|
help="Iter of Interpolation. The number of frames will be double after per iter.",
|
|
)
|
|
parser.add_argument("--output_path", type=str, default="samples", help="Output path.")
|
|
parser.add_argument("--fps", type=int, default=8, help="Frames rate of the output video.")
|
|
parser.add_argument("--folder", action="store_true", help="If the input is a folder, set this flag.")
|
|
args = parser.parse_args()
|
|
|
|
times_frame = 2**args.niters
|
|
old_fps = args.fps
|
|
args.fps = args.fps * times_frame
|
|
print(f"Interpolation will turn {old_fps}fps video to {args.fps}fps video.")
|
|
args.input = os.path.expanduser(args.input)
|
|
args.ckpt = os.path.expanduser(args.ckpt)
|
|
args.folder = osp.splitext(args.input)[-1].lower() not in VID_EXT
|
|
args.ckpt = download_model(local_path=args.ckpt, url=hf_endpoint + "/lalala125/AMT/resolve/main/amt-g.pth")
|
|
return args
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
ckpt_path = args.ckpt
|
|
input_path = args.input
|
|
output_path = args.output_path
|
|
iters = int(args.niters)
|
|
fps = int(args.fps)
|
|
|
|
model = load_model(ckpt_path)
|
|
|
|
if args.folder:
|
|
for file in os.listdir(input_path):
|
|
if osp.splitext(file)[-1].lower() in VID_EXT:
|
|
vid_path = os.path.join(input_path, file)
|
|
process(model, vid_path, output_path, fps, iters)
|
|
else:
|
|
process(model, input_path, output_path, fps, iters)
|
|
|
|
print("Interpolation is done.")
|
|
print(f"Output path: {output_path}")
|