From e953ab3ba04fd8a6a619e01dd05c58544c429099 Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Thu, 17 Feb 2022 16:42:06 +0000 Subject: [PATCH] refactor test --- .../functional/batch_consistency_test.py | 8 ++++---- .../functional/torchscript_consistency_impl.py | 7 +++++-- torchaudio/functional/functional.py | 10 ++++++---- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/test/torchaudio_unittest/functional/batch_consistency_test.py b/test/torchaudio_unittest/functional/batch_consistency_test.py index 31589524f4c..a87b739fe3a 100644 --- a/test/torchaudio_unittest/functional/batch_consistency_test.py +++ b/test/torchaudio_unittest/functional/batch_consistency_test.py @@ -297,9 +297,9 @@ def test_filtfilt(self): def test_compute_rtf_evd(self): torch.random.manual_seed(2434) + batch_size = 2 channel = 4 - spectrum = torch.rand(2, 5, channel, dtype=torch.cfloat) + n_fft_bin = 5 + spectrum = torch.rand(batch_size, n_fft_bin, channel, dtype=torch.cfloat) psd = torch.einsum("...c,...d->...cd", spectrum, spectrum.conj()) - batchwise_output = F.compute_rtf_evd(psd) - itemwise_output = torch.stack([F.compute_rtf_evd(psd[i]) for i in range(2)]) - self.assertEqual(batchwise_output, itemwise_output) + self.assert_batch_consistency(F.compute_rtf_evd, (psd,)) diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index 15bb38aa326..77183395fa5 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -618,8 +618,11 @@ def test_phase_vocoder(self): self._assert_consistency_complex(F.phase_vocoder, (tensor, rate, phase_advance)) def test_compute_rtf_evd(self): - tensor = torch.rand(129, 4, 4, dtype=torch.cfloat) - self._assert_consistency_complex(F.compute_rtf_evd, tensor) + batch_size = 2 + channel = 4 + n_fft_bin = 129 + tensor = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype) + self._assert_consistency_complex(F.compute_rtf_evd, (tensor,)) class FunctionalFloat32Only(TestBaseMixin): diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 04445453a8c..3e05c6feec5 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1636,13 +1636,15 @@ def rnnt_loss( def compute_rtf_evd(psd_s: Tensor) -> Tensor: r"""Estimate the relative transfer function (RTF) or the steering vector by eigenvalue decomposition. + Args: - psd_s (Tensor): The complex-valued covariance matrix of target speech. + psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech. Tensor of dimension `(..., freq, channel, channel)` + Returns: Tensor: The estimated complex-valued RTF of target speech. - Tensor of dimension `(..., freq, channel)` + Tensor of dimension `(..., freq, channel)` """ - w, v = torch.linalg.eigh(psd_s) # (..., freq, channel, channel) - rtf = v[..., -1] + _, v = torch.linalg.eigh(psd_s) # v is sorted along with eigenvalues in ascending order + rtf = v[..., -1] # choose the eigenvector with max eigenvalue return rtf