Skip to content

Commit

Permalink
do not reset NN layers in online (streaming) mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Böck committed Dec 8, 2016
1 parent 99846e9 commit e64fc55
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
7 changes: 5 additions & 2 deletions madmom/ml/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ class NeuralNetwork(Processor):
"""

def __init__(self, layers):
def __init__(self, layers, online=False):
self.layers = layers
self.online = online

def process(self, data):
"""
Expand All @@ -89,13 +90,15 @@ def process(self, data):
Network predictions for this data.
"""
# reset the layers? (online: do not reset, keep the state)
reset = not self.online
# check the dimensions of the data
if data.ndim == 1:
data = np.atleast_2d(data).T
# loop over all layers
for layer in self.layers:
# activate the layer and feed the output into the next one
data = layer(data)
data = layer(data, reset=reset)
# ravel the predictions if needed
if data.ndim == 2 and data.shape[1] == 1:
data = data.ravel()
Expand Down
12 changes: 6 additions & 6 deletions madmom/ml/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class Layer(object):
"""

def __call__(self, *args):
def __call__(self, *args, **kwargs):
# this magic method makes a Layer callable
return self.activate(*args)
return self.activate(*args, **kwargs)

def activate(self, data):
"""
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(self, weights, bias, activation_fn):
self.bias = bias
self.activation_fn = activation_fn

def activate(self, data):
def activate(self, data, **kwargs):
"""
Activate the layer.
Expand Down Expand Up @@ -180,7 +180,7 @@ def __init__(self, fwd_layer, bwd_layer):
self.fwd_layer = fwd_layer
self.bwd_layer = bwd_layer

def activate(self, data):
def activate(self, data, **kwargs):
"""
Activate the layer.
Expand All @@ -200,9 +200,9 @@ def activate(self, data):
"""
# activate in forward direction
fwd = self.fwd_layer(data)
fwd = self.fwd_layer(data, **kwargs)
# also activate with reverse input
bwd = self.bwd_layer(data[::-1])
bwd = self.bwd_layer(data[::-1], **kwargs)
# stack data
return np.hstack((fwd, bwd[::-1]))

Expand Down
1 change: 1 addition & 0 deletions madmom/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,4 @@ def io_arguments(parser, output_suffix='.txt', pickle=True, online=False):
sp.set_defaults(origin='future')
sp.set_defaults(num_frames=1)
sp.set_defaults(stream=None)
sp.set_defaults(online=True)

0 comments on commit e64fc55

Please sign in to comment.