Skip to content

Commit

Permalink
Refactor MVDR module (#2383)
Browse files Browse the repository at this point in the history
Summary:
- Use `apply_beamforming`, `rtf_evd`, `rtf_power`, `mvdr_weights_souden`, `mvdr_weights_rtf` methods under `torchaudio.functional` to replace the class methods.
- Refactor docstrings in `PSD` and `MVDR`.
- Put `_get_mvdr_vector` outside of `MVDR` class as it doesn't call self methods inside.
- Since MVDR uses einsum for matrix operations, packing and unpacking batches are not necessary. It can be tested by the [batch_consistency_test](https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/transforms/batch_consistency_test.py#L202). Removed it from the code.

Pull Request resolved: #2383

Reviewed By: carolineechen, mthrok

Differential Revision: D36338373

Pulled By: nateanl

fbshipit-source-id: a48a6ae2825657e5967a19656245596cdf037c5f
  • Loading branch information
nateanl authored and facebook-github-bot committed May 12, 2022
1 parent 0963968 commit f5036c7
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 229 deletions.
37 changes: 19 additions & 18 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,17 +1748,18 @@ def psd(
.. properties:: Autograd TorchScript
Args:
specgram (Tensor): Multi-channel complex-valued spectrum.
Tensor of dimension `(..., channel, freq, time)`
mask (Tensor or None, optional): Real-valued time-frequency mask
for normalization. Tensor of dimension `(..., freq, time)`
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor with dimensions `(..., channel, freq, time)`.
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False`` or
with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
(Default: ``None``)
normalize (bool, optional): whether to normalize the mask along the time dimension. (Default: ``True``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-10``)
normalize (bool, optional): If ``True``, normalize the mask along the time dimension. (Default: ``True``)
eps (float, optional): Value to add to the denominator in mask normalization. (Default: ``1e-15``)
Returns:
Tensor: The complex-valued PSD matrix of the input spectrum.
Tensor of dimension `(..., freq, channel, channel)`
torch.Tensor: The complex-valued PSD matrix of the input spectrum.
Tensor with dimensions `(..., freq, channel, channel)`
"""
specgram = specgram.transpose(-3, -2) # shape (freq, channel, time)
# outer product:
Expand All @@ -1780,14 +1781,14 @@ def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> t
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args:
input (torch.Tensor): Tensor of dimension `(..., channel, channel)`
dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix
(Default: -2)
input (torch.Tensor): Tensor with dimensions `(..., channel, channel)`.
dim1 (int, optional): The first dimension of the diagonal matrix.
(Default: ``-1``)
dim2 (int, optional): The second dimension of the diagonal matrix.
(Default: ``-2``)
Returns:
Tensor: trace of the input Tensor
Tensor: The trace of the input Tensor.
"""
assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same."
Expand All @@ -1799,12 +1800,12 @@ def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.T
"""Perform Tikhonov regularization (only modifying real part).
Args:
mat (torch.Tensor): input matrix (..., channel, channel)
reg (float, optional): regularization factor (Default: 1e-8)
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``)
mat (torch.Tensor): Input matrix with dimensions `(..., channel, channel)`.
reg (float, optional): Regularization factor. (Default: 1e-8)
eps (float, optional): Value to avoid the correlation matrix is all-zero. (Default: ``1e-8``)
Returns:
Tensor: regularized matrix (..., channel, channel)
Tensor: Regularized matrix with dimensions `(..., channel, channel)`.
"""
# Add eps
C = mat.size(-1)
Expand Down
Loading

0 comments on commit f5036c7

Please sign in to comment.