diff --git a/.github/scripts/setup-env.sh b/.github/scripts/setup-env.sh index 8c5b53fd281..4ca2b3a03f7 100755 --- a/.github/scripts/setup-env.sh +++ b/.github/scripts/setup-env.sh @@ -34,11 +34,6 @@ conda activate ci conda install --quiet --yes libjpeg-turbo -c pytorch pip install --progress-bar=off --upgrade setuptools==72.1.0 -# See https://github.com/pytorch/vision/issues/6790 -if [[ "${PYTHON_VERSION}" != "3.11" ]]; then - pip install --progress-bar=off av!=10.0.0 -fi - echo '::endgroup::' if [[ "${OS_TYPE}" == windows && "${GPU_ARCH_TYPE}" == cuda ]]; then diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 8b341622181..c030b2f7493 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -34,12 +34,12 @@ jobs: CONDA_PATH=$(which conda) eval "$(${CONDA_PATH} shell.bash hook)" conda activate ci - # FIXME: not sure why we need this. `ldd torchvision/video_reader.so` shows that it - # already links against the one pulled from conda. However, at runtime it pulls from - # /lib64 - # Should we maybe always do this in `./.github/scripts/setup-env.sh` so that we don't - # have to pay attention in all other workflows? + + echo '::group::Install TorchCodec and ffmpeg' + conda install --quiet --yes ffmpeg + pip install --progress-bar=off --pre torchcodec --index-url="https://download.pytorch.org/whl/nightly/cpu" export LD_LIBRARY_PATH="${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH}" + echo '::endgroup::' cd docs diff --git a/gallery/others/plot_optical_flow.py b/gallery/others/plot_optical_flow.py index 6296c8e667e..a80804e6db5 100644 --- a/gallery/others/plot_optical_flow.py +++ b/gallery/others/plot_optical_flow.py @@ -47,11 +47,10 @@ def plot(imgs, **imshow_kwargs): plt.tight_layout() # %% -# Reading Videos Using Torchvision +# Reading Videos Using TorchCodec # -------------------------------- -# We will first read a video using :func:`~torchvision.io.read_video`. -# Alternatively one can use the new :class:`~torchvision.io.VideoReader` API (if -# torchvision is built from source). +# We will first read a video using +# `TorchCodec `_. # The video we will use here is free of use from `pexels.com # `_, # credits go to `Pavel Danilyuk `_. @@ -67,16 +66,16 @@ def plot(imgs, **imshow_kwargs): _ = urlretrieve(video_url, video_path) # %% -# :func:`~torchvision.io.read_video` returns the video frames, audio frames and -# the metadata associated with the video. In our case, we only need the video -# frames. +# We use :class:`~torchcodec.decoders.VideoDecoder` to decode the video frames. +# TorchCodec returns frames in NCHW format by default. # # Here we will just make 2 predictions between 2 pre-selected pairs of frames, # namely frames (100, 101) and (150, 151). Each of these pairs corresponds to a # single model input. -from torchvision.io import read_video -frames, _, _ = read_video(str(video_path), output_format="TCHW") +from torchcodec.decoders import VideoDecoder +decoder = VideoDecoder(str(video_path)) +frames = decoder[:] img1_batch = torch.stack([frames[100], frames[150]]) img2_batch = torch.stack([frames[101], frames[151]]) @@ -85,7 +84,7 @@ def plot(imgs, **imshow_kwargs): # %% # The RAFT model accepts RGB images. We first get the frames from -# :func:`~torchvision.io.read_video` and resize them to ensure their dimensions +# the decoder and resize them to ensure their dimensions # are divisible by 8. Note that we explicitly use ``antialias=False``, because # this is how those models were trained. Then we use the transforms bundled into # the weights in order to preprocess the input and rescale its values to the diff --git a/test/common_utils.py b/test/common_utils.py index 24ebb1376c3..1459f52cbbe 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -18,7 +18,7 @@ import torch.testing from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair -from torchvision import io, tv_tensors +from torchvision import tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor @@ -166,6 +166,8 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None): + from datasets_utils import create_video_file + names = [] for i in range(num_videos): if sizes is None: @@ -176,10 +178,9 @@ def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None): f = 5 else: f = fps[i] - data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8) - name = os.path.join(tmpdir, f"{i}.mp4") - names.append(name) - io.write_video(name, data, fps=f) + name = f"{i}.mp4" + create_video_file(tmpdir, name, size=(size, 3, 300, 400), fps=f) + names.append(os.path.join(tmpdir, name)) return names diff --git a/test/datasets_utils.py b/test/datasets_utils.py index cbfb26b6c6b..46d82f5e784 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -66,7 +66,7 @@ class LazyImporter: """ MODULES = ( - "av", + "torchcodec", "lmdb", "pycocotools", "requests", @@ -669,17 +669,24 @@ class VideoDatasetTestCase(DatasetTestCase): - Overwrites the 'FEATURE_TYPES' class attribute to expect two :class:`torch.Tensor` s for the video and audio as well as an integer label. - - Overwrites the 'REQUIRED_PACKAGES' class attribute to require PyAV (``av``). + - Overwrites the 'REQUIRED_PACKAGES' class attribute to require TorchCodec (``torchcodec``). + - Skips on non-Linux platforms and CUDA-only environments. - Adds the 'DEFAULT_FRAMES_PER_CLIP' class attribute. If no 'frames_per_clip' is provided by 'inject_fake_data()' and it is the last parameter without a default value in the dataset constructor, the value of the 'DEFAULT_FRAMES_PER_CLIP' class attribute is appended to the output. """ FEATURE_TYPES = (torch.Tensor, torch.Tensor, int) - REQUIRED_PACKAGES = ("av",) + REQUIRED_PACKAGES = ("torchcodec",) FRAMES_PER_CLIP = 1 + @classmethod + def setUpClass(cls): + if platform.system() != "Linux": + raise unittest.SkipTest("Video dataset tests are only supported on Linux.") + super().setUpClass() + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.dataset_args = self._set_default_frames_per_clip(self.dataset_args) @@ -864,13 +871,12 @@ def shape_test_for_stereo( assert dw == mw -@requires_lazy_imports("av") +@requires_lazy_imports("torchcodec") def create_video_file( root: Union[pathlib.Path, str], name: Union[pathlib.Path, str], size: Union[Sequence[int], int] = (1, 3, 10, 10), fps: float = 25, - **kwargs: Any, ) -> pathlib.Path: """Create a video file from random data. @@ -881,14 +887,15 @@ def create_video_file( ``(num_frames, num_channels, height, width)``. If scalar, the value is used for the height and width. If not provided, ``num_frames=1`` and ``num_channels=3`` are assumed. fps (float): Frame rate in frames per second. - kwargs (Any): Additional parameters passed to :func:`torchvision.io.write_video`. Returns: - pathlib.Path: Path to the created image file. + pathlib.Path: Path to the created video file. Raises: - UsageError: If PyAV is not available. + UsageError: If TorchCodec is not available. """ + from torchcodec.encoders import VideoEncoder + if isinstance(size, int): size = (size, size) if len(size) == 2: @@ -902,11 +909,14 @@ def create_video_file( video = create_image_or_video_tensor(size) file = pathlib.Path(root) / name - torchvision.io.write_video(str(file), video.permute(0, 2, 3, 1), fps, **kwargs) + + encoder = VideoEncoder(video, frame_rate=fps) + encoder.to_file(str(file)) + return file -@requires_lazy_imports("av") +@requires_lazy_imports("torchcodec") def create_video_folder( root: Union[str, pathlib.Path], name: Union[str, pathlib.Path], @@ -933,7 +943,7 @@ def create_video_folder( List[pathlib.Path]: Paths to all created video files. Raises: - UsageError: If PyAV is not available. + UsageError: If TorchCodec is not available. .. seealso:: @@ -944,7 +954,7 @@ def create_video_folder( def size(idx): num_frames = 1 num_channels = 3 - # The 'libx264' video codec, which is the default of torchvision.io.write_video, requires the height and + # The 'libx264' video codec requires the height and # width of the video to be divisible by 2. height, width = (torch.randint(2, 6, size=(2,), dtype=torch.int) * 2).tolist() return (num_frames, num_channels, height, width) diff --git a/test/test_datasets_samplers.py b/test/test_datasets_samplers.py index 9e3826b2c13..222890da20c 100644 --- a/test/test_datasets_samplers.py +++ b/test/test_datasets_samplers.py @@ -1,12 +1,23 @@ +import sys + import pytest import torch from common_utils import assert_equal, get_list_of_videos -from torchvision import io from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler from torchvision.datasets.video_utils import VideoClips +try: + import torchcodec # noqa: F401 + + _torchcodec_available = True +except ImportError: + _torchcodec_available = False + -@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") +@pytest.mark.skipif( + not (_torchcodec_available and sys.platform == "linux"), + reason="this test requires torchcodec (linux only)", +) class TestDatasetsSamplers: def test_random_clip_sampler(self, tmpdir): video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25]) diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index 51330911e50..6d066a382b3 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -1,9 +1,22 @@ +import sys + import pytest import torch from common_utils import assert_equal, get_list_of_videos -from torchvision import io from torchvision.datasets.video_utils import unfold, VideoClips +try: + import torchcodec # noqa: F401 + + _torchcodec_available = True +except ImportError: + _torchcodec_available = False + +_requires_torchcodec = pytest.mark.skipif( + not (_torchcodec_available and sys.platform == "linux"), + reason="this test requires torchcodec (linux only)", +) + class TestVideo: def test_unfold(self): @@ -31,7 +44,7 @@ def test_unfold(self): ) assert_equal(r, expected) - @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") + @_requires_torchcodec def test_video_clips(self, tmpdir): video_list = get_list_of_videos(tmpdir, num_videos=3) video_clips = VideoClips(video_list, 5, 5, num_workers=2) @@ -55,7 +68,7 @@ def test_video_clips(self, tmpdir): assert video_idx == v_idx assert clip_idx == c_idx - @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") + @_requires_torchcodec def test_video_clips_custom_fps(self, tmpdir): video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) num_frames = 4 diff --git a/test/test_io.py b/test/test_io.py index 84d30ee3297..5194421105a 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -1,251 +1,4 @@ -import contextlib -import os -import sys -import tempfile - import pytest -import torch -import torchvision.io as io -from common_utils import assert_equal, cpu_and_cuda - - -try: - import av - - # Do a version test too - io.video._check_av_available() -except ImportError: - av = None - - -VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") - - -def _create_video_frames(num_frames, height, width): - y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width), indexing="ij") - data = [] - for i in range(num_frames): - xc = float(i) / num_frames - yc = 1 - float(i) / (2 * num_frames) - d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 - data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) - - return torch.stack(data, 0) - - -@contextlib.contextmanager -def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None): - if lossless: - if video_codec is not None: - raise ValueError("video_codec can't be specified together with lossless") - if options is not None: - raise ValueError("options can't be specified together with lossless") - video_codec = "libx264rgb" - options = {"crf": "0"} - - if video_codec is None: - video_codec = "libx264" - if options is None: - options = {} - - data = _create_video_frames(num_frames, height, width) - with tempfile.NamedTemporaryFile(suffix=".mp4") as f: - f.close() - io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options) - yield f.name, data - os.unlink(f.name) - - -@pytest.mark.skipif(av is None, reason="PyAV unavailable") -class TestVideo: - # compression adds artifacts, thus we add a tolerance of - # 6 in 0-255 range - TOLERANCE = 6 - - def test_write_read_video(self): - with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - lv, _, info = io.read_video(f_name) - assert_equal(data, lv) - assert info["video_fps"] == 5 - - def test_read_timestamps(self): - with temp_video(10, 300, 300, 5) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) - # note: not all formats/codecs provide accurate information for computing the - # timestamps. For the format that we use here, this information is available, - # so we use it as a baseline - with av.open(f_name) as container: - stream = container.streams[0] - pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) - num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) - expected_pts = [i * pts_step for i in range(num_frames)] - - assert pts == expected_pts - - @pytest.mark.parametrize("start", range(5)) - @pytest.mark.parametrize("offset", range(1, 4)) - def test_read_partial_video(self, start, offset): - with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) - - lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1]) - s_data = data[start : (start + offset)] - assert len(lv) == offset - assert_equal(s_data, lv) - - lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7]) - assert len(lv) == 4 - assert_equal(data[4:8], lv) - - @pytest.mark.parametrize("start", range(0, 80, 20)) - @pytest.mark.parametrize("offset", range(1, 4)) - def test_read_partial_video_bframes(self, start, offset): - # do not use lossless encoding, to test the presence of B-frames - options = {"bframes": "16", "keyint": "10", "min-keyint": "4"} - with temp_video(100, 300, 300, 5, options=options) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) - - lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1]) - s_data = data[start : (start + offset)] - assert len(lv) == offset - assert_equal(s_data, lv, rtol=0.0, atol=self.TOLERANCE) - - lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7]) - assert len(lv) == 4 - assert_equal(data[4:8], lv, rtol=0.0, atol=self.TOLERANCE) - - def test_read_packed_b_frames_divx_file(self): - name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi" - f_name = os.path.join(VIDEO_DIR, name) - pts, fps = io.read_video_timestamps(f_name) - - assert pts == sorted(pts) - assert fps == 30 - - def test_read_timestamps_from_packet(self): - with temp_video(10, 300, 300, 5, video_codec="mpeg4") as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) - # note: not all formats/codecs provide accurate information for computing the - # timestamps. For the format that we use here, this information is available, - # so we use it as a baseline - with av.open(f_name) as container: - stream = container.streams[0] - # make sure we went through the optimized codepath - assert b"Lavc" in stream.codec_context.extradata - pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) - num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) - expected_pts = [i * pts_step for i in range(num_frames)] - - assert pts == expected_pts - - def test_read_video_pts_unit_sec(self): - with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - lv, _, info = io.read_video(f_name, pts_unit="sec") - - assert_equal(data, lv) - assert info["video_fps"] == 5 - assert info == {"video_fps": 5} - - def test_read_timestamps_pts_unit_sec(self): - with temp_video(10, 300, 300, 5) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name, pts_unit="sec") - - with av.open(f_name) as container: - stream = container.streams[0] - pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) - num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) - expected_pts = [i * pts_step * stream.time_base for i in range(num_frames)] - - assert pts == expected_pts - - @pytest.mark.parametrize("start", range(5)) - @pytest.mark.parametrize("offset", range(1, 4)) - def test_read_partial_video_pts_unit_sec(self, start, offset): - with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name, pts_unit="sec") - - lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1], pts_unit="sec") - s_data = data[start : (start + offset)] - assert len(lv) == offset - assert_equal(s_data, lv) - - with av.open(f_name) as container: - stream = container.streams[0] - lv, _, _ = io.read_video( - f_name, int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7], pts_unit="sec" - ) - assert len(lv) == 4 - assert_equal(data[4:8], lv) - - def test_read_video_corrupted_file(self): - with tempfile.NamedTemporaryFile(suffix=".mp4") as f: - f.write(b"This is not an mpg4 file") - video, audio, info = io.read_video(f.name) - assert isinstance(video, torch.Tensor) - assert isinstance(audio, torch.Tensor) - assert video.numel() == 0 - assert audio.numel() == 0 - assert info == {} - - def test_read_video_timestamps_corrupted_file(self): - with tempfile.NamedTemporaryFile(suffix=".mp4") as f: - f.write(b"This is not an mpg4 file") - video_pts, video_fps = io.read_video_timestamps(f.name) - assert video_pts == [] - assert video_fps is None - - @pytest.mark.skip(reason="Temporarily disabled due to new pyav") - def test_read_video_partially_corrupted_file(self): - with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data): - with open(f_name, "r+b") as f: - size = os.path.getsize(f_name) - bytes_to_overwrite = size // 10 - # seek to the middle of the file - f.seek(5 * bytes_to_overwrite) - # corrupt 10% of the file from the middle - f.write(b"\xff" * bytes_to_overwrite) - # this exercises the container.decode assertion check - video, audio, info = io.read_video(f.name, pts_unit="sec") - # check that size is not equal to 5, but 3 - assert len(video) == 3 - # but the valid decoded content is still correct - assert_equal(video[:3], data[:3]) - # and the last few frames are wrong - with pytest.raises(AssertionError): - assert_equal(video, data) - - @pytest.mark.skipif(sys.platform == "win32", reason="temporarily disabled on Windows") - @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_write_video_with_audio(self, device, tmpdir): - f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4") - video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec") - - out_f_name = os.path.join(tmpdir, "testing.mp4") - io.video.write_video( - out_f_name, - video_tensor.to(device), - round(info["video_fps"]), - video_codec="libx264rgb", - options={"crf": "0"}, - audio_array=audio_tensor.to(device), - audio_fps=info["audio_fps"], - audio_codec="aac", - ) - - out_video_tensor, out_audio_tensor, out_info = io.read_video(out_f_name, pts_unit="sec") - - assert info["video_fps"] == out_info["video_fps"] - assert_equal(video_tensor, out_video_tensor) - - audio_stream = av.open(f_name).streams.audio[0] - out_audio_stream = av.open(out_f_name).streams.audio[0] - - assert info["audio_fps"] == out_info["audio_fps"] - assert audio_stream.rate == out_audio_stream.rate - assert pytest.approx(out_audio_stream.frames, rel=0.0, abs=1) == audio_stream.frames - assert audio_stream.frame_size == out_audio_stream.frame_size - - # TODO add tests for audio if __name__ == "__main__": diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index ad26299cff6..a95737a571d 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -4,13 +4,24 @@ from typing import Any, Optional, TypeVar, Union import torch -from torchvision.io import read_video, read_video_timestamps from .utils import tqdm T = TypeVar("T") +def _get_torchcodec(): + try: + import torchcodec # type: ignore[import-not-found] + except ImportError: + raise ImportError( + "Video decoding capabilities were removed from torchvision and migrated " + "to TorchCodec. Please install TorchCodec following instructions at " + "https://github.com/pytorch/torchcodec#installing-torchcodec" + ) + return torchcodec + + def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor: """ similar to tensor.unfold, but with the dilation @@ -47,7 +58,11 @@ def __len__(self) -> int: return len(self.video_paths) def __getitem__(self, idx: int) -> tuple[list[int], Optional[float]]: - return read_video_timestamps(self.video_paths[idx]) + torchcodec = _get_torchcodec() + decoder = torchcodec.decoders.VideoDecoder(self.video_paths[idx]) + num_frames = decoder.metadata.num_frames + fps = decoder.metadata.average_fps + return list(range(num_frames)), fps def _collate_fn(x: T) -> T: @@ -292,9 +307,27 @@ def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any] video_path = self.video_paths[video_idx] clip_pts = self.clips[video_idx][clip_idx] - start_pts = clip_pts[0].item() - end_pts = clip_pts[-1].item() - video, audio, info = read_video(video_path, start_pts, end_pts) + start_idx = int(clip_pts[0].item()) + end_idx = int(clip_pts[-1].item()) + + torchcodec = _get_torchcodec() + + dimension_order = "NHWC" if self.output_format == "THWC" else "NCHW" + decoder = torchcodec.decoders.VideoDecoder(video_path, dimension_order=dimension_order) + video = decoder.get_frames_at(indices=list(range(start_idx, end_idx + 1))).data + + # Audio via TorchCodec + fps = decoder.metadata.average_fps + start_sec = start_idx / fps + end_sec = (end_idx + 1) / fps + try: + audio_decoder = torchcodec.decoders.AudioDecoder(video_path) + audio_samples = audio_decoder.get_samples_played_in_range(start_seconds=start_sec, stop_seconds=end_sec) + audio = audio_samples.data + except Exception: + audio = torch.empty((1, 0), dtype=torch.float32) + + info = {"video_fps": fps} if self.frame_rate is not None: resampling_idx = self.resampling_idxs[video_idx][clip_idx] @@ -304,10 +337,6 @@ def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any] info["video_fps"] = self.frame_rate assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}" - if self.output_format == "TCHW": - # [T,H,W,C] --> [T,C,H,W] - video = video.permute(0, 3, 1, 2) - return video, audio, info, video_idx def __getstate__(self) -> dict[str, Any]: diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index a486b0275e1..02e28e107c6 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -18,6 +18,15 @@ except ImportError: pass +try: + from pytorch.vision.fb.io.video import ( # type: ignore[import-not-found] + read_video, + read_video_timestamps, + write_video, + ) +except ImportError: + pass + from .image import ( decode_avif, decode_gif, @@ -35,13 +44,9 @@ write_jpeg, write_png, ) -from .video import read_video, read_video_timestamps, write_video __all__ = [ - "write_video", - "read_video", - "read_video_timestamps", "ImageReadMode", "decode_image", "decode_jpeg", diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 5331b764d27..87fe36f2caa 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -1,453 +1,23 @@ -import gc -import math -import re -import warnings -from fractions import Fraction -from typing import Any, Optional, Union - -import numpy as np -import torch - -from ..utils import _log_api_usage_once -from ._video_deprecation_warning import _raise_video_deprecation_warning - -try: - import av - - av.logging.set_level(av.logging.ERROR) - if not hasattr(av.video.frame.VideoFrame, "pict_type"): - av = ImportError( - """\ -Your version of PyAV is too old for the necessary video operations in torchvision. -If you are on Python 3.5, you will have to build from source (the conda-forge -packages are not up-to-date). See -https://github.com/mikeboers/PyAV#installation for instructions on how to -install PyAV on your system. -""" - ) - try: - FFmpegError = av.FFmpegError # from av 14 https://github.com/PyAV-Org/PyAV/blob/main/CHANGELOG.rst - except AttributeError: - FFmpegError = av.AVError -except ImportError: - av = ImportError( - """\ -PyAV is not installed, and is necessary for the video operations in torchvision. -See https://github.com/mikeboers/PyAV#installation for instructions on how to -install PyAV on your system. -""" - ) - - -def _check_av_available() -> None: - if isinstance(av, Exception): - raise av - - -def _av_available() -> bool: - return not isinstance(av, Exception) - - -# PyAV has some reference cycles -_CALLED_TIMES = 0 -_GC_COLLECTION_INTERVAL = 10 - - -def write_video( - filename: str, - video_array: torch.Tensor, - fps: float, - video_codec: str = "libx264", - options: Optional[dict[str, Any]] = None, - audio_array: Optional[torch.Tensor] = None, - audio_fps: Optional[float] = None, - audio_codec: Optional[str] = None, - audio_options: Optional[dict[str, Any]] = None, -) -> None: - """ - [DEPRECATED] Writes a 4d tensor in [T, H, W, C] format in a video file. - - .. warning:: - - DEPRECATED: All the video decoding and encoding capabilities of torchvision - are deprecated from version 0.22 and will be removed in version 0.24. We - recommend that you migrate to - `TorchCodec `__, where we'll - consolidate the future decoding/encoding capabilities of PyTorch - - This function relies on PyAV (therefore, ultimately FFmpeg) to encode - videos, you can get more fine-grained control by referring to the other - options at your disposal within `the FFMpeg wiki - `_. - - Args: - filename (str): path where the video will be saved - video_array (Tensor[T, H, W, C]): tensor containing the individual frames, - as a uint8 tensor in [T, H, W, C] format - fps (Number): video frames per second - video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc. - options (Dict): dictionary containing options to be passed into the PyAV video stream. - The list of options is codec-dependent and can all - be found from `the FFMpeg wiki `_. - audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels - and N is the number of samples - audio_fps (Number): audio sample rate, typically 44100 or 48000 - audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc. - audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream. - The list of options is codec-dependent and can all - be found from `the FFMpeg wiki `_. - - Examples:: - >>> # Creating libx264 video with CRF 17, for visually lossless footage: - >>> - >>> from torchvision.io import write_video - >>> # 1000 frames of 100x100, 3-channel image. - >>> vid = torch.randn(1000, 100, 100, 3, dtype = torch.uint8) - >>> write_video("video.mp4", options = {"crf": "17"}) - - """ - _raise_video_deprecation_warning() - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(write_video) - _check_av_available() - video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy(force=True) - - # PyAV does not support floating point numbers with decimal point - # and will throw OverflowException in case this is not the case - if isinstance(fps, float): - fps = int(np.round(fps)) - - with av.open(filename, mode="w") as container: - stream = container.add_stream(video_codec, rate=fps) - stream.width = video_array.shape[2] - stream.height = video_array.shape[1] - stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24" - stream.options = options or {} - - if audio_array is not None: - audio_format_dtypes = { - "dbl": " 1 else "mono" - audio_sample_fmt = container.streams.audio[0].format.name - - format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt]) - audio_array = torch.as_tensor(audio_array).numpy(force=True).astype(format_dtype) - - frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout) - - frame.sample_rate = audio_fps - - for packet in a_stream.encode(frame): - container.mux(packet) - - for packet in a_stream.encode(): - container.mux(packet) - - for img in video_array: - frame = av.VideoFrame.from_ndarray(img, format="rgb24") - try: - frame.pict_type = "NONE" - except TypeError: - from av.video.frame import PictureType # noqa - - frame.pict_type = PictureType.NONE - - for packet in stream.encode(frame): - container.mux(packet) - - # Flush stream - for packet in stream.encode(): - container.mux(packet) - - -def _read_from_stream( - container: "av.container.Container", - start_offset: float, - end_offset: float, - pts_unit: str, - stream: "av.stream.Stream", - stream_name: dict[str, Optional[Union[int, tuple[int, ...], list[int]]]], -) -> list["av.frame.Frame"]: - global _CALLED_TIMES, _GC_COLLECTION_INTERVAL - _CALLED_TIMES += 1 - if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: - gc.collect() - - if pts_unit == "sec": - # TODO: we should change all of this from ground up to simply take - # sec and convert to MS in C++ - start_offset = int(math.floor(start_offset * (1 / stream.time_base))) - if end_offset != float("inf"): - end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) - else: - warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.") - - frames = {} - should_buffer = True - max_buffer_size = 5 - if stream.type == "video": - # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt) - # so need to buffer some extra frames to sort everything - # properly - extradata = stream.codec_context.extradata - # overly complicated way of finding if `divx_packed` is set, following - # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263 - if extradata and b"DivX" in extradata: - # can't use regex directly because of some weird characters sometimes... - pos = extradata.find(b"DivX") - d = extradata[pos:] - o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d) - if o is None: - o = re.search(rb"DivX(\d+)b(\d+)(\w)", d) - if o is not None: - should_buffer = o.group(3) == b"p" - seek_offset = start_offset - # some files don't seek to the right location, so better be safe here - seek_offset = max(seek_offset - 1, 0) - if should_buffer: - # FIXME this is kind of a hack, but we will jump to the previous keyframe - # so this will be safe - seek_offset = max(seek_offset - max_buffer_size, 0) - try: - # TODO check if stream needs to always be the video stream here or not - container.seek(seek_offset, any_frame=False, backward=True, stream=stream) - except FFmpegError: - # TODO add some warnings in this case - # print("Corrupted file?", container.name) - return [] - buffer_count = 0 - try: - for _idx, frame in enumerate(container.decode(**stream_name)): - frames[frame.pts] = frame - if frame.pts >= end_offset: - if should_buffer and buffer_count < max_buffer_size: - buffer_count += 1 - continue - break - except FFmpegError: - # TODO add a warning - pass - # ensure that the results are sorted wrt the pts - result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset] - if len(frames) > 0 and start_offset > 0 and start_offset not in frames: - # if there is no frame that exactly matches the pts of start_offset - # add the last frame smaller than start_offset, to guarantee that - # we will have all the necessary data. This is most useful for audio - preceding_frames = [i for i in frames if i < start_offset] - if len(preceding_frames) > 0: - first_frame_pts = max(preceding_frames) - result.insert(0, frames[first_frame_pts]) - return result - - -def _align_audio_frames( - aframes: torch.Tensor, audio_frames: list["av.frame.Frame"], ref_start: int, ref_end: float -) -> torch.Tensor: - start, end = audio_frames[0].pts, audio_frames[-1].pts - total_aframes = aframes.shape[1] - step_per_aframe = (end - start + 1) / total_aframes - s_idx = 0 - e_idx = total_aframes - if start < ref_start: - s_idx = int((ref_start - start) / step_per_aframe) - if end > ref_end: - e_idx = int((ref_end - end) / step_per_aframe) - return aframes[:, s_idx:e_idx] - - -def read_video( - filename: str, - start_pts: Union[float, Fraction] = 0, - end_pts: Optional[Union[float, Fraction]] = None, - pts_unit: str = "pts", - output_format: str = "THWC", -) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: - """[DEPRECATED] Reads a video from a file, returning both the video frames and the audio frames - - .. warning:: - - DEPRECATED: All the video decoding and encoding capabilities of torchvision - are deprecated from version 0.22 and will be removed in version 0.24. We - recommend that you migrate to - `TorchCodec `__, where we'll - consolidate the future decoding/encoding capabilities of PyTorch - - Args: - filename (str): path to the video file. If using the pyav backend, this can be whatever ``av.open`` accepts. - start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): - The start presentation time of the video - end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): - The end presentation time - pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, - either 'pts' or 'sec'. Defaults to 'pts'. - output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". - - Returns: - vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames - aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points - info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) - """ - _raise_video_deprecation_warning() - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(read_video) - - output_format = output_format.upper() - if output_format not in ("THWC", "TCHW"): - raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") - - _check_av_available() - - if end_pts is None: - end_pts = float("inf") - - if end_pts < start_pts: - raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") - - info = {} - video_frames = [] - audio_frames = [] - audio_timebase = Fraction(0, 1) - - try: - with av.open(filename, metadata_errors="ignore") as container: - if container.streams.audio: - audio_timebase = container.streams.audio[0].time_base - if container.streams.video: - video_frames = _read_from_stream( - container, - start_pts, - end_pts, - pts_unit, - container.streams.video[0], - {"video": 0}, - ) - video_fps = container.streams.video[0].average_rate - # guard against potentially corrupted files - if video_fps is not None: - info["video_fps"] = float(video_fps) - - if container.streams.audio: - audio_frames = _read_from_stream( - container, - start_pts, - end_pts, - pts_unit, - container.streams.audio[0], - {"audio": 0}, - ) - info["audio_fps"] = container.streams.audio[0].rate - - except FFmpegError: - # TODO raise a warning? - pass - - vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] - aframes_list = [frame.to_ndarray() for frame in audio_frames] - - if vframes_list: - vframes = torch.as_tensor(np.stack(vframes_list)) - else: - vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) - - if aframes_list: - aframes = np.concatenate(aframes_list, 1) - aframes = torch.as_tensor(aframes) - if pts_unit == "sec": - start_pts = int(math.floor(start_pts * (1 / audio_timebase))) - if end_pts != float("inf"): - end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) - aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) - else: - aframes = torch.empty((1, 0), dtype=torch.float32) - - if output_format == "TCHW": - # [T,H,W,C] --> [T,C,H,W] - vframes = vframes.permute(0, 3, 1, 2) - - return vframes, aframes, info - - -def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool: - extradata = container.streams[0].codec_context.extradata - if extradata is None: - return False - if b"Lavc" in extradata: - return True - return False - - -def _decode_video_timestamps(container: "av.container.Container") -> list[int]: - if _can_read_timestamps_from_packets(container): - # fast path - return [x.pts for x in container.demux(video=0) if x.pts is not None] - else: - return [x.pts for x in container.decode(video=0) if x.pts is not None] - - -def read_video_timestamps(filename: str, pts_unit: str = "pts") -> tuple[list[int], Optional[float]]: - """[DEPREACTED] List the video frames timestamps. - - .. warning:: - - DEPRECATED: All the video decoding and encoding capabilities of torchvision - are deprecated from version 0.22 and will be removed in version 0.24. We - recommend that you migrate to - `TorchCodec `__, where we'll - consolidate the future decoding/encoding capabilities of PyTorch - - Note that the function decodes the whole video frame-by-frame. - - Args: - filename (str): path to the video file - pts_unit (str, optional): unit in which timestamp values will be returned - either 'pts' or 'sec'. Defaults to 'pts'. - - Returns: - pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'): - presentation timestamps for each one of the frames in the video. - video_fps (float, optional): the frame rate for the video - - """ - _raise_video_deprecation_warning() - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(read_video_timestamps) - - _check_av_available() - - video_fps = None - pts = [] - - try: - with av.open(filename, metadata_errors="ignore") as container: - if container.streams.video: - video_stream = container.streams.video[0] - video_time_base = video_stream.time_base - try: - pts = _decode_video_timestamps(container) - except FFmpegError: - warnings.warn(f"Failed decoding frames for file {filename}") - video_fps = float(video_stream.average_rate) - except FFmpegError as e: - msg = f"Failed to open container for {filename}; Caught error: {e}" - warnings.warn(msg, RuntimeWarning) - - pts.sort() - - if pts_unit == "sec": - pts = [x * video_time_base for x in pts] - - return pts, video_fps +# This module re-exports video utilities from the internal fb location. +# The actual implementation lives in pytorch.vision.fb.io.video +from pytorch.vision.fb.io.video import ( # type: ignore[import-not-found] + _align_audio_frames, + _av_available, + _check_av_available, + _read_from_stream, + av, + read_video, + read_video_timestamps, + write_video, +) + +__all__ = [ + "read_video", + "read_video_timestamps", + "write_video", + "_read_from_stream", + "_align_audio_frames", + "_check_av_available", + "_av_available", + "av", +]