Skip to content

Commit

Permalink
add compute_rtf_evd to functional
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed Feb 17, 2022
1 parent 9cf59e7 commit e0ff8bc
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ treble_biquad

.. autofunction:: spectral_centroid

compute_rtf_evd
---------------

.. autofunction:: compute_rtf_evd

:hidden:`Loss`
~~~~~~~~~~~~~~

Expand Down
9 changes: 9 additions & 0 deletions test/torchaudio_unittest/functional/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,12 @@ def test_filtfilt(self):
a = torch.rand(self.batch_size, 3)
b = torch.rand(self.batch_size, 3)
self.assert_batch_consistency(F.filtfilt, inputs=(x, a, b))

def test_compute_rtf_evd(self):
torch.random.manual_seed(2434)
channel = 4
spectrum = torch.rand(2, 5, channel, dtype=torch.cfloat)
psd = torch.einsum("...c,...d->...cd", spectrum, spectrum.conj())
batchwise_output = F.compute_rtf_evd(psd)
itemwise_output = torch.stack([F.compute_rtf_evd(psd[i]) for i in range(2)])
self.assertEqual(batchwise_output, itemwise_output)
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,10 @@ def test_phase_vocoder(self):
)[..., None]
self._assert_consistency_complex(F.phase_vocoder, (tensor, rate, phase_advance))

def test_compute_rtf_evd(self):
tensor = torch.rand(129, 4, 4, dtype=torch.cfloat)
self._assert_consistency_complex(F.compute_rtf_evd, tensor)


class FunctionalFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
edit_distance,
pitch_shift,
rnnt_loss,
compute_rtf_evd,
)

__all__ = [
Expand Down Expand Up @@ -94,4 +95,5 @@
"edit_distance",
"pitch_shift",
"rnnt_loss",
"compute_rtf_evd",
]
15 changes: 15 additions & 0 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"edit_distance",
"pitch_shift",
"rnnt_loss",
"compute_rtf_evd",
]


Expand Down Expand Up @@ -1631,3 +1632,17 @@ def rnnt_loss(
return costs.sum()

return costs


def compute_rtf_evd(psd_s: Tensor) -> Tensor:
r"""Estimate the relative transfer function (RTF) or the steering vector by eigenvalue decomposition.
Args:
psd_s (Tensor): The complex-valued covariance matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
Returns:
Tensor: The estimated complex-valued RTF of target speech.
Tensor of dimension `(..., freq, channel)`
"""
w, v = torch.linalg.eigh(psd_s) # (..., freq, channel, channel)
rtf = v[..., -1]
return rtf

0 comments on commit e0ff8bc

Please sign in to comment.