-
Notifications
You must be signed in to change notification settings - Fork 11
/
yamnet.py
121 lines (104 loc) · 4.75 KB
/
yamnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""Core model definition of YAMNet."""
import csv
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, layers
import y_features as features_lib
def _batch_norm(name, params):
def _bn_layer(layer_input):
return layers.BatchNormalization(
name=name,
center=params.batchnorm_center,
scale=params.batchnorm_scale,
epsilon=params.batchnorm_epsilon)(layer_input)
return _bn_layer
def _conv(name, kernel, stride, filters, params):
def _conv_layer(layer_input):
output = layers.Conv2D(name='{}/conv'.format(name),
filters=filters,
kernel_size=kernel,
strides=stride,
padding=params.conv_padding,
use_bias=False,
activation=None)(layer_input)
output = _batch_norm('{}/conv/bn'.format(name), params)(output)
output = layers.ReLU(name='{}/relu'.format(name))(output)
return output
return _conv_layer
def _separable_conv(name, kernel, stride, filters, params):
def _separable_conv_layer(layer_input):
output = layers.DepthwiseConv2D(name='{}/depthwise_conv'.format(name),
kernel_size=kernel,
strides=stride,
depth_multiplier=1,
padding=params.conv_padding,
use_bias=False,
activation=None)(layer_input)
output = _batch_norm('{}/depthwise_conv/bn'.format(name), params)(output)
output = layers.ReLU(name='{}/depthwise_conv/relu'.format(name))(output)
output = layers.Conv2D(name='{}/pointwise_conv'.format(name),
filters=filters,
kernel_size=(1, 1),
strides=1,
padding=params.conv_padding,
use_bias=False,
activation=None)(output)
output = _batch_norm('{}/pointwise_conv/bn'.format(name), params)(output)
output = layers.ReLU(name='{}/pointwise_conv/relu'.format(name))(output)
return output
return _separable_conv_layer
_YAMNET_LAYER_DEFS = [
# (layer_function, kernel, stride, num_filters)
(_conv, [3, 3], 2, 32),
(_separable_conv, [3, 3], 1, 64),
(_separable_conv, [3, 3], 2, 128),
(_separable_conv, [3, 3], 1, 128),
(_separable_conv, [3, 3], 2, 256),
(_separable_conv, [3, 3], 1, 256),
(_separable_conv, [3, 3], 2, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 2, 1024),
(_separable_conv, [3, 3], 1, 1024)
]
def yamnet(features, params):
"""Define the core YAMNet mode in Keras."""
net = layers.Reshape(
(params.patch_frames, params.patch_bands, 1),
input_shape=(params.patch_frames, params.patch_bands))(features)
for (i, (layer_fun, kernel, stride, filters)) in enumerate(_YAMNET_LAYER_DEFS):
net = layer_fun('layer{}'.format(i + 1), kernel, stride, filters, params)(net)
embeddings = layers.GlobalAveragePooling2D()(net)
logits = layers.Dense(units=params.num_classes, use_bias=True)(embeddings)
predictions = layers.Activation(activation=params.classifier_activation)(logits)
return predictions, embeddings
def yamnet_frames_model(params):
"""Defines the YAMNet waveform-to-class-scores model.
Args:
params: An instance of Params containing hyperparameters.
Returns:
A model accepting (num_samples,) waveform input and emitting:
- predictions: (num_patches, num_classes) matrix of class scores per time frame
- embeddings: (num_patches, embedding size) matrix of embeddings per time frame
- log_mel_spectrogram: (num_spectrogram_frames, num_mel_bins) spectrogram feature matrix
"""
waveform = layers.Input(batch_shape=(None,), dtype=tf.float32)
waveform_padded = features_lib.pad_waveform(waveform, params)
log_mel_spectrogram, features = features_lib.waveform_to_log_mel_spectrogram_patches(
waveform_padded, params)
predictions, embeddings = yamnet(features, params)
frames_model = Model(
name='yamnet_frames', inputs=waveform,
outputs=[predictions, embeddings, log_mel_spectrogram])
return frames_model
def class_names(class_map_csv):
"""Read the class name definition file and return a list of strings."""
if tf.is_tensor(class_map_csv):
class_map_csv = class_map_csv.numpy()
with open(class_map_csv) as csv_file:
reader = csv.reader(csv_file)
next(reader) # Skip header
return np.array([display_name for (_, _, display_name) in reader])