From 5a3bfe89d6360de90e17e383f41de3c90e8a5f9c Mon Sep 17 00:00:00 2001 From: Lubosz Sarnecki Date: Sun, 14 Dec 2025 10:00:43 +0100 Subject: [PATCH] df/io: Add support for torchaudio 2.9. torchaudio drops support for meta data functionality in favor of using torchcodec. See https://github.com/pytorch/audio/issues/3902 This patch implements a new path for torchaudio versions 2.9+ by checking the version number. In uses the new torchcodec API for retrieving metadata and decoding. Signed-off-by: Lubosz Sarnecki --- DeepFilterNet/df/io.py | 45 +++++++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/DeepFilterNet/df/io.py b/DeepFilterNet/df/io.py index 58024b152..f46c2014b 100644 --- a/DeepFilterNet/df/io.py +++ b/DeepFilterNet/df/io.py @@ -6,17 +6,30 @@ from loguru import logger from numpy import ndarray from torch import Tensor +from packaging import version -try: - from torchaudio import AudioMetaData + +if version.parse(ta.__version__) >= version.parse("2.9.0"): + from torchcodec.decoders import AudioDecoder + from torchcodec._core import AudioStreamMetadata + + AudioStreamMetadataType = AudioStreamMetadata TA_RESAMPLE_SINC = "sinc_interp_hann" TA_RESAMPLE_KAISER = "sinc_interp_kaiser" -except ImportError: - from torchaudio.backend.common import AudioMetaData +else: + try: + from torchaudio import AudioMetaData + + TA_RESAMPLE_SINC = "sinc_interp_hann" + TA_RESAMPLE_KAISER = "sinc_interp_kaiser" + except ImportError: + from torchaudio.backend.common import AudioMetaData - TA_RESAMPLE_SINC = "sinc_interpolation" - TA_RESAMPLE_KAISER = "kaiser_window" + TA_RESAMPLE_SINC = "sinc_interpolation" + TA_RESAMPLE_KAISER = "kaiser_window" + + AudioStreamMetadataType = AudioMetaData from df.logger import warn_once from df.utils import download_file, get_cache_dir, get_git_root @@ -24,7 +37,7 @@ def load_audio( file: str, sr: Optional[int] = None, verbose=True, **kwargs -) -> Tuple[Tensor, AudioMetaData]: +) -> Tuple[Tensor, AudioStreamMetadataType]: """Loads an audio file using torchaudio. Args: @@ -43,10 +56,20 @@ def load_audio( rkwargs = {} if "method" in kwargs: rkwargs["method"] = kwargs.pop("method") - info: AudioMetaData = ta.info(file, **ikwargs) - if "num_frames" in kwargs and sr is not None: - kwargs["num_frames"] *= info.sample_rate // sr - audio, orig_sr = ta.load(file, **kwargs) + + if version.parse(ta.__version__) >= version.parse("2.9.0"): + decoder = AudioDecoder(file) + info: AudioStreamMetadata = decoder.metadata + samples = decoder.get_all_samples() + audio = samples.data + orig_sr = samples.sample_rate + else: + info: AudioMetaData = ta.info(file, **ikwargs) + + if "num_frames" in kwargs and sr is not None: + kwargs["num_frames"] *= info.sample_rate // sr + audio, orig_sr = ta.load(file, **kwargs) + if sr is not None and orig_sr != sr: if verbose: warn_once(