Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-picked 0.12] Modify Pitchshift for faster resampling #2441

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ def test_PitchShift(self):
sample_rate = 8000
n_steps = 4
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.PitchShift(sample_rate=sample_rate, n_steps=n_steps), waveform)
pitch_shift = T.PitchShift(sample_rate=sample_rate, n_steps=n_steps)
# dry-run for initializing parameters
pitch_shift(waveform)
self._assert_consistency(pitch_shift, waveform)

def test_PSD(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
Expand Down
19 changes: 19 additions & 0 deletions test/torchaudio_unittest/transforms/transforms_test_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torchaudio.transforms as T
from torchaudio.functional.functional import _get_sinc_resample_kernel
from parameterized import param, parameterized
from torchaudio_unittest.common_utils import (
get_spectrogram,
Expand Down Expand Up @@ -147,3 +148,21 @@ def test_mvdr(self, dtype):
mask_n = torch.rand(specgram.shape[-2:])
specgram_enhanced = transform(specgram, mask_s, mask_n)
assert specgram_enhanced.dtype == dtype

def test_pitch_shift_resample_kernel(self):
"""The resampling kernel in PitchShift is identical to what helper function generates.
There should be no numerical difference caused by dtype conversion.
"""
sample_rate = 8000
trans = T.PitchShift(sample_rate=sample_rate, n_steps=4)
trans.to(self.dtype).to(self.device)
# dry run to initialize the kernel
trans(torch.randn(2, 8000, dtype=self.dtype, device=self.device))

expected, _ = _get_sinc_resample_kernel(
trans.orig_freq,
sample_rate,
trans.gcd,
device=self.device,
dtype=self.dtype)
self.assertEqual(trans.kernel, expected)
62 changes: 56 additions & 6 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
import warnings
from collections.abc import Sequence
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, List

import torch
import torchaudio
Expand Down Expand Up @@ -1389,10 +1389,10 @@ def _get_sinc_resample_kernel(
orig_freq: int,
new_freq: int,
gcd: int,
lowpass_filter_width: int,
rolloff: float,
resampling_method: str,
beta: Optional[float],
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation",
beta: Optional[float] = None,
device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None,
):
Expand Down Expand Up @@ -1635,6 +1635,39 @@ def pitch_shift(
Returns:
Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
"""
waveform_stretch = _stretch_waveform(
waveform,
n_steps,
bins_per_octave,
n_fft,
win_length,
hop_length,
window,
)
rate = 2.0 ** (-float(n_steps) / bins_per_octave)
waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)

return _fix_waveform_shape(waveform_shift, waveform.size())


def _stretch_waveform(
waveform: Tensor,
n_steps: int,
bins_per_octave: int = 12,
n_fft: int = 512,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
window: Optional[Tensor] = None,
) -> Tensor:
"""
Pitch shift helper function to preprocess and stretch waveform before resampling step.

Args:
See pitch_shift arg descriptions.

Returns:
Tensor: The preprocessed waveform stretched prior to resampling.
"""
if hop_length is None:
hop_length = n_fft // 4
if win_length is None:
Expand Down Expand Up @@ -1666,7 +1699,24 @@ def pitch_shift(
waveform_stretch = torch.istft(
spec_stretch, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=len_stretch
)
waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)
return waveform_stretch


def _fix_waveform_shape(
waveform_shift: Tensor,
shape: List[int],
) -> Tensor:
"""
PitchShift helper function to process after resampling step to fix the shape back.

Args:
waveform_shift(Tensor): The waveform after stretch and resample
shape (List[int]): The shape of initial waveform

Returns:
Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
"""
ori_len = shape[-1]
shift_len = waveform_shift.size()[-1]
if shift_len > ori_len:
waveform_shift = waveform_shift[..., :ori_len]
Expand Down
54 changes: 50 additions & 4 deletions torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@

import torch
from torch import Tensor
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import UninitializedParameter

from torchaudio import functional as F
from torchaudio.functional.functional import (
_apply_sinc_resample_kernel,
_get_sinc_resample_kernel,
_stretch_waveform,
_fix_waveform_shape,
)

__all__ = []
Expand Down Expand Up @@ -1511,7 +1516,7 @@ def forward(self, waveform: Tensor) -> Tensor:
)


class PitchShift(torch.nn.Module):
class PitchShift(LazyModuleMixin, torch.nn.Module):
r"""Shift the pitch of a waveform by ``n_steps`` steps.

.. devices:: CPU CUDA
Expand All @@ -1537,6 +1542,9 @@ class PitchShift(torch.nn.Module):
"""
__constants__ = ["sample_rate", "n_steps", "bins_per_octave", "n_fft", "win_length", "hop_length"]

kernel: UninitializedParameter
width: int

def __init__(
self,
sample_rate: int,
Expand All @@ -1548,7 +1556,7 @@ def __init__(
window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None,
) -> None:
super(PitchShift, self).__init__()
super().__init__()
self.n_steps = n_steps
self.bins_per_octave = bins_per_octave
self.sample_rate = sample_rate
Expand All @@ -1557,6 +1565,27 @@ def __init__(
self.hop_length = hop_length if hop_length is not None else self.win_length // 4
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer("window", window)
rate = 2.0 ** (-float(n_steps) / bins_per_octave)
self.orig_freq = int(sample_rate / rate)
self.gcd = math.gcd(int(self.orig_freq), int(sample_rate))

if self.orig_freq != sample_rate:
self.width = -1
self.kernel = UninitializedParameter(device=None, dtype=None)

def initialize_parameters(self, input):
if self.has_uninitialized_params():
if self.orig_freq != self.sample_rate:
with torch.no_grad():
kernel, self.width = _get_sinc_resample_kernel(
self.orig_freq,
self.sample_rate,
self.gcd,
dtype=input.dtype,
device=input.device,
)
self.kernel.materialize(kernel.shape)
self.kernel.copy_(kernel)

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand All @@ -1566,10 +1595,10 @@ def forward(self, waveform: Tensor) -> Tensor:
Returns:
Tensor: The pitch-shifted audio of shape `(..., time)`.
"""
shape = waveform.size()

return F.pitch_shift(
waveform_stretch = _stretch_waveform(
waveform,
self.sample_rate,
self.n_steps,
self.bins_per_octave,
self.n_fft,
Expand All @@ -1578,6 +1607,23 @@ def forward(self, waveform: Tensor) -> Tensor:
self.window,
)

if self.orig_freq != self.sample_rate:
waveform_shift = _apply_sinc_resample_kernel(
waveform_stretch,
self.orig_freq,
self.sample_rate,
self.gcd,
self.kernel,
self.width,
)
else:
waveform_shift = waveform_stretch

return _fix_waveform_shape(
waveform_shift,
shape,
)


class RNNTLoss(torch.nn.Module):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
Expand Down