Skip to content

Commit

Permalink
fix _convert_simple_rnn (#15723)
Browse files Browse the repository at this point in the history
* fix _convert_simple_rnn

* fix _convert_simple_rnn

* fix errors in the last pr
  • Loading branch information
haoyang9804 authored Sep 13, 2023
1 parent e88d0d4 commit 7fd4704
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
25 changes: 14 additions & 11 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,23 +1052,26 @@ def _convert_simple_rnn(
inexpr = [inexpr, prev_op]
in_data = inexpr[0]
prev_op = inexpr[1]
prev_op = _op.nn.batch_flatten(prev_op)
weightList = keras_layer.get_weights()
kernel_weight = etab.new_const(weightList[0].transpose([1, 0]))
recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
if keras_layer.use_bias:
in_bias = etab.new_const(weightList[2])
units = list(weightList[0].shape)[1]
assert units > 0, "The value of units must be a positive integer"
in_data = _op.nn.batch_flatten(in_data)
ixh = _op.nn.dense(in_data, kernel_weight, units=units)
if keras_layer.use_bias:
ixh = _op.nn.bias_add(ixh, bias=in_bias)
prev_op = _op.nn.batch_flatten(prev_op)
ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units)
output = ixh + ixh2
output = _convert_activation(output, keras_layer, etab, data_layout)
out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
output = _op.reshape(output, newshape=out_shape)
in_bias = etab.new_const(weightList[2])
assert len(in_data.type_annotation.shape) == 3
timeDim = in_data.type_annotation.shape[1].value
in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1)
for i in range(len(in_data_split)):
in_data_split_i = _op.nn.batch_flatten(in_data_split[i])
ixh = _op.nn.dense(in_data_split_i, kernel_weight, units=units)
if keras_layer.use_bias:
ixh = _op.nn.bias_add(ixh, bias=in_bias)
ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units)
output = ixh + ixh2
output = _convert_activation(output, keras_layer, etab, data_layout)
prev_op = output
return [output, output]


Expand Down
11 changes: 11 additions & 0 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,16 @@ def test_forward_time_distributed(self, keras_mod):
)
verify_keras_frontend(dense_model, need_transpose=False)

def test_simplernn_with_infertype(self, keras_mod):
"""This test case is from https://github.com/apache/tvm/issues/14868"""
input_shape = (2, 2, 2)
x = keras_mod.layers.Input(shape=input_shape[1:], dtype="float32")
layer = keras_mod.layers.SimpleRNN(units=4)
y = layer(x)
model = keras_mod.models.Model(x, y)
mod, _ = relay.frontend.from_keras(model, {model.input_names[0]: input_shape})
relay.transform.InferType()(mod)


if __name__ == "__main__":
for k in [keras, tf_keras]:
Expand Down Expand Up @@ -867,3 +877,4 @@ def test_forward_time_distributed(self, keras_mod):
sut.test_forward_repeat_vector(keras_mod=k)
sut.test_forward_l2_normalize(keras_mod=k)
sut.test_forward_time_distributed(keras_mod=k)
sut.test_simplernn_with_infertype(keras_mod=k)

0 comments on commit 7fd4704

Please sign in to comment.