diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index ce048105ae8b..3a4897ad3166 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -726,12 +726,13 @@ def gru_cell( b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1) b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1) r_gate += b_ir + b_hr + r_gate = rz_act(r_gate) z_gate += b_iz + b_hz i_n += b_in h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + b_hn else: + r_gate = rz_act(r_gate) h_n = _op.nn.dense((r_gate * hidden_state), w_hn) - r_gate = rz_act(r_gate) z_gate = rz_act(z_gate) n_gate = n_act(i_n + h_n) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a1d821686ed5..d82148128184 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3311,6 +3311,8 @@ def verify_rnn( use_peep=False, linear_before_reset=False, directions=1, + rtol=1e-5, + atol=1e-5, target=None, dev=None, ): @@ -3433,7 +3435,7 @@ def register(name, shape, proto_type): model = helper.make_model(graph, producer_name="rnn_test") verify_with_ort_with_inputs( - model, input_values, output_shapes, atol=1e-2, rtol=1e-2, target=target, dev=dev + model, input_values, output_shapes, atol=atol, rtol=rtol, target=target, dev=dev ) @@ -3599,10 +3601,12 @@ def test_gru(target, dev): use_bias=False, rnn_type="GRU", directions=directions, + rtol=1e-6, + atol=1e-6, target=target, dev=dev, ) - # large batch. + # large batch. linear before reset verify_rnn( seq_length=4, batch_size=8, @@ -3624,6 +3628,8 @@ def test_gru(target, dev): use_bias=True, rnn_type="GRU", directions=directions, + rtol=1e-6, + atol=1e-6, target=target, dev=dev, ) @@ -3636,6 +3642,8 @@ def test_gru(target, dev): use_bias=True, rnn_type="GRU", directions=directions, + rtol=1e-6, + atol=1e-6, target=target, dev=dev, ) @@ -3648,6 +3656,8 @@ def test_gru(target, dev): use_bias=True, rnn_type="GRU", directions=directions, + rtol=1e-6, + atol=1e-6, target=target, dev=dev, ) @@ -3660,6 +3670,8 @@ def test_gru(target, dev): use_bias=True, rnn_type="GRU", directions=directions, + rtol=1e-6, + atol=1e-6, target=target, dev=dev, ) @@ -3675,6 +3687,8 @@ def test_gru(target, dev): activations=["HardSigmoid", "Softsign"] * directions, rnn_type="GRU", directions=directions, + rtol=1e-6, + atol=1e-6, target=target, dev=dev, ) @@ -3690,6 +3704,8 @@ def test_gru(target, dev): betas=[0.3, 0.0] * directions, rnn_type="GRU", directions=directions, + rtol=1e-8, + atol=1e-8, target=target, dev=dev, ) @@ -3705,6 +3721,8 @@ def test_gru(target, dev): betas=[0.3, 0.1] * directions, rnn_type="GRU", directions=directions, + rtol=1e-8, + atol=1e-8, target=target, dev=dev, ) @@ -3719,6 +3737,8 @@ def test_gru(target, dev): use_initial_state=True, rnn_type="GRU", directions=directions, + rtol=1e-6, + atol=1e-6, target=target, dev=dev, )