Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Add Conv1D support for MXNet backend (#44)
Browse files Browse the repository at this point in the history
* Add Conv1D support for MXNet backend

* Fix CR comments
  • Loading branch information
sandeep-krishnamurthy authored Mar 6, 2018
1 parent 8325604 commit c97096e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 15 deletions.
60 changes: 50 additions & 10 deletions keras/backend/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@
_REENTRY = False
NAME_SCOPE_STACK = []

# MXNet requires 'channels_first' format for efficient performance
if image_data_format() == 'channels_last':
warnings.warn('MXNet Backend performs best with `channels_first` format. Using `channels_last` will '
'significantly reduce performance due to the Transpose operations.', stacklevel=2)


class name_scope(object):
def __init__(self, name):
Expand Down Expand Up @@ -2915,7 +2910,54 @@ def conv1d(x, kernel, strides=1, padding='valid',
# Returns
A tensor, result of 1D convolution.
"""
raise NotImplementedError('MXNet Backend: conv1d is not supported yet.')
if data_format is None:
data_format = image_data_format()
_validate_data_format(data_format)

# Causal requires temporal padding.
# MXNet backend does not support temporal padding on 3D tensor.
if padding is 'causal':
raise ValueError('MXNet Backend: conv1d does not support "causal" padding mode')

if padding not in {'same', 'valid'}:
raise ValueError('`padding` should be either `same` or `valid`.')

if hasattr(x, '_keras_shape'):
shape = x._keras_shape
else:
shape = None

if data_format == 'channels_last':
# X original shape (batch, length, input_dim)
# Add a dimension to X to Make it (batch, length, 1, input_dim)
x = expand_dims(x, axis=2)
# update x._keras_shape
if shape is not None:
x._keras_shape = (shape[0], shape[1], 1, shape[2])
elif data_format == 'channels_first':
# X original shape (batch, input_dim, length)
# Add a dimension to X to make it (batch, input_dim, length, 1)
x = expand_dims(x, axis=3)
if shape is not None:
x._keras_shape = (shape[0], shape[1], shape[2], 1)

# update dilation rate, strides
dilation_rate = (dilation_rate, 1)
strides = (strides, 1)
# add dim to kernel (always same format independently of data_format)
# i.e. (rows, 1, input_depth, depth)
kernel = expand_dims(kernel, axis=1)

output = _convnd(x, kernel, name='conv1d', strides=strides, filter_dilation=dilation_rate,
padding_mode=padding, data_format=data_format)

# Remove added extra dimension
# remove added dim
if data_format == 'channels_last':
output = squeeze(output, axis=2)
else:
output = squeeze(output, axis=3)
return output


def conv2d(x, kernel, strides=(1, 1), padding='valid',
Expand Down Expand Up @@ -3915,7 +3957,6 @@ def _preprocess_convnd_input(data_var, data_format):
axes = list(range(ndim(data_var)))
axes.insert(1, axes.pop(-1)) # make it channels_first format
data_var = KerasSymbol(mx.sym.transpose(data=data_var.symbol, axes=axes))

return data_var


Expand All @@ -3930,7 +3971,7 @@ def _postprocess_convnd_output(x, data_format):


@keras_mxnet_symbol
def _preprocess_convnd_kernel(kernel, data_format):
def _preprocess_convnd_kernel(kernel):
# Kernel is always provided in TF kernel shape:
# 2-D: (rows, cols, input_depth, depth)
# 3-D: (kernel_depth, kernel_rows, kernel_cols, input_depth, depth)
Expand All @@ -3942,7 +3983,6 @@ def _preprocess_convnd_kernel(kernel, data_format):
kernel = KerasSymbol(mx.sym.transpose(data=kernel.symbol, axes=(4, 3, 0, 1, 2)))
elif len(kernel.shape) > 3:
kernel = KerasSymbol(mx.sym.transpose(data=kernel.symbol, axes=(3, 2, 0, 1)))

return kernel


Expand Down Expand Up @@ -4001,7 +4041,7 @@ def _convnd(x, kernel, strides, filter_dilation, name=None, padding_mode='valid'

# Handle Data Format
x = _preprocess_convnd_input(x, data_format)
kernel = _preprocess_convnd_kernel(kernel, data_format)
kernel = _preprocess_convnd_kernel(kernel)

# We have already converted kernel to match MXNet required shape:
# (depth, input_depth, rows, cols)
Expand Down
4 changes: 2 additions & 2 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,14 +807,14 @@ def test_conv1d(self):
# MXNet backend does not support conv1d yet.
for strides in [1, 2]:
check_two_tensor_operation('conv1d', input_shape, kernel_shape,
BACKENDS_WITHOUT_MXNET, cntk_dynamicity=True,
BACKENDS, cntk_dynamicity=True,
strides=strides,
data_format='channels_last')

xval = np.random.random(input_shape)
kernel_val = np.random.random(kernel_shape) - 0.5
# Test invalid use cases
for k in BACKENDS_WITHOUT_MXNET:
for k in BACKENDS:
with pytest.raises(ValueError):
k.conv1d(k.variable(xval), k.variable(kernel_val), data_format='channels_middle')

Expand Down
4 changes: 1 addition & 3 deletions tests/keras/layers/convolutional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

@keras_test
@pytest.mark.skipif((K.backend() == 'cntk' or K.backend() == 'mxnet'),
reason='cntk/mxnet do not support dilated conv')
reason='cntk/mxnet do not support Causal padding in conv1d')
def test_causal_dilated_conv():
# Causal:
layer_test(convolutional.Conv1D,
Expand Down Expand Up @@ -65,8 +65,6 @@ def test_causal_dilated_conv():
)


@pytest.mark.skipif((K.backend() == 'mxnet'),
reason='MXNet backend does not support conv1d yet.')
@keras_test
def test_conv_1d():
batch_size = 2
Expand Down

0 comments on commit c97096e

Please sign in to comment.