diff --git a/tests/test_metric.py b/tests/test_metric.py index 6d4b036..5f769c3 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -89,7 +89,7 @@ def test_forward_silence(self): # Generate random inputs (scale between -1 and 1) audio_lengths_samples = int(audio_length * sample_rate) - unprocessed_audio = torch.zeros(batch, audio_channels, audio_lengths_samples) + unprocessed_audio = torch.rand(batch, audio_channels, audio_lengths_samples) * convert_decibels_to_amplitude_ratio(-75) processed_audio = torch.rand(batch, audio_stems, audio_channels, audio_lengths_samples) * convert_decibels_to_amplitude_ratio(-60) target_audio = torch.zeros(batch, audio_stems, audio_channels, audio_lengths_samples) diff --git a/torch_log_wmse_audio_quality/metric.py b/torch_log_wmse_audio_quality/metric.py index 41074d0..bcbc665 100644 --- a/torch_log_wmse_audio_quality/metric.py +++ b/torch_log_wmse_audio_quality/metric.py @@ -95,8 +95,8 @@ def _calculate_log_wmse( Tensor: The logWMSE between the processed audio and target audio. """ - # Add EPS if input_rms is 0 (silence) to avoid NaNs - if input_rms.sum() == 0: + # Add EPS if input_rms is 0 (silence), or close to it, to avoid NaNs + if input_rms.sum() < ERROR_TOLERANCE_THRESHOLD: input_rms = torch.ones_like(input_rms) * ERROR_TOLERANCE_THRESHOLD # Calculate the scaling factor based on the input RMS