Skip to content

Commit

Permalink
move numpy method to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed Feb 25, 2022
1 parent 24bc248 commit 6f07572
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 6 additions & 0 deletions test/torchaudio_unittest/common_utils/beamform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,9 @@ def mvdr_weights_rtf_numpy(rtf, psd_n, reference_channel, diag_eps=1e-7, eps=1e-
scale = np.einsum("...c,...c->...", rtf.conj(), reference_channel[..., None, :])
beamform_weights = beamform_weights * scale[..., None]
return beamform_weights


def rtf_evd_numpy(psd):
_, v = np.linalg.eigh(psd)
rtf = v[..., -1]
return rtf
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def test_rtf_evd(self):
channel = 4
specgram = np.random.random((n_fft_bin, channel)) + np.random.random((n_fft_bin, channel)) * 1j
psd = np.einsum("fc,fd->fcd", specgram.conj(), specgram)
rtf = self._rtf_evd_numpy(psd)
rtf = beamform_utils.rtf_evd_numpy(psd)
rtf_audio = F.rtf_evd(torch.tensor(psd, dtype=self.complex_dtype, device=self.device))
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)

Expand Down

0 comments on commit 6f07572

Please sign in to comment.