Skip to content

Commit

Permalink
refactor test
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed Feb 17, 2022
1 parent e0ff8bc commit e953ab3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
8 changes: 4 additions & 4 deletions test/torchaudio_unittest/functional/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ def test_filtfilt(self):

def test_compute_rtf_evd(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
spectrum = torch.rand(2, 5, channel, dtype=torch.cfloat)
n_fft_bin = 5
spectrum = torch.rand(batch_size, n_fft_bin, channel, dtype=torch.cfloat)
psd = torch.einsum("...c,...d->...cd", spectrum, spectrum.conj())
batchwise_output = F.compute_rtf_evd(psd)
itemwise_output = torch.stack([F.compute_rtf_evd(psd[i]) for i in range(2)])
self.assertEqual(batchwise_output, itemwise_output)
self.assert_batch_consistency(F.compute_rtf_evd, (psd,))
Original file line number Diff line number Diff line change
Expand Up @@ -618,8 +618,11 @@ def test_phase_vocoder(self):
self._assert_consistency_complex(F.phase_vocoder, (tensor, rate, phase_advance))

def test_compute_rtf_evd(self):
tensor = torch.rand(129, 4, 4, dtype=torch.cfloat)
self._assert_consistency_complex(F.compute_rtf_evd, tensor)
batch_size = 2
channel = 4
n_fft_bin = 129
tensor = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
self._assert_consistency_complex(F.compute_rtf_evd, (tensor,))


class FunctionalFloat32Only(TestBaseMixin):
Expand Down
10 changes: 6 additions & 4 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,13 +1636,15 @@ def rnnt_loss(

def compute_rtf_evd(psd_s: Tensor) -> Tensor:
r"""Estimate the relative transfer function (RTF) or the steering vector by eigenvalue decomposition.
Args:
psd_s (Tensor): The complex-valued covariance matrix of target speech.
psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
Returns:
Tensor: The estimated complex-valued RTF of target speech.
Tensor of dimension `(..., freq, channel)`
Tensor of dimension `(..., freq, channel)`
"""
w, v = torch.linalg.eigh(psd_s) # (..., freq, channel, channel)
rtf = v[..., -1]
_, v = torch.linalg.eigh(psd_s) # v is sorted along with eigenvalues in ascending order
rtf = v[..., -1] # choose the eigenvector with max eigenvalue
return rtf

0 comments on commit e953ab3

Please sign in to comment.