Skip to content

Commit

Permalink
Changed scaling factor so RMS doesn't need to = 0, rather just be low…
Browse files Browse the repository at this point in the history
…er than the error threshold to replace with min value.
  • Loading branch information
crlandsc committed May 22, 2024
1 parent 130f86e commit 5ad4abf
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_forward_silence(self):

# Generate random inputs (scale between -1 and 1)
audio_lengths_samples = int(audio_length * sample_rate)
unprocessed_audio = torch.zeros(batch, audio_channels, audio_lengths_samples)
unprocessed_audio = torch.rand(batch, audio_channels, audio_lengths_samples) * convert_decibels_to_amplitude_ratio(-75)
processed_audio = torch.rand(batch, audio_stems, audio_channels, audio_lengths_samples) * convert_decibels_to_amplitude_ratio(-60)
target_audio = torch.zeros(batch, audio_stems, audio_channels, audio_lengths_samples)

Expand Down
4 changes: 2 additions & 2 deletions torch_log_wmse_audio_quality/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def _calculate_log_wmse(
Tensor: The logWMSE between the processed audio and target audio.
"""

# Add EPS if input_rms is 0 (silence) to avoid NaNs
if input_rms.sum() == 0:
# Add EPS if input_rms is 0 (silence), or close to it, to avoid NaNs
if input_rms.sum() < ERROR_TOLERANCE_THRESHOLD:
input_rms = torch.ones_like(input_rms) * ERROR_TOLERANCE_THRESHOLD

# Calculate the scaling factor based on the input RMS
Expand Down

0 comments on commit 5ad4abf

Please sign in to comment.