import gc import math import os import re import warnings from fractions import Fraction import av import cv2 import numpy as np import torch from torchvision import get_video_backend from torchvision.io.video import _check_av_available MAX_NUM_FRAMES = 2500 def read_video_av( filename: str, start_pts: float | Fraction = 0, end_pts: float | Fraction | None = None, pts_unit: str = "pts", output_format: str = "THWC", ) -> tuple[torch.Tensor, torch.Tensor, dict]: """ Reads a video from a file, returning both the video frames and the audio frames This method is modified from torchvision.io.video.read_video, with the following changes: 1. will not extract audio frames and return empty for aframes 2. remove checks and only support pyav 3. add container.close() and gc.collect() to avoid thread leakage 4. try our best to avoid memory leak Args: filename (str): path to the video file start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): The start presentation time of the video end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): The end presentation time pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'pts'. output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". Returns: vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points info (dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) """ # format output_format = output_format.upper() if output_format not in ("THWC", "TCHW"): raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") # file existence if not os.path.exists(filename): raise RuntimeError(f"File not found: {filename}") # backend check assert get_video_backend() == "pyav", "pyav backend is required for read_video_av" _check_av_available() # end_pts check if end_pts is None: end_pts = float("inf") if end_pts < start_pts: raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") # == get video info == info = {} # TODO: creating an container leads to memory leak (1G for 8 workers 1 GPU) container = av.open(filename, metadata_errors="ignore") # fps video_fps = container.streams.video[0].average_rate # guard against potentially corrupted files if video_fps is not None: info["video_fps"] = float(video_fps) iter_video = container.decode(**{"video": 0}) frame = next(iter_video).to_rgb().to_ndarray() height, width = frame.shape[:2] total_frames = container.streams.video[0].frames if total_frames == 0: total_frames = MAX_NUM_FRAMES warnings.warn(f"total_frames is 0, using {MAX_NUM_FRAMES} as a fallback") container.close() del container # HACK: must create before iterating stream # use np.zeros will not actually allocate memory # use np.ones will lead to a little memory leak video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8) # == read == try: # TODO: The reading has memory leak (4G for 8 workers 1 GPU) container = av.open(filename, metadata_errors="ignore") assert container.streams.video is not None video_frames = _read_from_stream( video_frames, container, start_pts, end_pts, pts_unit, container.streams.video[0], {"video": 0}, filename=filename, ) except av.AVError as e: print(f"[Warning] Error while reading video {filename}: {e}") vframes = torch.from_numpy(video_frames).clone() del video_frames if output_format == "TCHW": # [T,H,W,C] --> [T,C,H,W] vframes = vframes.permute(0, 3, 1, 2) aframes = torch.empty((1, 0), dtype=torch.float32) return vframes, aframes, info def _read_from_stream( video_frames, container: "av.container.Container", start_offset: float, end_offset: float, pts_unit: str, stream: "av.stream.Stream", stream_name: dict[str, int | tuple[int, ...] | list[int] | None], filename: str | None = None, ) -> list["av.frame.Frame"]: if pts_unit == "sec": # TODO: we should change all of this from ground up to simply take # sec and convert to MS in C++ start_offset = int(math.floor(start_offset * (1 / stream.time_base))) if end_offset != float("inf"): end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) else: warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.") should_buffer = True max_buffer_size = 5 if stream.type == "video": # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt) # so need to buffer some extra frames to sort everything # properly extradata = stream.codec_context.extradata # overly complicated way of finding if `divx_packed` is set, following # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263 if extradata and b"DivX" in extradata: # can't use regex directly because of some weird characters sometimes... pos = extradata.find(b"DivX") d = extradata[pos:] o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d) if o is None: o = re.search(rb"DivX(\d+)b(\d+)(\w)", d) if o is not None: should_buffer = o.group(3) == b"p" seek_offset = start_offset # some files don't seek to the right location, so better be safe here seek_offset = max(seek_offset - 1, 0) if should_buffer: # FIXME this is kind of a hack, but we will jump to the previous keyframe # so this will be safe seek_offset = max(seek_offset - max_buffer_size, 0) try: # TODO check if stream needs to always be the video stream here or not container.seek(seek_offset, any_frame=False, backward=True, stream=stream) except av.AVError as e: print(f"[Warning] Error while seeking video {filename}: {e}") return [] # == main == buffer_count = 0 frames_pts = [] cnt = 0 try: for _idx, frame in enumerate(container.decode(**stream_name)): frames_pts.append(frame.pts) video_frames[cnt] = frame.to_rgb().to_ndarray() cnt += 1 if cnt >= len(video_frames): break if frame.pts >= end_offset: if should_buffer and buffer_count < max_buffer_size: buffer_count += 1 continue break except av.AVError as e: print(f"[Warning] Error while reading video {filename}: {e}") # garbage collection for thread leakage container.close() del container # NOTE: manually garbage collect to close pyav threads gc.collect() # ensure that the results are sorted wrt the pts # NOTE: here we assert frames_pts is sorted start_ptr = 0 end_ptr = cnt while start_ptr < end_ptr and frames_pts[start_ptr] < start_offset: start_ptr += 1 while start_ptr < end_ptr and frames_pts[end_ptr - 1] > end_offset: end_ptr -= 1 if start_offset > 0 and start_offset not in frames_pts[start_ptr:end_ptr]: # if there is no frame that exactly matches the pts of start_offset # add the last frame smaller than start_offset, to guarantee that # we will have all the necessary data. This is most useful for audio if start_ptr > 0: start_ptr -= 1 result = video_frames[start_ptr:end_ptr].copy() return result def read_video_cv2(video_path): cap = cv2.VideoCapture(video_path) if not cap.isOpened(): # print("Error: Unable to open video") raise ValueError else: fps = cap.get(cv2.CAP_PROP_FPS) vinfo = { "video_fps": fps, } frames = [] while True: # Read a frame from the video ret, frame = cap.read() # If frame is not read correctly, break the loop if not ret: break frames.append(frame[:, :, ::-1]) # BGR to RGB # Exit if 'q' is pressed if cv2.waitKey(25) & 0xFF == ord("q"): break # Release the video capture object and close all windows cap.release() cv2.destroyAllWindows() frames = np.stack(frames) frames = torch.from_numpy(frames) # [T, H, W, C=3] frames = frames.permute(0, 3, 1, 2) return frames, vinfo def read_video(video_path, backend="av"): if backend == "cv2": vframes, vinfo = read_video_cv2(video_path) elif backend == "av": vframes, _, vinfo = read_video_av(filename=video_path, pts_unit="sec", output_format="TCHW") else: raise ValueError return vframes, vinfo