Skip to content

Commit

Permalink
Update perceptual.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sai-soum authored Feb 8, 2024
1 parent 01de139 commit a4a021f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions auraloss/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False)
if ntaps % 2 == 0:
raise ValueError(f"ntaps must be odd (ntaps={ntaps}).")

if filter_type == "lp":
if filter_type == "hp":
self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
self.fir.weight.requires_grad = False
self.fir.weight.data = torch.tensor([1, coef, 0]).view(1, 1, -1)
elif filter_type == "hp":
self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1)
elif filter_type == "lp":
self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
self.fir.weight.requires_grad = False
self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1)
self.fir.weight.data = torch.tensor([1, coef, 0]).view(1, 1, -1)
elif filter_type == "fd":
self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
self.fir.weight.requires_grad = False
Expand Down

0 comments on commit a4a021f

Please sign in to comment.