From 92cf006700d6d00a4c31a7f71587bc29a2ff9d48 Mon Sep 17 00:00:00 2001 From: Jeff Sontag Date: Sat, 9 Sep 2023 16:49:39 -0700 Subject: [PATCH 01/10] add butter filter --- auraloss/perceptual.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/auraloss/perceptual.py b/auraloss/perceptual.py index 1cedeb3..5aeae3b 100644 --- a/auraloss/perceptual.py +++ b/auraloss/perceptual.py @@ -55,7 +55,7 @@ class FIRFilter(torch.nn.Module): a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates. """ - def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False): + def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, butter_order=2, butter_freq=(250, 5000), butter_filter_type="bandpass", plot=False): """Initilize FIR pre-emphasis filtering module.""" super(FIRFilter, self).__init__() self.filter_type = filter_type @@ -63,6 +63,9 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False) self.fs = fs self.ntaps = ntaps self.plot = plot + self.butter_order = butter_order + self.butter_freq = butter_freq + self.butter_filter_type = butter_filter_type import scipy.signal @@ -113,6 +116,30 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False) if plot: from .plotting import compare_filters compare_filters(b, a, taps, fs=fs) + elif filter_type == "butter": + # Define butter filter + filts = signal.lti(*signal.butter(self.butter_order, self.butter_freq, self.butter_filter_type, analog=True)) + + # convert analog filter to digital filter + b, a = scipy.signal.bilinear(filts.num, filts.den, fs=fs) + + # compute the digital filter frequency response + w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs) + + # then we fit to 101 tap FIR filter with least squares + taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs) + + # now implement this digital FIR filter as a Conv1d layer + self.fir = torch.nn.Conv1d( + 1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2 + ) + self.fir.weight.requires_grad = False + self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1) + + if plot: + from .plotting import compare_filters + compare_filters(b, a, taps, fs=fs) + def forward(self, input, target): """Calculate forward propagation. From 2a2d899216bf6b1002f2310c0e46b87df7b35d4d Mon Sep 17 00:00:00 2001 From: Jeff Sontag Date: Sat, 9 Sep 2023 17:10:09 -0700 Subject: [PATCH 02/10] package --- auraloss/perceptual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auraloss/perceptual.py b/auraloss/perceptual.py index 5aeae3b..b30b2ed 100644 --- a/auraloss/perceptual.py +++ b/auraloss/perceptual.py @@ -118,7 +118,7 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, butter_orde compare_filters(b, a, taps, fs=fs) elif filter_type == "butter": # Define butter filter - filts = signal.lti(*signal.butter(self.butter_order, self.butter_freq, self.butter_filter_type, analog=True)) + filts = scipy.signal.lti(*signal.butter(self.butter_order, self.butter_freq, self.butter_filter_type, analog=True)) # convert analog filter to digital filter b, a = scipy.signal.bilinear(filts.num, filts.den, fs=fs) From ebd8748c0d21547d51c38c07d1b1dff058521c20 Mon Sep 17 00:00:00 2001 From: Jeff Sontag Date: Sat, 9 Sep 2023 17:13:26 -0700 Subject: [PATCH 03/10] again --- auraloss/perceptual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auraloss/perceptual.py b/auraloss/perceptual.py index b30b2ed..e054bac 100644 --- a/auraloss/perceptual.py +++ b/auraloss/perceptual.py @@ -118,7 +118,7 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, butter_orde compare_filters(b, a, taps, fs=fs) elif filter_type == "butter": # Define butter filter - filts = scipy.signal.lti(*signal.butter(self.butter_order, self.butter_freq, self.butter_filter_type, analog=True)) + filts = scipy.signal.lti(*scipy.signal.butter(self.butter_order, self.butter_freq, self.butter_filter_type, analog=True)) # convert analog filter to digital filter b, a = scipy.signal.bilinear(filts.num, filts.den, fs=fs) From 4bc86c42f1f9f988636f268944360708c54f202e Mon Sep 17 00:00:00 2001 From: Jeff Sontag Date: Sat, 9 Sep 2023 17:32:34 -0700 Subject: [PATCH 04/10] add sr --- auraloss/perceptual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auraloss/perceptual.py b/auraloss/perceptual.py index e054bac..9e5d3fd 100644 --- a/auraloss/perceptual.py +++ b/auraloss/perceptual.py @@ -118,7 +118,7 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, butter_orde compare_filters(b, a, taps, fs=fs) elif filter_type == "butter": # Define butter filter - filts = scipy.signal.lti(*scipy.signal.butter(self.butter_order, self.butter_freq, self.butter_filter_type, analog=True)) + filts = scipy.signal.lti(*scipy.signal.butter(self.butter_order, self.butter_freq, self.butter_filter_type, fs=fs, analog=True)) # convert analog filter to digital filter b, a = scipy.signal.bilinear(filts.num, filts.den, fs=fs) From 68912a3a7d80e33a5afc80c08f69b4caf474b993 Mon Sep 17 00:00:00 2001 From: Jeff Sontag Date: Sat, 9 Sep 2023 17:42:49 -0700 Subject: [PATCH 05/10] trying digital init --- auraloss/perceptual.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/auraloss/perceptual.py b/auraloss/perceptual.py index 9e5d3fd..e2e9ab9 100644 --- a/auraloss/perceptual.py +++ b/auraloss/perceptual.py @@ -118,10 +118,10 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, butter_orde compare_filters(b, a, taps, fs=fs) elif filter_type == "butter": # Define butter filter - filts = scipy.signal.lti(*scipy.signal.butter(self.butter_order, self.butter_freq, self.butter_filter_type, fs=fs, analog=True)) - + # filts = scipy.signal.lti(*scipy.signal.butter(self.butter_order, self.butter_freq, self.butter_filter_type, analog=True)) + b, a = scipy.signal.butter(self.butter_order, self.butter_freq, self.butter_filter_type, analog=False, output="ba", fs=fs) # convert analog filter to digital filter - b, a = scipy.signal.bilinear(filts.num, filts.den, fs=fs) + # b, a = scipy.signal.bilinear(filts.num, filts.den, fs=fs) # compute the digital filter frequency response w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs) From f069d8234b770553149b251f3bd22d66d0144769 Mon Sep 17 00:00:00 2001 From: Jeff Sontag Date: Sat, 9 Sep 2023 17:46:29 -0700 Subject: [PATCH 06/10] cleanup --- auraloss/perceptual.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/auraloss/perceptual.py b/auraloss/perceptual.py index e2e9ab9..b6bbae2 100644 --- a/auraloss/perceptual.py +++ b/auraloss/perceptual.py @@ -117,11 +117,8 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, butter_orde from .plotting import compare_filters compare_filters(b, a, taps, fs=fs) elif filter_type == "butter": - # Define butter filter - # filts = scipy.signal.lti(*scipy.signal.butter(self.butter_order, self.butter_freq, self.butter_filter_type, analog=True)) + # Define digital butter filter b, a = scipy.signal.butter(self.butter_order, self.butter_freq, self.butter_filter_type, analog=False, output="ba", fs=fs) - # convert analog filter to digital filter - # b, a = scipy.signal.bilinear(filts.num, filts.den, fs=fs) # compute the digital filter frequency response w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs) From de822db31671be531e88233b8335519bbc2d5667 Mon Sep 17 00:00:00 2001 From: Jeff Sontag Date: Fri, 20 Oct 2023 11:50:33 -0700 Subject: [PATCH 07/10] adding prefilter input --- auraloss/freq.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/auraloss/freq.py b/auraloss/freq.py index a5efe70..7229413 100644 --- a/auraloss/freq.py +++ b/auraloss/freq.py @@ -84,6 +84,7 @@ class STFTLoss(torch.nn.Module): Default: 'mean' mag_distance (str, optional): Distance function ["L1", "L2"] for the magnitude loss terms. device (str, optional): Place the filterbanks on specified device. Default: None + prefilter (FIRFilter, optional): apply customizable FIRFilter constructed by auraloss.perceptual.FIRFilter to STFT loss. Default: None Returns: loss: @@ -112,6 +113,7 @@ def __init__( reduction: str = "mean", mag_distance: str = "L1", device: Any = None, + prefilter: FIRFilter = None, ): super().__init__() self.fft_size = fft_size @@ -132,6 +134,7 @@ def __init__( self.reduction = reduction self.mag_distance = mag_distance self.device = device + self.prefilter = prefilter self.spectralconv = SpectralConvergenceLoss() self.logstft = STFTMagnitudeLoss( @@ -181,7 +184,10 @@ def __init__( raise ValueError( f"`sample_rate` must be supplied when `perceptual_weighting = True`." ) - self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate) + if self.prefilter is None: + self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate) + else: + self.prefilter = torch.nn.Sequential(FIRFilter(filter_type="aw", fs=sample_rate), self.prefilter) def stft(self, x): """Perform STFT. @@ -209,7 +215,7 @@ def stft(self, x): def forward(self, input: torch.Tensor, target: torch.Tensor): bs, chs, seq_len = input.size() - if self.perceptual_weighting: # apply optional A-weighting via FIR filter + if self.prefilter is not None: # apply prefilter # since FIRFilter only support mono audio we will move channels to batch dim input = input.view(bs * chs, 1, -1) target = target.view(bs * chs, 1, -1) From c9b93e5e740dff968c31cdd20b3990f63691293f Mon Sep 17 00:00:00 2001 From: Jeff Sontag Date: Fri, 20 Oct 2023 12:12:23 -0700 Subject: [PATCH 08/10] add custom sequential class for multiple inputs --- auraloss/freq.py | 4 ++-- auraloss/utils.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/auraloss/freq.py b/auraloss/freq.py index 7229413..8fc8e1e 100644 --- a/auraloss/freq.py +++ b/auraloss/freq.py @@ -2,7 +2,7 @@ import numpy as np from typing import List, Any -from .utils import apply_reduction +from .utils import apply_reduction, FIRSequential from .perceptual import SumAndDifference, FIRFilter @@ -187,7 +187,7 @@ def __init__( if self.prefilter is None: self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate) else: - self.prefilter = torch.nn.Sequential(FIRFilter(filter_type="aw", fs=sample_rate), self.prefilter) + self.prefilter = FIRSequential(FIRFilter(filter_type="aw", fs=sample_rate), self.prefilter) def stft(self, x): """Perform STFT. diff --git a/auraloss/utils.py b/auraloss/utils.py index 3b36c69..885375c 100644 --- a/auraloss/utils.py +++ b/auraloss/utils.py @@ -8,3 +8,11 @@ def apply_reduction(losses, reduction="none"): elif reduction == "sum": losses = losses.sum() return losses + +class FIRSequential(torch.nn.Sequential): + def __init__(self): + super().__init__() + def forward(self, *inputs): + for module in self._modules.values: + inputs = module(*inputs) + return inputs \ No newline at end of file From 9be14892c2473724b62230cfe84cd3168e580c37 Mon Sep 17 00:00:00 2001 From: Jeff Sontag Date: Fri, 20 Oct 2023 12:17:56 -0700 Subject: [PATCH 09/10] remove init --- auraloss/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/auraloss/utils.py b/auraloss/utils.py index 885375c..7b0ee6a 100644 --- a/auraloss/utils.py +++ b/auraloss/utils.py @@ -10,8 +10,6 @@ def apply_reduction(losses, reduction="none"): return losses class FIRSequential(torch.nn.Sequential): - def __init__(self): - super().__init__() def forward(self, *inputs): for module in self._modules.values: inputs = module(*inputs) From 2759c6ef1ae90f29ce34d5d90d43fb11490b7ba8 Mon Sep 17 00:00:00 2001 From: Jeff Sontag Date: Fri, 20 Oct 2023 12:21:49 -0700 Subject: [PATCH 10/10] missing parentheses --- auraloss/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auraloss/utils.py b/auraloss/utils.py index 7b0ee6a..b951f73 100644 --- a/auraloss/utils.py +++ b/auraloss/utils.py @@ -11,6 +11,6 @@ def apply_reduction(losses, reduction="none"): class FIRSequential(torch.nn.Sequential): def forward(self, *inputs): - for module in self._modules.values: + for module in self._modules.values(): inputs = module(*inputs) return inputs \ No newline at end of file