sglang_v0.5.2/vision_0.23.0/test/test_videoapi.py

313 lines
13 KiB
Python

import collections
import os
import urllib
import pytest
import torch
import torchvision
from pytest import approx
from torchvision.datasets.utils import download_url
from torchvision.io import _HAS_CPU_VIDEO_DECODER, VideoReader
# WARNING: these tests have been skipped forever on the CI because the video ops
# are never properly available. This is bad, but things have been in a terrible
# state for a long time already as we write this comment, and we'll hopefully be
# able to get rid of this all soon.
try:
import av
# Do a version test too
torchvision.io.video._check_av_available()
except ImportError:
av = None
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
CheckerConfig = ["duration", "video_fps", "audio_sample_rate"]
GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
def backends():
backends_ = ["video_reader"]
if av is not None:
backends_.append("pyav")
return backends_
def fate(name, path="."):
"""Download and return a path to a sample from the FFmpeg test suite.
See the `FFmpeg Automated Test Environment <https://www.ffmpeg.org/fate.html>`_
"""
file_name = name.split("/")[1]
download_url("http://fate.ffmpeg.org/fate-suite/" + name, path, file_name)
return os.path.join(path, file_name)
test_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
duration=2.0, video_fps=30.0, audio_sample_rate=None
),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
"v_SoccerJuggling_g23_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
"v_SoccerJuggling_g24_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
"R6llTwEh07w.mp4": GroundTruth(duration=10.0, video_fps=30.0, audio_sample_rate=44100),
"SOX5yA1l24A.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
"WUzgd7C1pWA.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
}
@pytest.mark.skipif(_HAS_CPU_VIDEO_DECODER is False, reason="Didn't compile with ffmpeg")
class TestVideoApi:
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", backends())
def test_frame_reading(self, test_video, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
with av.open(full_path) as av_reader:
if av_reader.streams.video:
av_frames, vr_frames = [], []
av_pts, vr_pts = [], []
# get av frames
for av_frame in av_reader.decode(av_reader.streams.video[0]):
av_frames.append(torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1))
av_pts.append(av_frame.pts * av_frame.time_base)
# get vr frames
video_reader = VideoReader(full_path, "video")
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
# same number of frames
assert len(vr_frames) == len(av_frames)
assert len(vr_pts) == len(av_pts)
# compare the frames and ptss
for i in range(len(vr_frames)):
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
mean_delta = torch.mean(torch.abs(av_frames[i].float() - vr_frames[i].float()))
# on average the difference is very small and caused
# by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames
assert mean_delta.item() < 2.55
del vr_frames, av_frames, vr_pts, av_pts
# test audio reading compared to PYAV
with av.open(full_path) as av_reader:
if av_reader.streams.audio:
av_frames, vr_frames = [], []
av_pts, vr_pts = [], []
# get av frames
for av_frame in av_reader.decode(av_reader.streams.audio[0]):
av_frames.append(torch.tensor(av_frame.to_ndarray()).permute(1, 0))
av_pts.append(av_frame.pts * av_frame.time_base)
av_reader.close()
# get vr frames
video_reader = VideoReader(full_path, "audio")
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
# same number of frames
assert len(vr_frames) == len(av_frames)
assert len(vr_pts) == len(av_pts)
# compare the frames and ptss
for i in range(len(vr_frames)):
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
max_delta = torch.max(torch.abs(av_frames[i].float() - vr_frames[i].float()))
# we assure that there is never more than 1% difference in signal
assert max_delta.item() < 0.001
@pytest.mark.parametrize("stream", ["video", "audio"])
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", backends())
def test_frame_reading_mem_vs_file(self, test_video, stream, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path)
reader_md = reader.get_metadata()
if stream in reader_md:
# Test video reading from file vs from memory
vr_frames, vr_frames_mem = [], []
vr_pts, vr_pts_mem = [], []
# get vr frames
video_reader = VideoReader(full_path, stream)
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
# get vr frames = read from memory
f = open(full_path, "rb")
fbytes = f.read()
f.close()
video_reader_from_mem = VideoReader(fbytes, stream)
for vr_frame_from_mem in video_reader_from_mem:
vr_frames_mem.append(vr_frame_from_mem["data"])
vr_pts_mem.append(vr_frame_from_mem["pts"])
# same number of frames
assert len(vr_frames) == len(vr_frames_mem)
assert len(vr_pts) == len(vr_pts_mem)
# compare the frames and ptss
for i in range(len(vr_frames)):
assert vr_pts[i] == vr_pts_mem[i]
mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float()))
# on average the difference is very small and caused
# by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames
assert mean_delta.item() < 2.55
del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem
else:
del reader, reader_md
@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("backend", backends())
def test_metadata(self, test_video, config, backend):
"""
Test that the metadata returned via pyav corresponds to the one returned
by the new video decoder API
"""
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata()
assert config.video_fps == approx(reader_md["video"]["fps"][0], abs=0.0001)
assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", backends())
def test_seek_start(self, test_video, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
video_reader = VideoReader(full_path, "video")
num_frames = 0
for _ in video_reader:
num_frames += 1
# now seek the container to 0 and do it again
# It's often that starting seek can be inprecise
# this way and it doesn't start at 0
video_reader.seek(0)
start_num_frames = 0
for _ in video_reader:
start_num_frames += 1
assert start_num_frames == num_frames
# now seek the container to < 0 to check for unexpected behaviour
video_reader.seek(-1)
start_num_frames = 0
for _ in video_reader:
start_num_frames += 1
assert start_num_frames == num_frames
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader"])
def test_accurateseek_middle(self, test_video, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
stream = "video"
video_reader = VideoReader(full_path, stream)
md = video_reader.get_metadata()
duration = md[stream]["duration"][0]
if duration is not None:
num_frames = 0
for _ in video_reader:
num_frames += 1
video_reader.seek(duration / 2)
middle_num_frames = 0
for _ in video_reader:
middle_num_frames += 1
assert middle_num_frames < num_frames
assert middle_num_frames == approx(num_frames // 2, abs=1)
video_reader.seek(duration / 2)
frame = next(video_reader)
lb = duration / 2 - 1 / md[stream]["fps"][0]
ub = duration / 2 + 1 / md[stream]["fps"][0]
assert (lb <= frame["pts"]) and (ub >= frame["pts"])
def test_fate_suite(self):
# TODO: remove the try-except statement once the connectivity issues are resolved
try:
video_path = fate("sub/MovText_capability_tester.mp4", VIDEO_DIR)
except (urllib.error.URLError, ConnectionError) as error:
pytest.skip(f"Skipping due to connectivity issues: {error}")
vr = VideoReader(video_path)
metadata = vr.get_metadata()
assert metadata["subtitles"]["duration"] is not None
os.remove(video_path)
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("backend", backends())
def test_keyframe_reading(self, test_video, config, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
av_reader = av.open(full_path)
# reduce streams to only keyframes
av_stream = av_reader.streams.video[0]
av_stream.codec_context.skip_frame = "NONKEY"
av_keyframes = []
vr_keyframes = []
if av_reader.streams.video:
# get all keyframes using pyav. Then, seek randomly into video reader
# and assert that all the returned values are in AV_KEYFRAMES
for av_frame in av_reader.decode(av_stream):
av_keyframes.append(float(av_frame.pts * av_frame.time_base))
if len(av_keyframes) > 1:
video_reader = VideoReader(full_path, "video")
for i in range(1, len(av_keyframes)):
seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2
data = next(video_reader.seek(seek_val, True))
vr_keyframes.append(data["pts"])
data = next(video_reader.seek(config.duration, True))
vr_keyframes.append(data["pts"])
assert len(av_keyframes) == len(vr_keyframes)
# NOTE: this video gets different keyframe with different
# loaders (0.333 pyav, 0.666 for us)
if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi":
for i in range(len(av_keyframes)):
assert av_keyframes[i] == approx(vr_keyframes[i], rel=0.001)
def test_src(self):
with pytest.raises(ValueError, match="src cannot be empty"):
VideoReader(src="")
with pytest.raises(ValueError, match="src must be either string"):
VideoReader(src=2)
with pytest.raises(TypeError, match="unexpected keyword argument"):
VideoReader(path="path")
if __name__ == "__main__":
pytest.main([__file__])