-
Notifications
You must be signed in to change notification settings - Fork 2
/
Losses.py
121 lines (93 loc) · 3.51 KB
/
Losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
import torch.nn.functional as F
from Utils import *
def mae_loss(y_pred, y_true):
"""L1 Loss"""
e = torch.mean(torch.abs(y_true - y_pred))
return e
def mse_loss(y_pred, y_true):
"""L2 Loss"""
e = torch.mean(torch.square(y_true - y_pred))
return e
def raw2spec_mae_loss(y_pred, y_true, stft):
"""L1 Loss on raw2spec Spectrograms"""
stft = stft.to(y_pred.device)
y_pred_2spec = stft_mag_transform(y_pred, stft)
y_true_2spec = stft_mag_transform(y_true, stft)
e = torch.mean(torch.abs(y_true_2spec - y_pred_2spec))
return e
def raw2spec_mse_loss(y_pred, y_true, stft):
"""Le Loss on raw2spec Spectrograms"""
stft = stft.to(y_pred.device)
y_pred_2spec = stft_mag_transform(y_pred, stft)
y_true_2spec = stft_mag_transform(y_true, stft)
e = torch.mean(torch.square(y_true_2spec - y_pred_2spec))
return e
def fnorm(matrix):
"""Frobenius Norm"""
f = torch.sqrt(torch.sum(torch.pow(torch.abs(matrix), 2)))
return f
def spectral_convergence_loss(y_pred, y_true):
"""Spectral Convergence Loss on Spectrograms"""
e = fnorm(y_true - y_pred) / fnorm(y_true)
return e
def raw2spec_spectral_convergence_loss(y_pred, y_true, stft):
"""Spectral Convergence Loss on raw2spec Spectrograms"""
stft = stft.to(y_pred.device)
assert stft.dB==False, print('Spectral convergence uses linearly scaled spectrograms')
y_pred_2spec = stft_mag_transform(y_pred, stft)
y_true_2spec = stft_mag_transform(y_true, stft)
e = fnorm(y_true_2spec - y_pred_2spec) / fnorm(y_true_2spec)
return e
def neg_si_sdr(y_pred, y_true, zero_mean=False, epsilon=1e-10):
if zero_mean:
mean_true = torch.mean(y_true, dim=2, keepdim=True)
mean_pred = torch.mean(y_pred, dim=2, keepdim=True)
y_true = y_true - mean_true
y_pred = y_pred - mean_pred
pairwise_dot = torch.sum(y_pred * y_true, dim=2, keepdim=True)
true_energy = torch.sum(y_true ** 2, dim=2, keepdim=True) + epsilon
scaled_true = pairwise_dot * y_true / true_energy
e_noise = y_pred - scaled_true
pairwise_sdr = torch.sum(scaled_true ** 2, dim=2) / (torch.sum(e_noise ** 2, dim=2) + epsilon)
pairwise_sdr = 10 * torch.log10(pairwise_sdr + epsilon)
return -torch.mean(pairwise_sdr)
def total_loss(y_pred, y_true, stft):
stft = stft.to(y_pred.device)
assert stft.dB==False, print('Spectral convergence uses linearly scaled spectrograms')
y_pred_2spec = stft_mag_transform(y_pred, stft)
y_true_2spec = stft_mag_transform(y_true, stft)
w1 = 1
w2 = 1e-1
w3 = 2e-2
l1 = torch.mean(torch.abs(y_true - y_pred))
l2 = torch.mean(torch.abs(y_true_2spec - y_pred_2spec))
l3 = fnorm(y_true_2spec - y_pred_2spec) / fnorm(y_true_2spec)
l = w1*l1 + w2*l2 + w3*l3
return l
def perceptual_loss(y_pred, y_true, stft, weights=[1, 1e-1, 2e-2]):
stft = stft.to(y_pred.device)
assert stft.dB==False, print('Spectral convergence uses linearly scaled spectrograms')
y_pred_2spec = stft_mag_transform(y_pred, stft)
y_true_2spec = stft_mag_transform(y_true, stft)
w1, w2, w3 = weights
l1 = torch.mean(torch.abs(y_true - y_pred))
l2 = torch.mean(torch.abs(y_true_2spec - y_pred_2spec))
l3 = fnorm(y_true_2spec - y_pred_2spec) / fnorm(y_true_2spec)
l = w1*l1 + w2*l2 + w3*l3
return l
@vae_recon_wrapper
def vae_perceptual_loss(*args, stft):
return perceptual_loss(*args, stft=stft)
def kld_loss(mu, logvar):
loss = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1))
return loss
@vae_kld_wrapper
def vae_kld_loss(*args):
return kld_loss(*args)
def latent_loss(x):
return x
@vq_vae_loss_wrapper
def vq_vae_latent_loss(*args):
return latent_loss(*args)