From cd4002d8a30472d4b6c6de5e23c6ea6dc0e7bb61 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Mon, 11 Sep 2023 00:14:41 +0800 Subject: [PATCH 1/3] fix _convert_simple_rnn --- python/tvm/relay/frontend/keras.py | 20 ++++++++++++++------ tests/python/frontend/keras/test_forward.py | 11 +++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 205b2be490a0..7ed4debf3143 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -1053,22 +1053,30 @@ def _convert_simple_rnn( in_data = inexpr[0] prev_op = inexpr[1] weightList = keras_layer.get_weights() - kernel_weight = etab.new_const(weightList[0].transpose([1, 0])) + weightList0 = weightList[0].transpose([1, 0]) + assert len(in_data.type_annotation.shape) == 3 + for i in range(in_data.type_annotation.shape[1].value - 1): + weightList0 = np.hstack((weightList0, weightList[0].transpose([1, 0]))) + kernel_weight = etab.new_const(weightList0) recurrent_weight = etab.new_const(weightList[1].transpose([1, 0])) if keras_layer.use_bias: in_bias = etab.new_const(weightList[2]) units = list(weightList[0].shape)[1] assert units > 0, "The value of units must be a positive integer" + dim = weightList0.shape[0] in_data = _op.nn.batch_flatten(in_data) ixh = _op.nn.dense(in_data, kernel_weight, units=units) if keras_layer.use_bias: ixh = _op.nn.bias_add(ixh, bias=in_bias) + split_list = [] + for i in range(1, dim): + split_list.append(i) + ixh_tuple = _op.split(ixh, split_list, 1) prev_op = _op.nn.batch_flatten(prev_op) - ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units) - output = ixh + ixh2 - output = _convert_activation(output, keras_layer, etab, data_layout) - out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) - output = _op.reshape(output, newshape=out_shape) + for i in range(dim): + ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units) + prev_op = ixh_tuple[0] + ixh2 + output = prev_op return [output, output] diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 80460f6063d7..902fe40f2f44 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -825,6 +825,16 @@ def test_forward_time_distributed(self, keras_mod): ) verify_keras_frontend(dense_model, need_transpose=False) + def test_SimpleRNN_with_InferType(self, keras_mod): + """This test case is from https://github.com/apache/tvm/issues/14868""" + input_shape = (2, 2, 2) + x = keras_mod.layers.Input(shape=input_shape[1:], dtype="float32") + layer = keras_mod.layers.SimpleRNN(units=4) + y = layer(x) + model = keras_mod.models.Model(x, y) + mod, _ = relay.frontend.from_keras(model, {"input_1": input_shape}) + relay.transform.InferType()(mod) + if __name__ == "__main__": for k in [keras, tf_keras]: @@ -867,3 +877,4 @@ def test_forward_time_distributed(self, keras_mod): sut.test_forward_repeat_vector(keras_mod=k) sut.test_forward_l2_normalize(keras_mod=k) sut.test_forward_time_distributed(keras_mod=k) + sut.test_SimpleRNN_with_InferType(keras_mod=k) From 6050fc14b4949c6dd440bc05a9d4fcf3da486058 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Mon, 11 Sep 2023 12:30:37 +0800 Subject: [PATCH 2/3] fix _convert_simple_rnn --- tests/python/frontend/keras/test_forward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 902fe40f2f44..e5fd988c166a 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -825,7 +825,7 @@ def test_forward_time_distributed(self, keras_mod): ) verify_keras_frontend(dense_model, need_transpose=False) - def test_SimpleRNN_with_InferType(self, keras_mod): + def test_simplernn_with_infertype(self, keras_mod): """This test case is from https://github.com/apache/tvm/issues/14868""" input_shape = (2, 2, 2) x = keras_mod.layers.Input(shape=input_shape[1:], dtype="float32") @@ -877,4 +877,4 @@ def test_SimpleRNN_with_InferType(self, keras_mod): sut.test_forward_repeat_vector(keras_mod=k) sut.test_forward_l2_normalize(keras_mod=k) sut.test_forward_time_distributed(keras_mod=k) - sut.test_SimpleRNN_with_InferType(keras_mod=k) + sut.test_simplernn_with_infertype(keras_mod=k) From a2f143fd6d35445f40b59abe24e886ae71d8521b Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Tue, 12 Sep 2023 19:57:21 +0800 Subject: [PATCH 3/3] fix errors in the last pr --- python/tvm/relay/frontend/keras.py | 33 +++++++++------------ tests/python/frontend/keras/test_forward.py | 2 +- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 7ed4debf3143..9e09cb400ab2 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -1052,31 +1052,26 @@ def _convert_simple_rnn( inexpr = [inexpr, prev_op] in_data = inexpr[0] prev_op = inexpr[1] + prev_op = _op.nn.batch_flatten(prev_op) weightList = keras_layer.get_weights() - weightList0 = weightList[0].transpose([1, 0]) - assert len(in_data.type_annotation.shape) == 3 - for i in range(in_data.type_annotation.shape[1].value - 1): - weightList0 = np.hstack((weightList0, weightList[0].transpose([1, 0]))) - kernel_weight = etab.new_const(weightList0) + kernel_weight = etab.new_const(weightList[0].transpose([1, 0])) recurrent_weight = etab.new_const(weightList[1].transpose([1, 0])) - if keras_layer.use_bias: - in_bias = etab.new_const(weightList[2]) units = list(weightList[0].shape)[1] assert units > 0, "The value of units must be a positive integer" - dim = weightList0.shape[0] - in_data = _op.nn.batch_flatten(in_data) - ixh = _op.nn.dense(in_data, kernel_weight, units=units) if keras_layer.use_bias: - ixh = _op.nn.bias_add(ixh, bias=in_bias) - split_list = [] - for i in range(1, dim): - split_list.append(i) - ixh_tuple = _op.split(ixh, split_list, 1) - prev_op = _op.nn.batch_flatten(prev_op) - for i in range(dim): + in_bias = etab.new_const(weightList[2]) + assert len(in_data.type_annotation.shape) == 3 + timeDim = in_data.type_annotation.shape[1].value + in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1) + for i in range(len(in_data_split)): + in_data_split_i = _op.nn.batch_flatten(in_data_split[i]) + ixh = _op.nn.dense(in_data_split_i, kernel_weight, units=units) + if keras_layer.use_bias: + ixh = _op.nn.bias_add(ixh, bias=in_bias) ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units) - prev_op = ixh_tuple[0] + ixh2 - output = prev_op + output = ixh + ixh2 + output = _convert_activation(output, keras_layer, etab, data_layout) + prev_op = output return [output, output] diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index e5fd988c166a..9d33b15a9179 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -832,7 +832,7 @@ def test_simplernn_with_infertype(self, keras_mod): layer = keras_mod.layers.SimpleRNN(units=4) y = layer(x) model = keras_mod.models.Model(x, y) - mod, _ = relay.frontend.from_keras(model, {"input_1": input_shape}) + mod, _ = relay.frontend.from_keras(model, {model.input_names[0]: input_shape}) relay.transform.InferType()(mod)