-
Notifications
You must be signed in to change notification settings - Fork 1
/
losses.py
291 lines (244 loc) · 9.71 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
# Copyright 2020 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file has been modified from the original
# Lint as: python3
"""Library of loss functions."""
import functools
import crepe
import spectral_ops
import gin
import tensorflow.compat.v2 as tf
tfkl = tf.keras.layers
# ---------------------- Losses ------------------------------------------------
def mean_difference(target, value, loss_type='L1', weights=None):
"""Common loss functions.
Args:
target: Target tensor.
value: Value tensor.
loss_type: One of 'L1', 'L2', or 'COSINE'.
weights: A weighting mask for the per-element differences.
Returns:
The average loss.
Raises:
ValueError: If loss_type is not an allowed value.
"""
difference = target - value
weights = 1.0 if weights is None else weights
loss_type = loss_type.upper()
if loss_type == 'L1':
return tf.reduce_mean(tf.abs(difference * weights))
elif loss_type == 'L2':
return tf.reduce_mean(difference**2 * weights)
elif loss_type == 'COSINE':
return tf.losses.cosine_distance(target, value, weights=weights, axis=-1)
else:
raise ValueError('Loss type ({}), must be '
'"L1", "L2", or "COSINE"'.format(loss_type))
@gin.register
class SpectralLoss(tfkl.Layer):
"""Multi-scale spectrogram loss."""
def __init__(self,
fft_sizes=(8192, 4096, 2048, 1024, 512, 256, 128, 64),
loss_type='L1',
mag_weight=1.0,
delta_time_weight=0.0,
delta_delta_time_weight=0.0,
delta_freq_weight=0.0,
delta_delta_freq_weight=0.0,
cumsum_freq_weight=0.0,
logmag_weight=1.0,
loudness_weight=0.0,
name='spectral_loss'):
super().__init__(name=name)
self.fft_sizes = fft_sizes
self.loss_type = loss_type
self.mag_weight = mag_weight
self.delta_time_weight = delta_time_weight
self.delta_delta_time_weight = delta_delta_time_weight
self.delta_freq_weight = delta_freq_weight
self.delta_delta_freq_weight = delta_delta_freq_weight
self.cumsum_freq_weight = cumsum_freq_weight
self.logmag_weight = logmag_weight
self.loudness_weight = loudness_weight
self.spectrogram_ops = []
for size in self.fft_sizes:
spectrogram_op = functools.partial(spectral_ops.compute_mag, size=size)
self.spectrogram_ops.append(spectrogram_op)
def call(self, target_audio, audio):
loss = 0.0
diff = spectral_ops.diff
cumsum = tf.math.cumsum
# Compute loss for each fft size.
for loss_op in self.spectrogram_ops:
target_mag = loss_op(target_audio)
value_mag = loss_op(audio)
# Add magnitude loss.
if self.mag_weight > 0:
loss += self.mag_weight * mean_difference(target_mag, value_mag,
self.loss_type)
if self.delta_time_weight > 0:
target = diff(target_mag, axis=1)
value = diff(value_mag, axis=1)
loss += self.delta_time_weight * mean_difference(
target, value, self.loss_type)
if self.delta_delta_time_weight > 0:
target = diff(diff(target_mag, axis=1), axis=1)
value = diff(diff(value_mag, axis=1), axis=1)
loss += self.delta_delta_time_weight * mean_difference(
target, value, self.loss_type)
if self.delta_freq_weight > 0:
target = diff(target_mag, axis=2)
value = diff(value_mag, axis=2)
loss += self.delta_freq_weight * mean_difference(
target, value, self.loss_type)
if self.delta_delta_freq_weight > 0:
target = diff(diff(target_mag, axis=2), axis=2)
value = diff(diff(value_mag, axis=2), axis=2)
loss += self.delta_delta_freq_weight * mean_difference(
target, value, self.loss_type)
# TODO(kyriacos) normalize cumulative spectrogram
if self.cumsum_freq_weight > 0:
target = cumsum(target_mag, axis=2)
value = cumsum(value_mag, axis=2)
loss += self.cumsum_freq_weight * mean_difference(
target, value, self.loss_type)
# Add logmagnitude loss, reusing spectrogram.
if self.logmag_weight > 0:
target = spectral_ops.safe_log(target_mag)
value = spectral_ops.safe_log(value_mag)
loss += self.logmag_weight * mean_difference(target, value,
self.loss_type)
if self.loudness_weight > 0:
target = spectral_ops.compute_loudness(target_audio, n_fft=6144,
use_tf=True)
value = spectral_ops.compute_loudness(audio, n_fft=6144, use_tf=True)
loss += self.loudness_weight * mean_difference(target, value,
self.loss_type)
return loss
@gin.register
class EmbeddingLoss(tfkl.Layer):
"""Embedding loss for a given pretrained model.
Calculates the embedding loss given a pretrained model.
You may also define a trivial pretrained model to apply any function that
computes the embedding.
"""
def __init__(self,
weight=1.0,
loss_type='L1',
pretrained_model=None,
name='embedding_loss'):
super().__init__(name=name)
self.weight = weight
self.loss_type = loss_type
self.pretrained_model = pretrained_model
def call(self, target_audio, audio):
loss = 0.0
if self.weight > 0.0:
audio, target_audio = tf_float32(audio), tf_float32(target_audio)
target_emb = self.pretrained_model(target_audio)
synth_emb = self.pretrained_model(audio)
loss = self.weight * mean_difference(
target_emb, synth_emb, self.loss_type)
return loss
@gin.register
class PretrainedCREPEEmbeddingLoss(EmbeddingLoss):
"""Embedding loss of the CREPE model."""
def __init__(self,
weight=1.0,
loss_type='L1',
model_capacity='tiny',
activation_layer='classifier',
name='pretrained_crepe_embedding_loss'):
# Scale each layer activation loss to comparable scales.
scale = {
'conv1-BN': 1.3,
'conv1-maxpool': 1.0,
'conv2-BN': 1.4,
'conv2-maxpool': 1.1,
'conv3-BN': 1.9,
'conv3-maxpool': 1.6,
'conv4-BN': 1.5,
'conv4-maxpool': 1.4,
'conv5-BN': 1.9,
'conv5-maxpool': 1.7,
'conv6-BN': 30,
'conv6-maxpool': 25,
'classifier': 130,
}[activation_layer]
super().__init__(
weight=20.0 * scale * weight,
loss_type=loss_type,
name=name,
pretrained_model=PretrainedCREPE(model_capacity=model_capacity,
activation_layer=activation_layer))
class PretrainedCREPE(tfkl.Layer):
"""Pretrained CREPE model with frozen weights."""
def __init__(self,
model_capacity='tiny',
activation_layer='conv5-maxpool',
name='pretrained_crepe',
trainable=False):
super(PretrainedCREPE, self).__init__(name=name, trainable=trainable)
self._model_capacity = model_capacity
self._activation_layer = activation_layer
spectral_ops.reset_crepe()
self._model = crepe.core.build_and_load_model(self._model_capacity)
self.frame_length = 1024
def build(self, unused_x_shape):
self.layer_names = [l.name for l in self._model.layers]
if self._activation_layer not in self.layer_names:
raise ValueError(
'activation layer {} not found, valid names are {}'.format(
self._activation_layer, self.layer_names))
self._activation_model = tf.keras.Model(
inputs=self._model.input,
outputs=self._model.get_layer(self._activation_layer).output)
# Variables are not to be trained.
self._model.trainable = self.trainable
self._activation_model.trainable = self.trainable
def frame_audio(self, audio, hop_length=1024, center=True):
"""Slice audio into frames for crepe."""
# Pad so that frames are centered around their timestamps.
# (i.e. first frame is zero centered).
pad = int(self.frame_length / 2)
audio = tf.pad(audio, ((0, 0), (pad, pad))) if center else audio
frames = tf.signal.frame(audio,
frame_length=self.frame_length,
frame_step=hop_length)
# Normalize each frame -- this is expected by the model.
mean, var = tf.nn.moments(frames, [-1], keepdims=True)
frames -= mean
frames /= (var**0.5 + 1e-5)
return frames
def call(self, audio):
"""Returns the embeddings.
Args:
audio: tensors of shape [batch, length]. Length must be divisible by 1024.
Returns:
activations of shape [batch, depth]
"""
frames = self.frame_audio(audio)
batch_size = int(frames.shape[0])
n_frames = int(frames.shape[1])
# Get model predictions.
frames = tf.reshape(frames, [-1, self.frame_length])
outputs = self._activation_model(frames)
outputs = tf.reshape(outputs, [batch_size, n_frames, -1])
return outputs
def tf_float32(x):
"""Ensure array/tensor is a float32 tf.Tensor."""
if isinstance(x, tf.Tensor):
return tf.cast(x, dtype=tf.float32) # This is a no-op if x is float32.
else:
return tf.convert_to_tensor(x, tf.float32)