Skip to content

Commit

Permalink
add compute_power_spectral_density_matrix to functional
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed Feb 15, 2022
1 parent 963905e commit 864f98e
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 1 deletion.
8 changes: 8 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,14 @@ treble_biquad

.. autofunction:: spectral_centroid

:hidden:`Multi-channel`
~~~~~~~~~~~~~~~~~~~~~~~

compute_power_spectral_density_matrix
-------------------------------------

.. autofunction:: compute_power_spectral_density_matrix

:hidden:`Loss`
~~~~~~~~~~~~~~

Expand Down
17 changes: 16 additions & 1 deletion test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def assert_grad(
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=self.dtype, device=self.device)
i = i.to(dtype=torch.cdouble if i.is_complex() else self.dtype, device=self.device)
if enable_all_grad:
i.requires_grad = True
inputs_.append(i)
Expand Down Expand Up @@ -250,6 +250,21 @@ def test_bandreject_biquad(self, central_freq, Q):
Q = torch.tensor(Q)
self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))

@parameterized.expand(
[
(True,),
(False,),
]
)
def test_compute_power_spectral_density_matrix(self, use_mask):
torch.random.manual_seed(2434)
specgram = torch.rand(4, 10, 5, dtype=torch.cfloat)
if use_mask:
mask = torch.rand(10, 5)
else:
mask = None
self.assert_grad(F.compute_power_spectral_density_matrix, (specgram, mask))


class AutogradFloat32(TestBaseMixin):
def assert_grad(
Expand Down
9 changes: 9 additions & 0 deletions test/torchaudio_unittest/functional/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,12 @@ def test_filtfilt(self):
itemwise_output = torch.stack([F.filtfilt(x[i], a[i], b[i]) for i in range(self.batch_size)])

self.assertEqual(batchwise_output, itemwise_output)

def test_compute_power_spectral_density_matrix(self):
sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=100)
specgram = specgram.view(2, 3, specgram.size(-2), specgram.size(-1))
batchwise_output = F.compute_power_spectral_density_matrix(specgram)
itemwise_output = torch.stack([F.compute_power_spectral_density_matrix(specgram[i]) for i in range(2)])
self.assertEqual(batchwise_output, itemwise_output)
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,22 @@ def func(tensor):
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
self._assert_consistency_complex(func, tensor)

def test_compute_power_spectral_density_matrix(self):
def func(tensor):
mask = None
return F.compute_power_spectral_density_matrix(tensor, mask=mask)

tensor = torch.rand(2, 201, 100, dtype=torch.cfloat)
self._assert_consistency_complex(func, tensor)

def test_compute_power_spectral_density_matrix_with_mask(self):
def func(tensor):
mask = torch.rand(201, 100)
return F.compute_power_spectral_density_matrix(tensor, mask=mask)

tensor = torch.rand(2, 201, 100, dtype=torch.cfloat)
self._assert_consistency_complex(func, tensor)


class FunctionalFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
edit_distance,
pitch_shift,
rnnt_loss,
compute_power_spectral_density_matrix,
)

__all__ = [
Expand Down Expand Up @@ -94,4 +95,5 @@
"edit_distance",
"pitch_shift",
"rnnt_loss",
"compute_power_spectral_density_matrix",
]
36 changes: 36 additions & 0 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"edit_distance",
"pitch_shift",
"rnnt_loss",
"compute_power_spectral_density_matrix",
]


Expand Down Expand Up @@ -1631,3 +1632,38 @@ def rnnt_loss(
return costs.sum()

return costs


def compute_power_spectral_density_matrix(
specgram: Tensor,
mask: Optional[Tensor] = None,
normalize: bool = True,
eps: float = 1e-10,
) -> Tensor:
"""Compute cross-channel power spectral density (PSD) matrix.
Args:
specgram (Tensor): Multi-channel complex-valued spectrum.
Tensor of dimension `(..., channel, freq, time)`
mask (Tensor or None, optional): Real-valued Time-Frequency mask
for normalization. Tensor of dimension `(..., freq, time)`
(Default: ``None``)
normalize (bool, optional): whether normalize the mask along the time dimension.
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-10``)
Returns:
Tensor: The complex-valued PSD matrix of the input spectrum.
Tensor of dimension `(..., freq, channel, channel)`
"""
specgram = specgram.transpose(-3, -2) # shape (freq, channel, time)
# outer product:
# (..., ch_1, time) x (..., ch_2, time) -> (..., time, ch_1, ch_2)
psd = torch.einsum("...ct,...et->...tce", [specgram, specgram.conj()])

if mask is not None:
# Normalized mask along time dimension:
if normalize:
mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)

psd = psd * mask[..., None, None]

psd = psd.sum(dim=-3)
return psd

0 comments on commit 864f98e

Please sign in to comment.