Skip to content

Commit

Permalink
rename method to psd
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed Feb 17, 2022
1 parent 9f0351e commit fd896a1
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 15 deletions.
4 changes: 2 additions & 2 deletions test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,14 @@ def test_bandreject_biquad(self, central_freq, Q):
(False,),
]
)
def test_compute_power_spectral_density_matrix(self, use_mask):
def test_psd(self, use_mask):
torch.random.manual_seed(2434)
specgram = torch.rand(4, 10, 5, dtype=torch.cfloat)
if use_mask:
mask = torch.rand(10, 5)
else:
mask = None
self.assert_grad(F.compute_power_spectral_density_matrix, (specgram, mask))
self.assert_grad(F.psd, (specgram, mask))


class AutogradFloat32(TestBaseMixin):
Expand Down
8 changes: 4 additions & 4 deletions test/torchaudio_unittest/functional/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def test_filtfilt(self):
b = torch.rand(self.batch_size, 3)
self.assert_batch_consistency(F.filtfilt, inputs=(x, a, b))

def test_compute_power_spectral_density_matrix(self):
def test_psd_matrix(self):
batch_size = 2
channel = 3
sample_rate = 44100
Expand All @@ -304,9 +304,9 @@ def test_compute_power_spectral_density_matrix(self):
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=batch_size * channel)
specgram = common_utils.get_spectrogram(waveform, n_fft=n_fft, hop_length=100)
specgram = specgram.view(batch_size, channel, n_fft_bin, specgram.size(-1))
self.assert_batch_consistency(F.compute_power_spectral_density_matrix, (specgram,))
self.assert_batch_consistency(F.psd, (specgram,))

def test_compute_power_spectral_density_matrix_with_mask(self):
def test_psd_with_mask(self):
batch_size = 2
channel = 3
sample_rate = 44100
Expand All @@ -316,4 +316,4 @@ def test_compute_power_spectral_density_matrix_with_mask(self):
specgram = common_utils.get_spectrogram(waveform, n_fft=n_fft, hop_length=100)
specgram = specgram.view(batch_size, channel, n_fft_bin, specgram.size(-1))
mask = torch.rand(batch_size, n_fft_bin, specgram.size(-1))
self.assert_batch_consistency(F.compute_power_spectral_density_matrix, (specgram, mask))
self.assert_batch_consistency(F.psd, (specgram, mask))
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,9 @@ def test_phase_vocoder(self):
)[..., None]
self._assert_consistency_complex(F.phase_vocoder, (tensor, rate, phase_advance))

def test_compute_power_spectral_density_matrix(self):
def test_psd(self):
def func(specgram):
return F.compute_power_spectral_density_matrix(specgram)
return F.psd(specgram)

batch_size = 2
channel = 4
Expand All @@ -628,9 +628,9 @@ def func(specgram):
tensor = torch.rand(batch_size, channel, n_fft_bin, frame, dtype=self.complex_dtype)
self._assert_consistency_complex(func, (tensor,))

def test_compute_power_spectral_density_matrix_with_mask(self):
def test_psd_with_mask(self):
def func(specgram, mask):
return F.compute_power_spectral_density_matrix(specgram, mask)
return F.psd(specgram, mask)

batch_size = 2
channel = 4
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
edit_distance,
pitch_shift,
rnnt_loss,
compute_power_spectral_density_matrix,
psd,
)

__all__ = [
Expand Down Expand Up @@ -95,5 +95,5 @@
"edit_distance",
"pitch_shift",
"rnnt_loss",
"compute_power_spectral_density_matrix",
"psd",
]
6 changes: 3 additions & 3 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"edit_distance",
"pitch_shift",
"rnnt_loss",
"compute_power_spectral_density_matrix",
"psd",
]


Expand Down Expand Up @@ -1634,7 +1634,7 @@ def rnnt_loss(
return costs


def compute_power_spectral_density_matrix(
def psd(
specgram: Tensor,
mask: Optional[Tensor] = None,
normalize: bool = True,
Expand All @@ -1648,7 +1648,7 @@ def compute_power_spectral_density_matrix(
mask (Tensor or None, optional): Real-valued time-frequency mask
for normalization. Tensor of dimension `(..., freq, time)`
(Default: ``None``)
normalize (bool, optional): whether to normalize the mask along the time dimension.
normalize (bool, optional): whether to normalize the mask along the time dimension. (Default: ``True``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-10``)
Returns:
Expand Down

0 comments on commit fd896a1

Please sign in to comment.