Skip to content

Commit

Permalink
[Keras] Enable Dense operator for any input dims (#16526)
Browse files Browse the repository at this point in the history
Our dense op expects 2D, but there are no limitation in Keras on the
shape of the input tensor. Reshaping of all "batch" axes into one was
added in this commit. After that, it is possible to import Dense layer
with ND input tensor from Keras to TVM.
  • Loading branch information
echuraev authored Feb 7, 2024
1 parent 6a3fadc commit 2dcf9ec
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
14 changes: 8 additions & 6 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,12 @@ def _convert_dense(
# In case of RNN dense, input shape will be (1, 1, n)
if input_dim > 2:
input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0])
if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1:
raise tvm.error.OpAttributeInvalid(
f"Input shape {input_shape} is not valid for operator Dense."
)
inexpr = _op.squeeze(inexpr, axis=[0])
# Keras has no limitations on the shape of the input tensor. But our
# dense op expects 2D input. All inputs with number of dimensions > 2
# are reshaped all "batch" axes into one.
# For example: (N, d1, d2, d3) -> (N * d1 * d2, d3)
new_batch_size = np.prod(input_shape[:-1])
inexpr = _op.reshape(inexpr, newshape=(new_batch_size, input_shape[-1]))
out = _op.nn.dense(data=inexpr, **params)
if keras_layer.use_bias:
bias = etab.new_const(weightList[1])
Expand All @@ -283,7 +284,8 @@ def _convert_dense(
if act_type != "linear":
out = _convert_activation(out, act_type, etab, data_layout)
if input_dim > 2:
out = _op.expand_dims(out, axis=0)
out_shape = (*input_shape[:-1], units)
out = _op.reshape(out, newshape=out_shape)
return out


Expand Down
10 changes: 10 additions & 0 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,16 @@ def test_forward_dense(self, keras_mod):
keras_model = keras_mod.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)

data = keras_mod.layers.Input(shape=(120, 2560), name="image_set")
x = keras_mod.layers.Dense(1, activation="linear", name="e")(data)
keras_model = keras_mod.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)

data = keras_mod.layers.Input(shape=(10, 12, 2560), name="image_set")
x = keras_mod.layers.Dense(32, activation="linear", name="e")(data)
keras_model = keras_mod.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)

def test_forward_permute(self, keras_mod):
data = keras_mod.layers.Input(shape=(2, 3, 4))
x = keras_mod.layers.Permute([2, 3, 1])(data)
Expand Down

0 comments on commit 2dcf9ec

Please sign in to comment.