diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index e6cc615ec0..36116d2e96 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -139,7 +139,10 @@ def test_PitchShift(self): sample_rate = 8000 n_steps = 4 waveform = common_utils.get_whitenoise(sample_rate=sample_rate) - self._assert_consistency(T.PitchShift(sample_rate=sample_rate, n_steps=n_steps), waveform) + pitch_shift = T.PitchShift(sample_rate=sample_rate, n_steps=n_steps) + # dry-run for initializing parameters + pitch_shift(waveform) + self._assert_consistency(pitch_shift, waveform) def test_PSD(self): tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4) diff --git a/test/torchaudio_unittest/transforms/transforms_test_impl.py b/test/torchaudio_unittest/transforms/transforms_test_impl.py index 9caeea4156..18403165e5 100644 --- a/test/torchaudio_unittest/transforms/transforms_test_impl.py +++ b/test/torchaudio_unittest/transforms/transforms_test_impl.py @@ -1,5 +1,6 @@ import torch import torchaudio.transforms as T +from torchaudio.functional.functional import _get_sinc_resample_kernel from parameterized import param, parameterized from torchaudio_unittest.common_utils import ( get_spectrogram, @@ -147,3 +148,21 @@ def test_mvdr(self, dtype): mask_n = torch.rand(specgram.shape[-2:]) specgram_enhanced = transform(specgram, mask_s, mask_n) assert specgram_enhanced.dtype == dtype + + def test_pitch_shift_resample_kernel(self): + """The resampling kernel in PitchShift is identical to what helper function generates. + There should be no numerical difference caused by dtype conversion. + """ + sample_rate = 8000 + trans = T.PitchShift(sample_rate=sample_rate, n_steps=4) + trans.to(self.dtype).to(self.device) + # dry run to initialize the kernel + trans(torch.randn(2, 8000, dtype=self.dtype, device=self.device)) + + expected, _ = _get_sinc_resample_kernel( + trans.orig_freq, + sample_rate, + trans.gcd, + device=self.device, + dtype=self.dtype) + self.assertEqual(trans.kernel, expected) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 6c77820bd9..665bf8c1f4 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -4,7 +4,7 @@ import math import warnings from collections.abc import Sequence -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, List import torch import torchaudio @@ -1389,10 +1389,10 @@ def _get_sinc_resample_kernel( orig_freq: int, new_freq: int, gcd: int, - lowpass_filter_width: int, - rolloff: float, - resampling_method: str, - beta: Optional[float], + lowpass_filter_width: int = 6, + rolloff: float = 0.99, + resampling_method: str = "sinc_interpolation", + beta: Optional[float] = None, device: torch.device = torch.device("cpu"), dtype: Optional[torch.dtype] = None, ): @@ -1635,6 +1635,39 @@ def pitch_shift( Returns: Tensor: The pitch-shifted audio waveform of shape `(..., time)`. """ + waveform_stretch = _stretch_waveform( + waveform, + n_steps, + bins_per_octave, + n_fft, + win_length, + hop_length, + window, + ) + rate = 2.0 ** (-float(n_steps) / bins_per_octave) + waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate) + + return _fix_waveform_shape(waveform_shift, waveform.size()) + + +def _stretch_waveform( + waveform: Tensor, + n_steps: int, + bins_per_octave: int = 12, + n_fft: int = 512, + win_length: Optional[int] = None, + hop_length: Optional[int] = None, + window: Optional[Tensor] = None, +) -> Tensor: + """ + Pitch shift helper function to preprocess and stretch waveform before resampling step. + + Args: + See pitch_shift arg descriptions. + + Returns: + Tensor: The preprocessed waveform stretched prior to resampling. + """ if hop_length is None: hop_length = n_fft // 4 if win_length is None: @@ -1666,7 +1699,24 @@ def pitch_shift( waveform_stretch = torch.istft( spec_stretch, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=len_stretch ) - waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate) + return waveform_stretch + + +def _fix_waveform_shape( + waveform_shift: Tensor, + shape: List[int], +) -> Tensor: + """ + PitchShift helper function to process after resampling step to fix the shape back. + + Args: + waveform_shift(Tensor): The waveform after stretch and resample + shape (List[int]): The shape of initial waveform + + Returns: + Tensor: The pitch-shifted audio waveform of shape `(..., time)`. + """ + ori_len = shape[-1] shift_len = waveform_shift.size()[-1] if shift_len > ori_len: waveform_shift = waveform_shift[..., :ori_len] diff --git a/torchaudio/transforms/_transforms.py b/torchaudio/transforms/_transforms.py index 92c4622b69..489901f9ea 100644 --- a/torchaudio/transforms/_transforms.py +++ b/torchaudio/transforms/_transforms.py @@ -6,10 +6,15 @@ import torch from torch import Tensor +from torch.nn.modules.lazy import LazyModuleMixin +from torch.nn.parameter import UninitializedParameter + from torchaudio import functional as F from torchaudio.functional.functional import ( _apply_sinc_resample_kernel, _get_sinc_resample_kernel, + _stretch_waveform, + _fix_waveform_shape, ) __all__ = [] @@ -1511,7 +1516,7 @@ def forward(self, waveform: Tensor) -> Tensor: ) -class PitchShift(torch.nn.Module): +class PitchShift(LazyModuleMixin, torch.nn.Module): r"""Shift the pitch of a waveform by ``n_steps`` steps. .. devices:: CPU CUDA @@ -1537,6 +1542,9 @@ class PitchShift(torch.nn.Module): """ __constants__ = ["sample_rate", "n_steps", "bins_per_octave", "n_fft", "win_length", "hop_length"] + kernel: UninitializedParameter + width: int + def __init__( self, sample_rate: int, @@ -1548,7 +1556,7 @@ def __init__( window_fn: Callable[..., Tensor] = torch.hann_window, wkwargs: Optional[dict] = None, ) -> None: - super(PitchShift, self).__init__() + super().__init__() self.n_steps = n_steps self.bins_per_octave = bins_per_octave self.sample_rate = sample_rate @@ -1557,6 +1565,27 @@ def __init__( self.hop_length = hop_length if hop_length is not None else self.win_length // 4 window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) self.register_buffer("window", window) + rate = 2.0 ** (-float(n_steps) / bins_per_octave) + self.orig_freq = int(sample_rate / rate) + self.gcd = math.gcd(int(self.orig_freq), int(sample_rate)) + + if self.orig_freq != sample_rate: + self.width = -1 + self.kernel = UninitializedParameter(device=None, dtype=None) + + def initialize_parameters(self, input): + if self.has_uninitialized_params(): + if self.orig_freq != self.sample_rate: + with torch.no_grad(): + kernel, self.width = _get_sinc_resample_kernel( + self.orig_freq, + self.sample_rate, + self.gcd, + dtype=input.dtype, + device=input.device, + ) + self.kernel.materialize(kernel.shape) + self.kernel.copy_(kernel) def forward(self, waveform: Tensor) -> Tensor: r""" @@ -1566,10 +1595,10 @@ def forward(self, waveform: Tensor) -> Tensor: Returns: Tensor: The pitch-shifted audio of shape `(..., time)`. """ + shape = waveform.size() - return F.pitch_shift( + waveform_stretch = _stretch_waveform( waveform, - self.sample_rate, self.n_steps, self.bins_per_octave, self.n_fft, @@ -1578,6 +1607,23 @@ def forward(self, waveform: Tensor) -> Tensor: self.window, ) + if self.orig_freq != self.sample_rate: + waveform_shift = _apply_sinc_resample_kernel( + waveform_stretch, + self.orig_freq, + self.sample_rate, + self.gcd, + self.kernel, + self.width, + ) + else: + waveform_shift = waveform_stretch + + return _fix_waveform_shape( + waveform_shift, + shape, + ) + class RNNTLoss(torch.nn.Module): """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*