-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fir filter input #65
base: main
Are you sure you want to change the base?
Fir filter input #65
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for these changes. Adding custom FIRFilter
instances to the STFT loss will be great. I left some suggestions in my comments. Overall, let's try to remove the FIRSequential
and retain current default behavior by creating two prefilters. One is the default prefilter
and the new one is the custom_prefilter
. Also, to make the FIRFilter
class cleaner, we will want to move the butterworth parameter specification outside and perhaps pass in taps. Will need some more thinking. Thanks for your work on this and feel free to tell me different on any of my comments.
@@ -181,7 +184,10 @@ def __init__( | |||
raise ValueError( | |||
f"`sample_rate` must be supplied when `perceptual_weighting = True`." | |||
) | |||
self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate) | |||
if self.prefilter is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we adjust the logic here? My thinking is that if someone sets perceptual_weighting=True
we should apply the default filter. If someone also specifies prefilter=FIRFilter()
then we should also apply this filter. This would create two separate logic branches here.
if self.prefilter is None: | ||
self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate) | ||
else: | ||
self.prefilter = FIRSequential(FIRFilter(filter_type="aw", fs=sample_rate), self.prefilter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following the above, I think this means we can simplify this by removing the need for FIRSequential
. To achieve this we could think of having two separate prefilter attributes. For example, for the user specified prefilter, self.user_prefilter
and for the perceptual weighting we could set self.prefilter
, as it is now. Then, in the forward()
we can check for either and apply if needed. How does this sound?
@@ -209,7 +215,7 @@ def stft(self, x): | |||
def forward(self, input: torch.Tensor, target: torch.Tensor): | |||
bs, chs, seq_len = input.size() | |||
|
|||
if self.perceptual_weighting: # apply optional A-weighting via FIR filter | |||
if self.prefilter is not None: # apply prefilter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we can add two checks, one for if self.prefilter is not None:
and one for if self.user_prefilter is not None:
.
@@ -8,3 +8,9 @@ def apply_reduction(losses, reduction="none"): | |||
elif reduction == "sum": | |||
losses = losses.sum() | |||
return losses | |||
|
|||
class FIRSequential(torch.nn.Sequential): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following the above, ideally we can completely remove the need for this class.
@@ -55,14 +55,17 @@ class FIRFilter(torch.nn.Module): | |||
a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates. | |||
""" | |||
|
|||
def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False): | |||
def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, butter_order=2, butter_freq=(250, 5000), butter_filter_type="bandpass", plot=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think adding the butterworth parameters to this constructor make the behavior a bit confusing since we also have the filter_type
parameter which will get ignored. I do not know the best solution yet, but it feels like we should instead just pass taps
into this constructor. Then we will need to supply some helper functions for hp and butterworth filters that will produce a tensor of taps. What do you think?
@@ -132,6 +134,7 @@ def __init__( | |||
self.reduction = reduction | |||
self.mag_distance = mag_distance | |||
self.device = device | |||
self.prefilter = prefilter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I suggest below, let's change this to something like self.user_prefilter
or another name, maybe self.custom_prefilter
, which is totally separate from self.prefilter
, which will be used for the default perceptual weighting.
In this PR I add optional FIRFilter input to STFTLoss, this filter automatically fills self.prefilter if available and sets it to None if not provided. If only perceptual_weighting flag is set, self.prefilter is set with internally constructed FIRFilter.
If both an external FIRFilter is provided and perceptual_weighting flag is set, an nn.Sequential variation (that allows for two inputs) is constructed to run both filters sequentially.
Tested on some audio input and appears to be working as expected. Below are some spectrograms of the different variations.
No Filter:
Only Perceptual Weighting:
Only External 4.5k Lowpass Filter:
Both Filters:
Also included are the changes to
auraloss.perceptual.FIRFilter
which allows for for butterworth filter construction and a FIRSequential class inauraloss.utils
that inherits from nn.Sequential and allows for multiple inputs.I haven't tried using this branch in a model yet but it has worked returning losses as expected in my testing of just this repo.