diff --git a/auraloss/freq.py b/auraloss/freq.py index a5efe70..8fc8e1e 100644 --- a/auraloss/freq.py +++ b/auraloss/freq.py @@ -2,7 +2,7 @@ import numpy as np from typing import List, Any -from .utils import apply_reduction +from .utils import apply_reduction, FIRSequential from .perceptual import SumAndDifference, FIRFilter @@ -84,6 +84,7 @@ class STFTLoss(torch.nn.Module): Default: 'mean' mag_distance (str, optional): Distance function ["L1", "L2"] for the magnitude loss terms. device (str, optional): Place the filterbanks on specified device. Default: None + prefilter (FIRFilter, optional): apply customizable FIRFilter constructed by auraloss.perceptual.FIRFilter to STFT loss. Default: None Returns: loss: @@ -112,6 +113,7 @@ def __init__( reduction: str = "mean", mag_distance: str = "L1", device: Any = None, + prefilter: FIRFilter = None, ): super().__init__() self.fft_size = fft_size @@ -132,6 +134,7 @@ def __init__( self.reduction = reduction self.mag_distance = mag_distance self.device = device + self.prefilter = prefilter self.spectralconv = SpectralConvergenceLoss() self.logstft = STFTMagnitudeLoss( @@ -181,7 +184,10 @@ def __init__( raise ValueError( f"`sample_rate` must be supplied when `perceptual_weighting = True`." ) - self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate) + if self.prefilter is None: + self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate) + else: + self.prefilter = FIRSequential(FIRFilter(filter_type="aw", fs=sample_rate), self.prefilter) def stft(self, x): """Perform STFT. @@ -209,7 +215,7 @@ def stft(self, x): def forward(self, input: torch.Tensor, target: torch.Tensor): bs, chs, seq_len = input.size() - if self.perceptual_weighting: # apply optional A-weighting via FIR filter + if self.prefilter is not None: # apply prefilter # since FIRFilter only support mono audio we will move channels to batch dim input = input.view(bs * chs, 1, -1) target = target.view(bs * chs, 1, -1) diff --git a/auraloss/perceptual.py b/auraloss/perceptual.py index 1cedeb3..b6bbae2 100644 --- a/auraloss/perceptual.py +++ b/auraloss/perceptual.py @@ -55,7 +55,7 @@ class FIRFilter(torch.nn.Module): a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates. """ - def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False): + def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, butter_order=2, butter_freq=(250, 5000), butter_filter_type="bandpass", plot=False): """Initilize FIR pre-emphasis filtering module.""" super(FIRFilter, self).__init__() self.filter_type = filter_type @@ -63,6 +63,9 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False) self.fs = fs self.ntaps = ntaps self.plot = plot + self.butter_order = butter_order + self.butter_freq = butter_freq + self.butter_filter_type = butter_filter_type import scipy.signal @@ -113,6 +116,27 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False) if plot: from .plotting import compare_filters compare_filters(b, a, taps, fs=fs) + elif filter_type == "butter": + # Define digital butter filter + b, a = scipy.signal.butter(self.butter_order, self.butter_freq, self.butter_filter_type, analog=False, output="ba", fs=fs) + + # compute the digital filter frequency response + w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs) + + # then we fit to 101 tap FIR filter with least squares + taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs) + + # now implement this digital FIR filter as a Conv1d layer + self.fir = torch.nn.Conv1d( + 1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2 + ) + self.fir.weight.requires_grad = False + self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1) + + if plot: + from .plotting import compare_filters + compare_filters(b, a, taps, fs=fs) + def forward(self, input, target): """Calculate forward propagation. diff --git a/auraloss/utils.py b/auraloss/utils.py index 3b36c69..b951f73 100644 --- a/auraloss/utils.py +++ b/auraloss/utils.py @@ -8,3 +8,9 @@ def apply_reduction(losses, reduction="none"): elif reduction == "sum": losses = losses.sum() return losses + +class FIRSequential(torch.nn.Sequential): + def forward(self, *inputs): + for module in self._modules.values(): + inputs = module(*inputs) + return inputs \ No newline at end of file