313 lines
13 KiB
Python
313 lines
13 KiB
Python
import collections
|
|
import os
|
|
import urllib
|
|
|
|
import pytest
|
|
import torch
|
|
import torchvision
|
|
from pytest import approx
|
|
from torchvision.datasets.utils import download_url
|
|
from torchvision.io import _HAS_CPU_VIDEO_DECODER, VideoReader
|
|
|
|
|
|
# WARNING: these tests have been skipped forever on the CI because the video ops
|
|
# are never properly available. This is bad, but things have been in a terrible
|
|
# state for a long time already as we write this comment, and we'll hopefully be
|
|
# able to get rid of this all soon.
|
|
|
|
|
|
try:
|
|
import av
|
|
|
|
# Do a version test too
|
|
torchvision.io.video._check_av_available()
|
|
except ImportError:
|
|
av = None
|
|
|
|
|
|
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
|
|
|
|
CheckerConfig = ["duration", "video_fps", "audio_sample_rate"]
|
|
GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
|
|
|
|
|
|
def backends():
|
|
backends_ = ["video_reader"]
|
|
if av is not None:
|
|
backends_.append("pyav")
|
|
return backends_
|
|
|
|
|
|
def fate(name, path="."):
|
|
"""Download and return a path to a sample from the FFmpeg test suite.
|
|
See the `FFmpeg Automated Test Environment <https://www.ffmpeg.org/fate.html>`_
|
|
"""
|
|
|
|
file_name = name.split("/")[1]
|
|
download_url("http://fate.ffmpeg.org/fate-suite/" + name, path, file_name)
|
|
return os.path.join(path, file_name)
|
|
|
|
|
|
test_videos = {
|
|
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
|
|
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
|
|
duration=2.0, video_fps=30.0, audio_sample_rate=None
|
|
),
|
|
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
|
|
"v_SoccerJuggling_g23_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
|
|
"v_SoccerJuggling_g24_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
|
|
"R6llTwEh07w.mp4": GroundTruth(duration=10.0, video_fps=30.0, audio_sample_rate=44100),
|
|
"SOX5yA1l24A.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
|
|
"WUzgd7C1pWA.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
|
|
}
|
|
|
|
|
|
@pytest.mark.skipif(_HAS_CPU_VIDEO_DECODER is False, reason="Didn't compile with ffmpeg")
|
|
class TestVideoApi:
|
|
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
|
|
@pytest.mark.parametrize("test_video", test_videos.keys())
|
|
@pytest.mark.parametrize("backend", backends())
|
|
def test_frame_reading(self, test_video, backend):
|
|
torchvision.set_video_backend(backend)
|
|
full_path = os.path.join(VIDEO_DIR, test_video)
|
|
with av.open(full_path) as av_reader:
|
|
if av_reader.streams.video:
|
|
av_frames, vr_frames = [], []
|
|
av_pts, vr_pts = [], []
|
|
# get av frames
|
|
for av_frame in av_reader.decode(av_reader.streams.video[0]):
|
|
av_frames.append(torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1))
|
|
av_pts.append(av_frame.pts * av_frame.time_base)
|
|
|
|
# get vr frames
|
|
video_reader = VideoReader(full_path, "video")
|
|
for vr_frame in video_reader:
|
|
vr_frames.append(vr_frame["data"])
|
|
vr_pts.append(vr_frame["pts"])
|
|
|
|
# same number of frames
|
|
assert len(vr_frames) == len(av_frames)
|
|
assert len(vr_pts) == len(av_pts)
|
|
|
|
# compare the frames and ptss
|
|
for i in range(len(vr_frames)):
|
|
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
|
|
|
|
mean_delta = torch.mean(torch.abs(av_frames[i].float() - vr_frames[i].float()))
|
|
# on average the difference is very small and caused
|
|
# by decoding (around 1%)
|
|
# TODO: asses empirically how to set this? atm it's 1%
|
|
# averaged over all frames
|
|
assert mean_delta.item() < 2.55
|
|
|
|
del vr_frames, av_frames, vr_pts, av_pts
|
|
|
|
# test audio reading compared to PYAV
|
|
with av.open(full_path) as av_reader:
|
|
if av_reader.streams.audio:
|
|
av_frames, vr_frames = [], []
|
|
av_pts, vr_pts = [], []
|
|
# get av frames
|
|
for av_frame in av_reader.decode(av_reader.streams.audio[0]):
|
|
av_frames.append(torch.tensor(av_frame.to_ndarray()).permute(1, 0))
|
|
av_pts.append(av_frame.pts * av_frame.time_base)
|
|
av_reader.close()
|
|
|
|
# get vr frames
|
|
video_reader = VideoReader(full_path, "audio")
|
|
for vr_frame in video_reader:
|
|
vr_frames.append(vr_frame["data"])
|
|
vr_pts.append(vr_frame["pts"])
|
|
|
|
# same number of frames
|
|
assert len(vr_frames) == len(av_frames)
|
|
assert len(vr_pts) == len(av_pts)
|
|
|
|
# compare the frames and ptss
|
|
for i in range(len(vr_frames)):
|
|
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
|
|
max_delta = torch.max(torch.abs(av_frames[i].float() - vr_frames[i].float()))
|
|
# we assure that there is never more than 1% difference in signal
|
|
assert max_delta.item() < 0.001
|
|
|
|
@pytest.mark.parametrize("stream", ["video", "audio"])
|
|
@pytest.mark.parametrize("test_video", test_videos.keys())
|
|
@pytest.mark.parametrize("backend", backends())
|
|
def test_frame_reading_mem_vs_file(self, test_video, stream, backend):
|
|
torchvision.set_video_backend(backend)
|
|
full_path = os.path.join(VIDEO_DIR, test_video)
|
|
|
|
reader = VideoReader(full_path)
|
|
reader_md = reader.get_metadata()
|
|
|
|
if stream in reader_md:
|
|
# Test video reading from file vs from memory
|
|
vr_frames, vr_frames_mem = [], []
|
|
vr_pts, vr_pts_mem = [], []
|
|
# get vr frames
|
|
video_reader = VideoReader(full_path, stream)
|
|
for vr_frame in video_reader:
|
|
vr_frames.append(vr_frame["data"])
|
|
vr_pts.append(vr_frame["pts"])
|
|
|
|
# get vr frames = read from memory
|
|
f = open(full_path, "rb")
|
|
fbytes = f.read()
|
|
f.close()
|
|
video_reader_from_mem = VideoReader(fbytes, stream)
|
|
|
|
for vr_frame_from_mem in video_reader_from_mem:
|
|
vr_frames_mem.append(vr_frame_from_mem["data"])
|
|
vr_pts_mem.append(vr_frame_from_mem["pts"])
|
|
|
|
# same number of frames
|
|
assert len(vr_frames) == len(vr_frames_mem)
|
|
assert len(vr_pts) == len(vr_pts_mem)
|
|
|
|
# compare the frames and ptss
|
|
for i in range(len(vr_frames)):
|
|
assert vr_pts[i] == vr_pts_mem[i]
|
|
mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float()))
|
|
# on average the difference is very small and caused
|
|
# by decoding (around 1%)
|
|
# TODO: asses empirically how to set this? atm it's 1%
|
|
# averaged over all frames
|
|
assert mean_delta.item() < 2.55
|
|
|
|
del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem
|
|
else:
|
|
del reader, reader_md
|
|
|
|
@pytest.mark.parametrize("test_video,config", test_videos.items())
|
|
@pytest.mark.parametrize("backend", backends())
|
|
def test_metadata(self, test_video, config, backend):
|
|
"""
|
|
Test that the metadata returned via pyav corresponds to the one returned
|
|
by the new video decoder API
|
|
"""
|
|
torchvision.set_video_backend(backend)
|
|
full_path = os.path.join(VIDEO_DIR, test_video)
|
|
reader = VideoReader(full_path, "video")
|
|
reader_md = reader.get_metadata()
|
|
assert config.video_fps == approx(reader_md["video"]["fps"][0], abs=0.0001)
|
|
assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)
|
|
|
|
@pytest.mark.parametrize("test_video", test_videos.keys())
|
|
@pytest.mark.parametrize("backend", backends())
|
|
def test_seek_start(self, test_video, backend):
|
|
torchvision.set_video_backend(backend)
|
|
full_path = os.path.join(VIDEO_DIR, test_video)
|
|
video_reader = VideoReader(full_path, "video")
|
|
num_frames = 0
|
|
for _ in video_reader:
|
|
num_frames += 1
|
|
|
|
# now seek the container to 0 and do it again
|
|
# It's often that starting seek can be inprecise
|
|
# this way and it doesn't start at 0
|
|
video_reader.seek(0)
|
|
start_num_frames = 0
|
|
for _ in video_reader:
|
|
start_num_frames += 1
|
|
|
|
assert start_num_frames == num_frames
|
|
|
|
# now seek the container to < 0 to check for unexpected behaviour
|
|
video_reader.seek(-1)
|
|
start_num_frames = 0
|
|
for _ in video_reader:
|
|
start_num_frames += 1
|
|
|
|
assert start_num_frames == num_frames
|
|
|
|
@pytest.mark.parametrize("test_video", test_videos.keys())
|
|
@pytest.mark.parametrize("backend", ["video_reader"])
|
|
def test_accurateseek_middle(self, test_video, backend):
|
|
torchvision.set_video_backend(backend)
|
|
full_path = os.path.join(VIDEO_DIR, test_video)
|
|
stream = "video"
|
|
video_reader = VideoReader(full_path, stream)
|
|
md = video_reader.get_metadata()
|
|
duration = md[stream]["duration"][0]
|
|
if duration is not None:
|
|
num_frames = 0
|
|
for _ in video_reader:
|
|
num_frames += 1
|
|
|
|
video_reader.seek(duration / 2)
|
|
middle_num_frames = 0
|
|
for _ in video_reader:
|
|
middle_num_frames += 1
|
|
|
|
assert middle_num_frames < num_frames
|
|
assert middle_num_frames == approx(num_frames // 2, abs=1)
|
|
|
|
video_reader.seek(duration / 2)
|
|
frame = next(video_reader)
|
|
lb = duration / 2 - 1 / md[stream]["fps"][0]
|
|
ub = duration / 2 + 1 / md[stream]["fps"][0]
|
|
assert (lb <= frame["pts"]) and (ub >= frame["pts"])
|
|
|
|
def test_fate_suite(self):
|
|
# TODO: remove the try-except statement once the connectivity issues are resolved
|
|
try:
|
|
video_path = fate("sub/MovText_capability_tester.mp4", VIDEO_DIR)
|
|
except (urllib.error.URLError, ConnectionError) as error:
|
|
pytest.skip(f"Skipping due to connectivity issues: {error}")
|
|
vr = VideoReader(video_path)
|
|
metadata = vr.get_metadata()
|
|
|
|
assert metadata["subtitles"]["duration"] is not None
|
|
os.remove(video_path)
|
|
|
|
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
|
|
@pytest.mark.parametrize("test_video,config", test_videos.items())
|
|
@pytest.mark.parametrize("backend", backends())
|
|
def test_keyframe_reading(self, test_video, config, backend):
|
|
torchvision.set_video_backend(backend)
|
|
full_path = os.path.join(VIDEO_DIR, test_video)
|
|
|
|
av_reader = av.open(full_path)
|
|
# reduce streams to only keyframes
|
|
av_stream = av_reader.streams.video[0]
|
|
av_stream.codec_context.skip_frame = "NONKEY"
|
|
|
|
av_keyframes = []
|
|
vr_keyframes = []
|
|
if av_reader.streams.video:
|
|
|
|
# get all keyframes using pyav. Then, seek randomly into video reader
|
|
# and assert that all the returned values are in AV_KEYFRAMES
|
|
|
|
for av_frame in av_reader.decode(av_stream):
|
|
av_keyframes.append(float(av_frame.pts * av_frame.time_base))
|
|
|
|
if len(av_keyframes) > 1:
|
|
video_reader = VideoReader(full_path, "video")
|
|
for i in range(1, len(av_keyframes)):
|
|
seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2
|
|
data = next(video_reader.seek(seek_val, True))
|
|
vr_keyframes.append(data["pts"])
|
|
|
|
data = next(video_reader.seek(config.duration, True))
|
|
vr_keyframes.append(data["pts"])
|
|
|
|
assert len(av_keyframes) == len(vr_keyframes)
|
|
# NOTE: this video gets different keyframe with different
|
|
# loaders (0.333 pyav, 0.666 for us)
|
|
if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi":
|
|
for i in range(len(av_keyframes)):
|
|
assert av_keyframes[i] == approx(vr_keyframes[i], rel=0.001)
|
|
|
|
def test_src(self):
|
|
with pytest.raises(ValueError, match="src cannot be empty"):
|
|
VideoReader(src="")
|
|
with pytest.raises(ValueError, match="src must be either string"):
|
|
VideoReader(src=2)
|
|
with pytest.raises(TypeError, match="unexpected keyword argument"):
|
|
VideoReader(path="path")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|