From 7fd4704003dec853abbfc15a47a0d07d941b7a8a Mon Sep 17 00:00:00 2001 From: Haoyang Date: Wed, 13 Sep 2023 12:53:41 +0800 Subject: [PATCH] fix _convert_simple_rnn (#15723) * fix _convert_simple_rnn * fix _convert_simple_rnn * fix errors in the last pr --- python/tvm/relay/frontend/keras.py | 25 ++++++++++++--------- tests/python/frontend/keras/test_forward.py | 11 +++++++++ 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 205b2be490a0..9e09cb400ab2 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -1052,23 +1052,26 @@ def _convert_simple_rnn( inexpr = [inexpr, prev_op] in_data = inexpr[0] prev_op = inexpr[1] + prev_op = _op.nn.batch_flatten(prev_op) weightList = keras_layer.get_weights() kernel_weight = etab.new_const(weightList[0].transpose([1, 0])) recurrent_weight = etab.new_const(weightList[1].transpose([1, 0])) - if keras_layer.use_bias: - in_bias = etab.new_const(weightList[2]) units = list(weightList[0].shape)[1] assert units > 0, "The value of units must be a positive integer" - in_data = _op.nn.batch_flatten(in_data) - ixh = _op.nn.dense(in_data, kernel_weight, units=units) if keras_layer.use_bias: - ixh = _op.nn.bias_add(ixh, bias=in_bias) - prev_op = _op.nn.batch_flatten(prev_op) - ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units) - output = ixh + ixh2 - output = _convert_activation(output, keras_layer, etab, data_layout) - out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) - output = _op.reshape(output, newshape=out_shape) + in_bias = etab.new_const(weightList[2]) + assert len(in_data.type_annotation.shape) == 3 + timeDim = in_data.type_annotation.shape[1].value + in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1) + for i in range(len(in_data_split)): + in_data_split_i = _op.nn.batch_flatten(in_data_split[i]) + ixh = _op.nn.dense(in_data_split_i, kernel_weight, units=units) + if keras_layer.use_bias: + ixh = _op.nn.bias_add(ixh, bias=in_bias) + ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units) + output = ixh + ixh2 + output = _convert_activation(output, keras_layer, etab, data_layout) + prev_op = output return [output, output] diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 80460f6063d7..9d33b15a9179 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -825,6 +825,16 @@ def test_forward_time_distributed(self, keras_mod): ) verify_keras_frontend(dense_model, need_transpose=False) + def test_simplernn_with_infertype(self, keras_mod): + """This test case is from https://github.com/apache/tvm/issues/14868""" + input_shape = (2, 2, 2) + x = keras_mod.layers.Input(shape=input_shape[1:], dtype="float32") + layer = keras_mod.layers.SimpleRNN(units=4) + y = layer(x) + model = keras_mod.models.Model(x, y) + mod, _ = relay.frontend.from_keras(model, {model.input_names[0]: input_shape}) + relay.transform.InferType()(mod) + if __name__ == "__main__": for k in [keras, tf_keras]: @@ -867,3 +877,4 @@ def test_forward_time_distributed(self, keras_mod): sut.test_forward_repeat_vector(keras_mod=k) sut.test_forward_l2_normalize(keras_mod=k) sut.test_forward_time_distributed(keras_mod=k) + sut.test_simplernn_with_infertype(keras_mod=k)