From 0122bd8bf384ebbb740f1d4808c3eef6940e1603 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 21 Jan 2022 10:05:19 -0800 Subject: [PATCH] [frontend][keras] Add support for TimeDistributed (#7006) * First pass on modifying Keras importer to handle TimeDistributed * Use squeeze inside TimeDistributed, add tests * linter fixes * More linting * Even more linting * Fix unused argument annotations * Forgot one pylint annotation * Forgot to set up data layout in _convert_activation * Decouple data_layout from etab * Linting fix * Forgot to set data_layout argument * Missed an etab.data_format, also test_conv1d was not in the test file's main * Rebase fixes * Linting fix * _convert_lambda needs a data layout argument too * linting fix too * Lint the test file too * Redundant variables * Simplify further * Another simplification Co-authored-by: Steven Lyubomirsky --- python/tvm/relay/frontend/keras.py | 347 ++++++++++++++------ tests/python/frontend/keras/test_forward.py | 17 + 2 files changed, 264 insertions(+), 100 deletions(-) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 818a8ba1cffa..98ccb509fb4d 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -20,7 +20,7 @@ import sys import numpy as np import tvm -from tvm.ir import IRModule +from tvm.ir import IRModule, TensorType, TupleType from .. import analysis from .. import expr as _expr @@ -62,10 +62,12 @@ def _as_list(arr): def _convert_recurrent_activation(inexpr, keras_layer): act_type = keras_layer.recurrent_activation.__name__ - return _convert_activation(inexpr, act_type, None) + return _convert_activation(inexpr, act_type, None, None, None) -def _convert_activation(inexpr, keras_layer, etab): +def _convert_activation( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument if isinstance(keras_layer, str): act_type = keras_layer else: @@ -82,7 +84,7 @@ def _convert_activation(inexpr, keras_layer, etab): beta = _expr.const(beta, dtype="float32") return _op.add(_op.multiply(inexpr, alpha), beta) if act_type == "softmax": - axis = 1 if etab.data_layout == "NCHW" else -1 + axis = 1 if data_layout == "NCHW" else -1 return _op.nn.softmax(inexpr, axis) if act_type == "sigmoid": return _op.sigmoid(inexpr) @@ -124,17 +126,19 @@ def _convert_activation(inexpr, keras_layer, etab): ) -def _convert_advanced_activation(inexpr, keras_layer, etab): +def _convert_advanced_activation(inexpr, keras_layer, etab, data_layout, input_shape=None): act_type = type(keras_layer).__name__ + if input_shape is None: + input_shape = keras_layer.input_shape if act_type == "Softmax": axis = keras_layer.axis - dims = len(keras_layer.input_shape) + dims = len(input_shape) if isinstance(axis, list): raise tvm.error.OpAttributeUnImplemented( "Softmax with axes {} is not supported.".format(axis) ) - if etab.data_layout == "NCHW": + if data_layout == "NCHW": if axis == -1: axis = 1 else: @@ -161,7 +165,7 @@ def _convert_advanced_activation(inexpr, keras_layer, etab): assert hasattr(keras_layer, "alpha"), "alpha required for PReLU." _check_data_format(keras_layer) size = len(keras_layer.alpha.shape) - if etab.data_layout == "NCHW": + if data_layout == "NCHW": alpha = etab.new_const(keras_layer.get_weights()[0].transpose(np.roll(range(size), 1))) else: alpha = etab.new_const(keras_layer.get_weights()[0]) @@ -177,7 +181,9 @@ def _convert_advanced_activation(inexpr, keras_layer, etab): ) -def _convert_merge(inexpr, keras_layer, _): +def _convert_merge( + inexpr, keras_layer, _, input_shape=None, data_layout=None +): # pylint: disable=unused-argument merge_type = type(keras_layer).__name__ ret = inexpr[0] if merge_type == "Dot": @@ -225,11 +231,15 @@ def _convert_merge(inexpr, keras_layer, _): return ret -def _convert_permute(inexpr, keras_layer, _): +def _convert_permute( + inexpr, keras_layer, _, input_shape=None, data_layout=None +): # pylint: disable=unused-argument return _op.transpose(inexpr, axes=(0,) + keras_layer.dims) -def _convert_embedding(inexpr, keras_layer, etab): +def _convert_embedding( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument indices = inexpr weightList = keras_layer.get_weights() weight = etab.new_const(weightList[0]) @@ -238,11 +248,14 @@ def _convert_embedding(inexpr, keras_layer, etab): return out -def _convert_dense(inexpr, keras_layer, etab): +def _convert_dense( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument weightList = keras_layer.get_weights() weight = etab.new_const(weightList[0].transpose([1, 0])) params = {"weight": weight, "units": weightList[0].shape[1]} - input_shape = keras_layer.input_shape + if input_shape is None: + input_shape = keras_layer.input_shape input_dim = len(input_shape) # In case of RNN dense, input shape will be (1, 1, n) if input_dim > 2: @@ -262,18 +275,20 @@ def _convert_dense(inexpr, keras_layer, etab): else: act_type = keras_layer.activation.__name__ if act_type != "linear": - out = _convert_activation(out, act_type, etab) + out = _convert_activation(out, act_type, etab, data_layout) if input_dim > 2: out = _op.expand_dims(out, axis=0) return out -def _convert_convolution1d(inexpr, keras_layer, etab): +def _convert_convolution1d(inexpr, keras_layer, etab, data_layout, input_shape=None): + if input_shape is None: + input_shape = keras_layer.input_shape _check_data_format(keras_layer) weightList = keras_layer.get_weights() weight = weightList[0] - if etab.data_layout == "NWC": + if data_layout == "NWC": kernel_layout = "WIO" else: kernel_layout = "OIW" @@ -281,7 +296,7 @@ def _convert_convolution1d(inexpr, keras_layer, etab): "Kernel layout with {} is not supported for operator Convolution1D " "in frontend Keras." ) - raise tvm.error.OpAttributeUnImplemented(msg.format(etab.data_layout)) + raise tvm.error.OpAttributeUnImplemented(msg.format(data_layout)) is_deconv = type(keras_layer).__name__ == "Conv1DTranspose" @@ -306,7 +321,7 @@ def _convert_convolution1d(inexpr, keras_layer, etab): "strides": [stride_w], "dilation": dilation, "padding": [0], - "data_layout": etab.data_layout, + "data_layout": data_layout, "kernel_layout": kernel_layout, } params["channels"] = n_filters @@ -315,7 +330,7 @@ def _convert_convolution1d(inexpr, keras_layer, etab): pass # calculate the padding values elif keras_layer.padding == "same": - in_w = keras_layer.input_shape[1] + in_w = input_shape[1] pad_w = _get_pad_pair(in_w, dilated_kernel_w, stride_w) params["padding"] = [pad_w[0], pad_w[1]] else: @@ -327,7 +342,7 @@ def _convert_convolution1d(inexpr, keras_layer, etab): else: out = _op.nn.conv1d(data=inexpr, **params) - channel_axis = -1 if etab.data_layout == "NWC" else 1 + channel_axis = -1 if data_layout == "NWC" else 1 if keras_layer.use_bias: bias = etab.new_const(weightList[1]) out = _op.nn.bias_add(out, bias, channel_axis) @@ -338,18 +353,21 @@ def _convert_convolution1d(inexpr, keras_layer, etab): else: act_type = keras_layer.activation.__name__ if act_type != "linear": - out = _convert_activation(out, act_type, etab) + out = _convert_activation(out, act_type, etab, data_layout) return out -def _convert_convolution(inexpr, keras_layer, etab): +def _convert_convolution(inexpr, keras_layer, etab, data_layout, input_shape=None): _check_data_format(keras_layer) is_deconv = type(keras_layer).__name__ == "Conv2DTranspose" is_depthconv = type(keras_layer).__name__ == "DepthwiseConv2D" weightList = keras_layer.get_weights() weight = weightList[0] - if etab.data_layout == "NHWC": + if input_shape is None: + input_shape = keras_layer.input_shape + + if data_layout == "NHWC": if is_depthconv: kernel_layout = "HWOI" else: @@ -368,7 +386,7 @@ def _convert_convolution(inexpr, keras_layer, etab): kernel_h, kernel_w, in_channels, depth_mult = weight.shape if kernel_layout == "OIHW": weight = weight.transpose([2, 3, 0, 1]) - elif etab.data_layout == "NCHW": + elif data_layout == "NCHW": kernel_h, kernel_w, in_channels, n_filters = weight.shape weight = weight.transpose([3, 2, 0, 1]) else: @@ -386,7 +404,7 @@ def _convert_convolution(inexpr, keras_layer, etab): "strides": [stride_h, stride_w], "dilation": dilation, "padding": [0, 0], - "data_layout": etab.data_layout, + "data_layout": data_layout, "kernel_layout": kernel_layout, } if is_depthconv: @@ -398,8 +416,8 @@ def _convert_convolution(inexpr, keras_layer, etab): pass # we insert a separate pad operator elif keras_layer.padding == "same": - in_h = keras_layer.input_shape[1] - in_w = keras_layer.input_shape[2] + in_h = input_shape[1] + in_w = input_shape[2] pad_t, pad_b = _get_pad_pair(in_h, dilated_kernel_h, stride_h) pad_l, pad_r = _get_pad_pair(in_w, dilated_kernel_w, stride_w) params["padding"] = (pad_t, pad_l, pad_b, pad_r) @@ -413,7 +431,7 @@ def _convert_convolution(inexpr, keras_layer, etab): if keras_layer.use_bias: bias = etab.new_const(weightList[1]) - if etab.data_layout == "NCHW": + if data_layout == "NCHW": out = _op.nn.bias_add(out, bias) else: out = _op.nn.bias_add(out, bias, axis=-1) @@ -423,16 +441,19 @@ def _convert_convolution(inexpr, keras_layer, etab): else: act_type = keras_layer.activation.__name__ if act_type != "linear": - out = _convert_activation(out, act_type, etab) + out = _convert_activation(out, act_type, etab, data_layout) + return out -def _convert_convolution3d(inexpr, keras_layer, etab): +def _convert_convolution3d(inexpr, keras_layer, etab, data_layout, input_shape=None): _check_data_format(keras_layer) weightList = keras_layer.get_weights() weight = weightList[0] + if input_shape is None: + input_shape = keras_layer.input_shape - if etab.data_layout == "NDHWC": + if data_layout == "NDHWC": kernel_layout = "DHWIO" else: kernel_layout = "OIDHW" @@ -440,7 +461,7 @@ def _convert_convolution3d(inexpr, keras_layer, etab): "Kernel layout with {} is not supported for operator Convolution3D " "in frontend Keras." ) - raise tvm.error.OpAttributeUnImplemented(msg.format(etab.data_layout)) + raise tvm.error.OpAttributeUnImplemented(msg.format(data_layout)) is_deconv = type(keras_layer).__name__ == "Conv3DTranspose" @@ -467,7 +488,7 @@ def _convert_convolution3d(inexpr, keras_layer, etab): "strides": [stride_d, stride_h, stride_w], "dilation": dilation, "padding": [0, 0, 0], - "data_layout": etab.data_layout, + "data_layout": data_layout, "kernel_layout": kernel_layout, } params["channels"] = n_filters @@ -476,9 +497,9 @@ def _convert_convolution3d(inexpr, keras_layer, etab): pass # calculate the padding values elif keras_layer.padding == "same": - in_d = keras_layer.input_shape[1] - in_h = keras_layer.input_shape[2] - in_w = keras_layer.input_shape[3] + in_d = input_shape[1] + in_h = input_shape[2] + in_w = input_shape[3] pad_d = _get_pad_pair(in_d, dilated_kernel_d, stride_d) pad_h = _get_pad_pair(in_h, dilated_kernel_h, stride_h) pad_w = _get_pad_pair(in_w, dilated_kernel_w, stride_w) @@ -491,7 +512,7 @@ def _convert_convolution3d(inexpr, keras_layer, etab): else: out = _op.nn.conv3d(data=inexpr, **params) - channel_axis = -1 if etab.data_layout == "NDHWC" else 1 + channel_axis = -1 if data_layout == "NDHWC" else 1 if keras_layer.use_bias: bias = etab.new_const(weightList[1]) out = _op.nn.bias_add(out, bias, channel_axis) @@ -502,17 +523,22 @@ def _convert_convolution3d(inexpr, keras_layer, etab): else: act_type = keras_layer.activation.__name__ if act_type != "linear": - out = _convert_activation(out, act_type, etab) + out = _convert_activation(out, act_type, etab, None) return out -def _convert_separable_convolution(inexpr, keras_layer, etab): +def _convert_separable_convolution(inexpr, keras_layer, etab, data_layout, input_shape=None): _check_data_format(keras_layer) - if etab.data_layout == "NHWC": + + if data_layout == "NHWC": kernel_layout = "HWOI" else: kernel_layout = "OIHW" + + if input_shape is None: + input_shape = keras_layer.input_shape + weightList = keras_layer.get_weights() # depthwise conv kernel_h, kernel_w, in_channels, depth_mult = weightList[0].shape @@ -529,15 +555,15 @@ def _convert_separable_convolution(inexpr, keras_layer, etab): "strides": [stride_h, stride_w], "dilation": [1, 1], "padding": [0, 0], - "data_layout": etab.data_layout, + "data_layout": data_layout, "kernel_layout": kernel_layout, } if keras_layer.padding == "valid": pass # we insert a separate pad operator elif keras_layer.padding == "same": - in_h = keras_layer.input_shape[1] - in_w = keras_layer.input_shape[2] + in_h = input_shape[1] + in_w = input_shape[2] pad_t, pad_b = _get_pad_pair(in_h, kernel_h, stride_h) pad_l, pad_r = _get_pad_pair(in_w, kernel_w, stride_w) params0["padding"] = (pad_t, pad_l, pad_b, pad_r) @@ -561,13 +587,13 @@ def _convert_separable_convolution(inexpr, keras_layer, etab): "kernel_size": [1, 1], "strides": [1, 1], "dilation": [1, 1], - "data_layout": etab.data_layout, + "data_layout": data_layout, "kernel_layout": kernel_layout, } out = _op.nn.conv2d(data=depthconv, **params1) if keras_layer.use_bias: bias = etab.new_const(weightList[2]) - if etab.data_layout == "NCHW": + if data_layout == "NCHW": out = _op.nn.bias_add(out, bias) else: out = _op.nn.bias_add(out, bias, axis=-1) @@ -577,30 +603,40 @@ def _convert_separable_convolution(inexpr, keras_layer, etab): else: act_type = keras_layer.activation.__name__ if act_type != "linear": - out = _convert_activation(out, act_type, etab) + out = _convert_activation(out, act_type, etab, data_layout) return out -def _convert_flatten(inexpr, keras_layer, etab): +def _convert_flatten( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) + # NCHW -> NHWC so that dense can be correctly converted - if etab.data_layout == "NCHW": + if data_layout == "NCHW": inexpr = _op.transpose(inexpr, axes=[0, 2, 3, 1]) return _op.nn.batch_flatten(inexpr) -def _convert_pooling(inexpr, keras_layer, etab): +def _convert_pooling( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) + pool_type = type(keras_layer).__name__ # global pool in keras = global pool + flatten in relay - global_pool_params = {"layout": etab.data_layout} + global_pool_params = {"layout": data_layout} + + if input_shape is None: + input_shape = keras_layer.input_shape + if pool_type == "GlobalMaxPooling2D": return _convert_flatten( - _op.nn.global_max_pool2d(inexpr, **global_pool_params), keras_layer, etab + _op.nn.global_max_pool2d(inexpr, **global_pool_params), keras_layer, etab, data_layout ) if pool_type == "GlobalAveragePooling2D": return _convert_flatten( - _op.nn.global_avg_pool2d(inexpr, **global_pool_params), keras_layer, etab + _op.nn.global_avg_pool2d(inexpr, **global_pool_params), keras_layer, etab, data_layout ) pool_h, pool_w = keras_layer.pool_size stride_h, stride_w = keras_layer.strides @@ -608,13 +644,13 @@ def _convert_pooling(inexpr, keras_layer, etab): "pool_size": [pool_h, pool_w], "strides": [stride_h, stride_w], "padding": [0, 0], - "layout": etab.data_layout, + "layout": data_layout, } if keras_layer.padding == "valid": pass elif keras_layer.padding == "same": - in_h = keras_layer.input_shape[1] - in_w = keras_layer.input_shape[2] + in_h = input_shape[1] + in_w = input_shape[2] pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h) pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w) params["padding"] = [pad_t, pad_l, pad_b, pad_r] @@ -632,9 +668,13 @@ def _convert_pooling(inexpr, keras_layer, etab): ) -def _convert_pooling3d(inexpr, keras_layer, etab): +def _convert_pooling3d( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) pool_type = type(keras_layer).__name__ + if input_shape is None: + input_shape = keras_layer.input_shape if pool_type not in ["MaxPooling3D", "AveragePooling3D"]: raise tvm.error.OpNotImplemented( @@ -647,15 +687,15 @@ def _convert_pooling3d(inexpr, keras_layer, etab): "pool_size": [pool_d1, pool_d2, pool_d3], "strides": [stride_d1, stride_d2, stride_d3], "padding": [0, 0, 0], - "layout": etab.data_layout, + "layout": data_layout, } if keras_layer.padding == "valid": pass elif keras_layer.padding == "same": - in_d1 = keras_layer.input_shape[1] - in_d2 = keras_layer.input_shape[2] - in_d3 = keras_layer.input_shape[3] + in_d1 = input_shape[1] + in_d2 = input_shape[2] + in_d3 = input_shape[3] pad_d1 = _get_pad_pair(in_d1, pool_d1, stride_d1) pad_d2 = _get_pad_pair(in_d2, pool_d2, stride_d2) pad_d3 = _get_pad_pair(in_d3, pool_d3, stride_d3) @@ -675,11 +715,13 @@ def _convert_pooling3d(inexpr, keras_layer, etab): return _op.transpose(out, axes=(0, 2, 3, 4, 1)) -def _convert_global_pooling3d(inexpr, keras_layer, etab): +def _convert_global_pooling3d( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) pool_type = type(keras_layer).__name__ - global_pool_params = {"layout": etab.data_layout} + global_pool_params = {"layout": data_layout} if pool_type == "GlobalMaxPooling3D": out = _op.nn.global_max_pool3d(inexpr, **global_pool_params) elif pool_type == "GlobalAveragePooling3D": @@ -689,10 +731,12 @@ def _convert_global_pooling3d(inexpr, keras_layer, etab): "Operator {} is not supported for frontend Keras.".format(keras_layer) ) - return _convert_flatten(out, keras_layer, etab) + return _convert_flatten(out, keras_layer, etab, input_shape, data_layout) -def _convert_upsample(inexpr, keras_layer, etab): +def _convert_upsample( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) upsample_type = type(keras_layer).__name__ params = {} @@ -716,29 +760,36 @@ def _convert_upsample(inexpr, keras_layer, etab): raise tvm.error.OpNotImplemented( "Operator {} is not supported for frontend Keras.".format(upsample_type) ) - params["layout"] = etab.data_layout + params["layout"] = data_layout out = _op.nn.upsampling(inexpr, **params) return out -def _convert_upsample3d(inexpr, keras_layer, etab): +def _convert_upsample3d( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) + params = {} d, h, w = keras_layer.size params["scale_d"] = d params["scale_h"] = h params["scale_w"] = w - params["layout"] = etab.data_layout + params["layout"] = data_layout params["coordinate_transformation_mode"] = "asymmetric" out = _op.nn.upsampling3d(inexpr, **params) return out -def _convert_cropping(inexpr, keras_layer, _): +def _convert_cropping( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) crop_type = type(keras_layer).__name__ + if input_shape is None: + input_shape = keras_layer.input_shape if crop_type == "Cropping2D": - (_, in_h, in_w, _) = keras_layer.input_shape + (_, in_h, in_w, _) = input_shape ((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping else: raise tvm.error.OpNotImplemented( @@ -752,8 +803,10 @@ def _convert_cropping(inexpr, keras_layer, _): ) -def _convert_batchnorm(inexpr, keras_layer, etab): - if etab.data_layout == "NCHW" or len(keras_layer.input_shape) < 4: +def _convert_batchnorm(inexpr, keras_layer, etab, data_layout, input_shape=None): + if input_shape is None: + input_shape = keras_layer.input_shape + if data_layout == "NCHW" or len(input_shape) < 4: axis = 1 else: axis = 3 @@ -785,8 +838,11 @@ def _convert_batchnorm(inexpr, keras_layer, etab): return result -def _convert_padding(inexpr, keras_layer, etab): +def _convert_padding( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) + padding_type = type(keras_layer).__name__ padding = keras_layer.padding top = left = bottom = right = 0 @@ -809,13 +865,16 @@ def _convert_padding(inexpr, keras_layer, etab): else: msg = "Operator {} is not supported in frontend Keras." raise tvm.error.OpNotImplemented(msg.format(padding_type)) - if etab.data_layout == "NCHW": + if data_layout == "NCHW": return _op.nn.pad(data=inexpr, pad_width=((0, 0), (0, 0), (top, bottom), (left, right))) return _op.nn.pad(data=inexpr, pad_width=((0, 0), (top, bottom), (left, right), (0, 0))) -def _convert_padding3d(inexpr, keras_layer, etab): +def _convert_padding3d( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) + padding = keras_layer.padding d_pad = h_pad = w_pad = [0, 0] @@ -831,7 +890,7 @@ def _convert_padding3d(inexpr, keras_layer, etab): msg = 'Value {} in attribute "padding" of operator ZeroPadding3D is ' "not valid." raise tvm.error.OpAttributeInvalid(msg.format(str(padding))) - if etab.data_layout == "NCDHW": + if data_layout == "NCDHW": out = _op.nn.pad( data=inexpr, pad_width=( @@ -856,22 +915,32 @@ def _convert_padding3d(inexpr, keras_layer, etab): return out -def _convert_concat(inexpr, keras_layer, etab): +def _convert_concat( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) - if etab.data_layout == "NHWC" or len(keras_layer.input_shape[0]) < 4: + if input_shape is None: + input_shape = keras_layer.input_shape + + if data_layout == "NHWC" or len(input_shape[0]) < 4: axis = -1 else: axis = 1 return _op.concatenate(_as_list(inexpr), axis=axis) -def _convert_reshape(inexpr, keras_layer, etab): +def _convert_reshape( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) - inshape = keras_layer.input_shape # includes batch + if input_shape is None: + input_shape = keras_layer.input_shape + + inshape = input_shape # includes batch tshape = keras_layer.target_shape # no batch shape = (-1,) + tshape - if etab.data_layout == "NCHW" and (len(inshape) > 3 or len(tshape) > 2): + if data_layout == "NCHW" and (len(inshape) > 3 or len(tshape) > 2): # Perform reshape in original NHWC format. inexpr = _op.transpose(inexpr, [0] + list(range(2, len(inshape))) + [1]) inexpr = _op.reshape(inexpr, newshape=shape) @@ -880,8 +949,12 @@ def _convert_reshape(inexpr, keras_layer, etab): return _op.reshape(inexpr, newshape=shape) -def _convert_lstm(inexpr, keras_layer, etab): +def _convert_lstm( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) + if input_shape is None: + input_shape = keras_layer.input_shape if not isinstance(inexpr, list): buf = np.zeros((1, keras_layer.units), "float32") c_op = etab.new_const(buf) @@ -891,7 +964,7 @@ def _convert_lstm(inexpr, keras_layer, etab): next_h = inexpr[1] next_c = inexpr[2] weightList = keras_layer.get_weights() - in_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.input_shape)[0]) + in_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0]) kernel_weight = etab.new_const(weightList[0].transpose([1, 0])) recurrent_weight = etab.new_const(weightList[1].transpose([1, 0])) in_bias = etab.new_const(weightList[2]) @@ -908,9 +981,11 @@ def _convert_lstm(inexpr, keras_layer, etab): gates = _op.split(gate, indices_or_sections=4, axis=1) in_gate = _convert_recurrent_activation(gates[0], keras_layer) in_transform = _convert_recurrent_activation(gates[1], keras_layer) - next_c = in_transform * next_c + in_gate * _convert_activation(gates[2], keras_layer, None) + next_c = in_transform * next_c + in_gate * _convert_activation( + gates[2], keras_layer, etab, data_layout + ) out_gate = _convert_recurrent_activation(gates[3], keras_layer) - next_h = out_gate * _convert_activation(next_c, keras_layer, None) + next_h = out_gate * _convert_activation(next_c, keras_layer, etab, data_layout) if keras_layer.return_sequences: out_list.append(_op.expand_dims(next_h, axis=1)) out = _op.concatenate(out_list, axis=1) if keras_layer.return_sequences else next_h @@ -919,7 +994,9 @@ def _convert_lstm(inexpr, keras_layer, etab): return [out, next_h, next_c] -def _convert_simple_rnn(inexpr, keras_layer, etab): +def _convert_simple_rnn( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) if not isinstance(inexpr, list): buf = np.zeros((1, keras_layer.units), "float32") @@ -937,13 +1014,15 @@ def _convert_simple_rnn(inexpr, keras_layer, etab): prev_op = _op.nn.batch_flatten(prev_op) ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units) output = ixh + ixh2 - output = _convert_activation(output, keras_layer, None) + output = _convert_activation(output, keras_layer, etab, data_layout) out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) output = _op.reshape(output, newshape=out_shape) return [output, output] -def _convert_gru(inexpr, keras_layer, etab): +def _convert_gru( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument _check_data_format(keras_layer) if not isinstance(inexpr, list): buf = np.zeros((1, keras_layer.units), "float32") @@ -978,7 +1057,7 @@ def _convert_gru(inexpr, keras_layer, etab): rec_act_r = _convert_recurrent_activation(x_r + recurrent_r, keras_layer) units = keras_layer.units recurrent_h = _op.nn.dense(rec_act_r * h_tm1_op, rec_weights[1], units=units) - act_hh = _convert_activation(x_h + recurrent_h, keras_layer, None) + act_hh = _convert_activation(x_h + recurrent_h, keras_layer, etab, data_layout) # previous and candidate state mixed by update gate output = rec_act_z * h_tm1_op + (_expr.const(1.0, dtype="float32") - rec_act_z) * act_hh out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) @@ -986,8 +1065,12 @@ def _convert_gru(inexpr, keras_layer, etab): return [output, output] -def _convert_repeat_vector(inexpr, keras_layer, _): - input_shape = list(keras_layer.input_shape) +def _convert_repeat_vector( + inexpr, keras_layer, etab, data_layout, input_shape=None +): # pylint: disable=unused-argument + if input_shape is None: + input_shape = keras_layer.input_shape + input_shape = list(input_shape) repeats = keras_layer.n out_shape = [-1, repeats] + input_shape[1:] out = _op.repeat(inexpr, repeats=repeats, axis=0) @@ -995,7 +1078,7 @@ def _convert_repeat_vector(inexpr, keras_layer, _): return out -def _convert_l2_normalize(inexpr, keras_layer, etab): +def _convert_l2_normalize(inexpr, keras_layer, data_layout): l2_normalize_is_loaded = False param_list = [] for i in dis.get_instructions(keras_layer.function): @@ -1066,7 +1149,7 @@ def is_int_or_tuple_of_ints(v): if isinstance(axis, int): axis = [axis] - if etab.data_layout == "NCHW": + if data_layout == "NCHW": dims = len(keras_layer.input_shape) def fix_axis_for_nchw(axis): @@ -1080,7 +1163,7 @@ def fix_axis_for_nchw(axis): return _op.nn.l2_normalize(inexpr, eps=1e-12, axis=axis) -def _convert_lambda(inexpr, keras_layer, etab): +def _convert_lambda(inexpr, keras_layer, _, data_layout): fcode = keras_layer.function.__code__ # Convert l2_normalize if ( @@ -1088,7 +1171,7 @@ def _convert_lambda(inexpr, keras_layer, etab): and len(fcode.co_names) > 0 and fcode.co_names[-1] == "l2_normalize" ): - return _convert_l2_normalize(inexpr, keras_layer, etab) + return _convert_l2_normalize(inexpr, keras_layer, data_layout) raise tvm.error.OpNotImplemented( "Function {} used in Lambda layer is not supported in frontend Keras.".format( fcode.co_names @@ -1096,7 +1179,65 @@ def _convert_lambda(inexpr, keras_layer, etab): ) -def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument +def _convert_time_distributed(inexpr, keras_layer, etab, data_layout, input_shape=None): + # TimeDistributed: split input tensor along the second dimension (assumed to be time), + # apply inner layer to each split individually, + # and then combine the results + if input_shape is None: + input_shape = keras_layer.input_shape + + assert len(input_shape) >= 2, "Input to TimeDistributed must have at least two dimensions" + + inner_layer = keras_layer.layer + inner_input_shape = [d for (i, d) in enumerate(input_shape) if i != 1] + + # for NDHWC, inner data layout will drop the D + inner_data_layout = data_layout + if data_layout == "NDHWC": + inner_data_layout = "NHWC" + + # some code duplication from keras_op_to_relay + # but it's useful to avoid cluttering the etab + inner_layer_op_name = type(keras_layer.layer).__name__ + if inner_layer_op_name not in _convert_map: + raise tvm.error.OpNotImplemented( + "The inner layer for TimeDistributed {} is not supported for frontend Keras.".format( + inner_layer_op_name + ) + ) + + conversion_func = lambda expr: _convert_map[inner_layer_op_name]( + expr, inner_layer, etab, inner_data_layout, input_shape=inner_input_shape + ) + + split_dim = input_shape[1] + split_input = _op.split(inexpr, split_dim, 1) + + split_shape = list(input_shape) + if split_shape[0] is None: + split_shape[0] = 1 + split_shape[1] = 1 + + split_var = new_var( + "time_distributed_split", + type_annotation=TupleType( + [TensorType(split_shape, dtype="float32") for i in range(split_dim)] + ), + ) + + # For each split, squeeze away the second dimension, + # apply the inner layer. + # Afterwards, combine the transformed splits back along + # the second dimension using stack + splits = [ + conversion_func(_op.squeeze(_expr.TupleGetItem(split_var, i), axis=[1])) + for i in range(split_dim) + ] + + return _expr.Let(split_var, split_input.astuple(), _op.stack(splits, axis=1)) + + +def _default_skip(inexpr, keras_layer, etab, data_layout): # pylint: disable=unused-argument """Layers that can be skipped because they are train time only.""" return inexpr @@ -1152,7 +1293,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument "LSTM": _convert_lstm, "GRU": _convert_gru, # 'Bidirectional' : _convert_bidirectional, - # 'TimeDistributed' : _default_skip, + "TimeDistributed": _convert_time_distributed, "Average": _convert_merge, "Minimum": _convert_merge, "Maximum": _convert_merge, @@ -1184,7 +1325,7 @@ def _check_unsupported_layers(model): ) -def keras_op_to_relay(inexpr, keras_layer, outname, etab): +def keras_op_to_relay(inexpr, keras_layer, outname, etab, data_layout): """Convert a Keras layer to a Relay expression and update the expression table. Parameters @@ -1200,13 +1341,16 @@ def keras_op_to_relay(inexpr, keras_layer, outname, etab): etab : relay.frontend.common.ExprTable The global expression table to be updated. + + data_layout : str + The input data layout """ op_name = type(keras_layer).__name__ if op_name not in _convert_map: raise tvm.error.OpNotImplemented( "Operator {} is not supported for frontend Keras.".format(op_name) ) - outs = _convert_map[op_name](inexpr, keras_layer, etab) + outs = _convert_map[op_name](inexpr, keras_layer, etab, data_layout) outs = _as_list(outs) for t_idx, out in enumerate(outs): name = outname + ":" + str(t_idx) @@ -1326,7 +1470,11 @@ def _convert_layer(keras_layer, etab, scope=""): inexpr = inexpr[0] outs.extend( keras_op_to_relay( - inexpr, keras_layer, scope + keras_layer.name + ":" + str(node_idx), etab + inexpr, + keras_layer, + scope + keras_layer.name + ":" + str(node_idx), + etab, + layout, ) ) return outs @@ -1368,7 +1516,6 @@ def _convert_layer(keras_layer, etab, scope=""): "NHWC", "NDHWC", ], "Layout must be one of 'NWC', 'NCHW', NHWC or NDHWC" - etab.data_layout = layout for keras_layer in model.layers: if isinstance(keras_layer, input_layer_class): _convert_input_layer(keras_layer) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 114b8f961374..2cfb93dcc9b4 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -624,6 +624,21 @@ def test_forward_l2_normalize(self, keras): verify_keras_frontend(keras_model, layout="NCHW") verify_keras_frontend(keras_model, layout="NHWC") + def test_forward_time_distributed(self, keras): + conv2d_inputs = keras.Input(shape=(10, 128, 128, 3)) + conv_2d_layer = keras.layers.Conv2D(64, (3, 3)) + conv2d_model = keras.models.Model( + conv2d_inputs, keras.layers.TimeDistributed(conv_2d_layer)(conv2d_inputs) + ) + verify_keras_frontend(conv2d_model, layout="NDHWC") + + dense_inputs = keras.Input(shape=(5, 1)) + dense_layer = keras.layers.Dense(1) + dense_model = keras.models.Model( + dense_inputs, keras.layers.TimeDistributed(dense_layer)(dense_inputs) + ) + verify_keras_frontend(dense_model, need_transpose=False) + if __name__ == "__main__": for k in [keras, tf_keras]: @@ -636,6 +651,7 @@ def test_forward_l2_normalize(self, keras): sut.test_forward_sequential(keras=k) sut.test_forward_pool(keras=k) sut.test_forward_conv(keras=k) + sut.test_forward_conv1d(keras=k) sut.test_forward_batch_norm(keras=k) sut.test_forward_upsample(keras=k, interpolation="nearest") sut.test_forward_upsample(keras=k, interpolation="bilinear") @@ -662,3 +678,4 @@ def test_forward_l2_normalize(self, keras): sut.test_forward_embedding(keras=k) sut.test_forward_repeat_vector(keras=k) sut.test_forward_l2_normalize(keras=k) + sut.test_forward_time_distributed(keras=k)