Skip to content

Commit

Permalink
Adopt lazy protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored and Sean Kim committed Jun 10, 2022
1 parent 0a8c182 commit cf7c53a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ def test_PitchShift(self):
sample_rate = 8000
n_steps = 4
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.PitchShift(sample_rate=sample_rate, n_steps=n_steps), waveform)
pitch_shift = T.PitchShift(sample_rate=sample_rate, n_steps=n_steps)
# dry-run for initializing parameters
pitch_shift(waveform)
self._assert_consistency(pitch_shift, waveform)

def test_PSD(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
Expand Down
19 changes: 19 additions & 0 deletions test/torchaudio_unittest/transforms/transforms_test_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torchaudio.transforms as T
from torchaudio.functional.functional import _get_sinc_resample_kernel
from parameterized import param, parameterized
from torchaudio_unittest.common_utils import (
get_spectrogram,
Expand Down Expand Up @@ -148,3 +149,21 @@ def test_mvdr(self, dtype):
mask_n = torch.rand(specgram.shape[-2:])
specgram_enhanced = transform(specgram, mask_s, mask_n)
assert specgram_enhanced.dtype == dtype

def test_pitch_shift_resample_kernel(self):
"""The resampling kernel in PitchShift is identical to what helper function generates.
There should be no numerical difference caused by dtype conversion.
"""
sample_rate = 8000
trans = T.PitchShift(sample_rate=sample_rate, n_steps=4)
trans.to(self.dtype).to(self.device)
# dry run to initialize the kernel
trans(torch.randn(2, 8000, dtype=self.dtype, device=self.device))

expected, _ = _get_sinc_resample_kernel(
trans.orig_freq,
sample_rate,
trans.gcd,
device=self.device,
dtype=self.dtype)
self.assertEqual(trans.kernel, expected)
32 changes: 24 additions & 8 deletions torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import torch
from torch import Tensor
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import UninitializedParameter

from torchaudio import functional as F
from torchaudio.functional.functional import (
_apply_sinc_resample_kernel,
Expand Down Expand Up @@ -1513,7 +1516,7 @@ def forward(self, waveform: Tensor) -> Tensor:
)


class PitchShift(torch.nn.Module):
class PitchShift(LazyModuleMixin, torch.nn.Module):
r"""Shift the pitch of a waveform by ``n_steps`` steps.
.. devices:: CPU CUDA
Expand All @@ -1539,6 +1542,9 @@ class PitchShift(torch.nn.Module):
"""
__constants__ = ["sample_rate", "n_steps", "bins_per_octave", "n_fft", "win_length", "hop_length"]

kernel: UninitializedParameter
width: int

def __init__(
self,
sample_rate: int,
Expand All @@ -1550,7 +1556,7 @@ def __init__(
window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None,
) -> None:
super(PitchShift, self).__init__()
super().__init__()
self.n_steps = n_steps
self.bins_per_octave = bins_per_octave
self.sample_rate = sample_rate
Expand All @@ -1564,12 +1570,22 @@ def __init__(
self.gcd = math.gcd(int(self.orig_freq), int(sample_rate))

if self.orig_freq != sample_rate:
kernel, self.width = _get_sinc_resample_kernel(
self.orig_freq,
sample_rate,
self.gcd,
)
self.register_buffer("kernel", kernel)
self.width = -1
self.kernel = UninitializedParameter(device=None, dtype=None)

def initialize_parameters(self, input):
if self.has_uninitialized_params():
if self.orig_freq != self.sample_rate:
with torch.no_grad():
kernel, self.width = _get_sinc_resample_kernel(
self.orig_freq,
self.sample_rate,
self.gcd,
dtype=input.dtype,
device=input.device,
)
self.kernel.materialize(kernel.shape)
self.kernel.copy_(kernel)

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand Down

0 comments on commit cf7c53a

Please sign in to comment.