From 132d81ff0190b726e6d0e99caf6ea72f3333641f Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Mon, 28 Jun 2021 09:33:06 -0700 Subject: [PATCH] [Onnx] Support Bidirectional RNNs (#8337) * modify lstm to be easily bidirectional * make it obvious some matriciies are packed via prime notation * fix var name * more var names * add op split * keyword arg names * missing implicit cls arg * deal with extra dimensions * last of the fixes * refactor rnn tests to support directions * bidirectional tests * test forward results * go backwards * more fixes * reverse tokens on reverse pass * parameterized directions * double up activations in bidirect * slow attribute forgetting * lstm interface is v. confus * test forward complete * add GRU outline * revisiion2 * why was tehre a not * gru tests * missing bounds, copy pasta! * add comment * ensure all args fp --- python/tvm/relay/frontend/onnx.py | 331 ++++++++----- tests/python/frontend/onnx/test_forward.py | 551 ++++++++++++--------- 2 files changed, 543 insertions(+), 339 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5d07102f2c3f..b38ad332af82 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2058,159 +2058,101 @@ class LSTM(RNN): """Operator converter for LSTM""" @classmethod - def _impl_v7(cls, inputs, attr, params): - # Unpack inputs, note that if optional and not provided then value will be None. - X = inputs[0] - W = inputs[1] - R = inputs[2] - B = inputs[3] - # Sequence length currently unused as it can be inferred from shapes. - # sequence_lens = inputs['sequence_lens'] - h_0 = inputs[5] - c_0 = inputs[6] - P = inputs[7] - - num_directions = infer_shape(W)[0] - W_dtype = infer_type(W).checked_type.dtype - - if num_directions != 1: - raise NotImplementedError("Bidirectional LSTMs not yet supported.") - # Remove num_directions axis from weights. - W = _op.squeeze(W, axis=[0]) - R = _op.squeeze(R, axis=[0]) - if B is not None: - B = _op.squeeze(B, axis=[0]) - - X_shape = infer_shape(X) - hidden_size = infer_shape(R)[-1] - batch_size = X_shape[1] - - # Initialize state if not provided. - # Otherwise remove bidirectional axis. - if h_0 is None: - h_0 = _op.zeros((batch_size, hidden_size), W_dtype) - else: - h_0 = _op.squeeze(h_0, axis=[0]) - if c_0 is None: - c_0 = _op.zeros((batch_size, hidden_size), W_dtype) - else: - c_0 = _op.squeeze(c_0, axis=[0]) + def generate_lstm( + cls, X_steps, H_t, C_t, W, R, B, p_i, p_f, p_o, f_act, g_act, h_act, backwards=False + ): + """Create an unrolled lstm loop. - if P is not None: - P = _op.squeeze(P, axis=[0]) - p_i, p_o, p_f = _op.split(P, 3) - H_t = h_0 - C_t = c_0 + See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math. + """ h_list = [] - - if "activations" in attr: - activations = attr["activations"] - if len(activations) != 3: - raise NotImplementedError("LSTM assumes 3 activation functions are provided") - alpha_loc = 0 - alphas = attr.get("activation_alpha", []) - if isinstance(alphas, float): - alphas = [alphas] - beta_loc = 0 - betas = attr.get("activation_beta", []) - if isinstance(betas, float): - betas = [betas] - acts = [] - for i in range(3): - alpha = None - beta = None - activation = activations[i] - if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc: - alpha = alphas[alpha_loc] - alpha_loc += 1 - if cls._activation_needs_beta(activation) and len(betas) > beta_loc: - beta = betas[beta_loc] - beta_loc += 1 - acts.append(cls._activation_helper(activation, alpha, beta)) - f_act, g_act, h_act = acts - else: - f_act = _op.sigmoid - g_act = _op.tanh - h_act = _op.tanh - - X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0) - for step in X_steps: + seq_length = len(X_steps) + for i in range(seq_length): + step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)] step = _op.squeeze(step, axis=[0]) gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R) if B is not None: WB, RB = _op.split(B, 2) gates += WB + RB i, o, f, c = _op.split(gates, 4, axis=-1) - if P is not None: - i = f_act(i + p_i * C_t) - f = f_act(f + p_f * C_t) + if p_i != 0: + i = f_act(i + p_i * C_t) else: i = f_act(i) + + if p_f != 0: + f = f_act(f + p_f * C_t) + else: f = f_act(f) + c = g_act(c) C = f * C_t + i * c - if P is not None: + if p_o != 0: o = f_act(o + p_o * C) else: o = f_act(o) + H = o * h_act(C) + H_t = H C_t = C h_list.append(_op.expand_dims(H, axis=0)) + + if backwards: + # Canonical view is hidden states from the first token not last + h_list = h_list[::-1] + # Concatenate outputs and add back in direction axis. concatenated = _op.concatenate(h_list, 0) output = _op.expand_dims(concatenated, axis=1) H_t = _op.expand_dims(H_t, axis=0) C_t = _op.expand_dims(C_t, axis=0) - return _expr.TupleWrapper(_expr.Tuple((output, H_t, C_t)), 3) - - -class GRU(RNN): - """Operator convert for GRU""" + return output, H_t, C_t @classmethod def _impl_v7(cls, inputs, attr, params): # Unpack inputs, note that if optional and not provided then value will be None. X = inputs[0] - W = inputs[1] - R = inputs[2] - B = inputs[3] + Wp = inputs[1] + Rp = inputs[2] + Bp = inputs[3] # Sequence length currently unused as it can be inferred from shapes. # sequence_lens = inputs['sequence_lens'] - h_0 = inputs[5] - linear_before_reset = attr.get("linear_before_reset", 0) + Hp_0 = inputs[5] + Cp_0 = inputs[6] + Pp = inputs[7] - num_directions = infer_shape(W)[0] - W_dtype = infer_type(W).checked_type.dtype + num_directions = infer_shape(Wp)[0] + W_dtype = infer_type(Wp).checked_type.dtype - if num_directions != 1: - raise NotImplementedError("Bidirectional GRUs not yet supported.") - # Remove num_directions axis from weights. - W = _op.squeeze(W, axis=[0]) - R = _op.squeeze(R, axis=[0]) - if B is not None: - B = _op.squeeze(B, axis=[0]) + if num_directions not in [1, 2]: + raise ValueError("num_directions must be either 1 or 2!") X_shape = infer_shape(X) - hidden_size = infer_shape(R)[-1] + hidden_size = infer_shape(Rp)[-1] batch_size = X_shape[1] # Initialize state if not provided. # Otherwise remove bidirectional axis. - if h_0 is None: - h_0 = _op.zeros((batch_size, hidden_size), W_dtype) + if Hp_0 is None: + Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + if Cp_0 is None: + Cp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + if Bp is None: + Bp = _op.zeros((num_directions, hidden_size * 8), W_dtype) + if Pp is not None: + p_i, p_o, p_f = _op.split(Pp, 3, axis=1) else: - h_0 = _op.squeeze(h_0, axis=[0]) - - H_t = h_0 - h_list = [] + p_i = p_o = p_f = _op.zeros((num_directions, hidden_size), W_dtype) if "activations" in attr: activations = attr["activations"] - if len(activations) != 2: - raise NotImplementedError("GRU assumes 2 activation functions are provided") + if len(activations) != 3 * num_directions: + raise NotImplementedError( + f"LSTM assumes 3 * num_directions activation functions are provided" + ) alpha_loc = 0 alphas = attr.get("activation_alpha", []) if isinstance(alphas, float): @@ -2220,7 +2162,7 @@ def _impl_v7(cls, inputs, attr, params): if isinstance(betas, float): betas = [betas] acts = [] - for i in range(2): + for i in range(3 * num_directions): alpha = None beta = None activation = activations[i] @@ -2231,13 +2173,75 @@ def _impl_v7(cls, inputs, attr, params): beta = betas[beta_loc] beta_loc += 1 acts.append(cls._activation_helper(activation, alpha, beta)) - f_act, g_act = acts else: - f_act = _op.sigmoid - g_act = _op.tanh + acts = [_op.sigmoid, _op.tanh, _op.tanh] * num_directions X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0) - for step in X_steps: + result_output = [] + result_H = [] + result_C = [] + + H_ts = _op.split(Hp_0, num_directions) + C_ts = _op.split(Cp_0, num_directions) + Ws = _op.split(Wp, num_directions) + Rs = _op.split(Rp, num_directions) + Bs = _op.split(Bp, num_directions) + p_is = _op.split(p_i, num_directions) + p_fs = _op.split(p_f, num_directions) + p_os = _op.split(p_o, num_directions) + for i in range(num_directions): + H_t = _op.squeeze(H_ts[i], axis=[0]) + C_t = _op.squeeze(C_ts[i], axis=[0]) + W = _op.squeeze(Ws[i], axis=[0]) + R = _op.squeeze(Rs[i], axis=[0]) + B = _op.squeeze(Bs[i], axis=[0]) + p_i = _op.squeeze(p_is[i], axis=[0]) + p_f = _op.squeeze(p_fs[i], axis=[0]) + p_o = _op.squeeze(p_os[i], axis=[0]) + + f_act, g_act, h_act = acts[i * 3 : (i + 1) * 3] + output, H, C = LSTM.generate_lstm( + X_steps=X_steps, + H_t=H_t, + C_t=C_t, + W=W, + R=R, + B=B, + p_i=p_i, + p_f=p_f, + p_o=p_o, + f_act=f_act, + g_act=g_act, + h_act=h_act, + backwards=i == 1, + ) + + result_output.append(output) + result_H.append(H) + result_C.append(C) + + output = _op.concatenate(result_output, axis=1) + H = _op.concatenate(result_H, axis=0) + C = _op.concatenate(result_C, axis=0) + + return _expr.TupleWrapper(_expr.Tuple((output, H, C)), 3) + + +class GRU(RNN): + """Operator convert for GRU""" + + @classmethod + def generate_gru( + cls, X_steps, H_t, W, R, B, linear_before_reset, f_act, g_act, W_dtype, backwards=False + ): + """Create an unrolled gru loop. + + See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math. + """ + h_list = [] + seq_length = len(X_steps) + for i in range(seq_length): + step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)] step = _op.squeeze(step, axis=[0]) current = _op.nn.dense(step, W) cz, cr, ch = _op.split(current, 3, axis=1) @@ -2266,12 +2270,113 @@ def _impl_v7(cls, inputs, attr, params): H_t = ((_expr.const(1, dtype=W_dtype) - z) * h) + (z * H_t) h_list.append(_op.expand_dims(H_t, axis=0)) + + if backwards: + # Canonical view is hidden states from the first token not last + h_list = h_list[::-1] + # Concatenate outputs and add back in direction axis. concatenated = _op.concatenate(h_list, 0) output = _op.expand_dims(concatenated, axis=1) H_t = _op.expand_dims(H_t, axis=0) - return _expr.TupleWrapper(_expr.Tuple((output, H_t)), 2) + return output, H_t + + @classmethod + def _impl_v7(cls, inputs, attr, params): + # Unpack inputs, note that if optional and not provided then value will be None. + X = inputs[0] + Wp = inputs[1] + Rp = inputs[2] + Bp = inputs[3] + # Sequence length currently unused as it can be inferred from shapes. + # sequence_lens = inputs['sequence_lens'] + Hp_0 = inputs[5] + linear_before_reset = attr.get("linear_before_reset", 0) + + num_directions = infer_shape(Wp)[0] + W_dtype = infer_type(Wp).checked_type.dtype + + if num_directions not in [1, 2]: + raise NotImplementedError( + f"Directions for GRUs should be either 1 or 2 got {num_directions}" + ) + + X_shape = infer_shape(X) + hidden_size = infer_shape(Rp)[-1] + batch_size = X_shape[1] + + # Initialize state if not provided. + # Otherwise remove bidirectional axis. + if Hp_0 is None: + Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) + if Bp is None: + Bp = _op.zeros((num_directions, hidden_size * 6), W_dtype) + + if "activations" in attr: + activations = attr["activations"] + if len(activations) != 2 * num_directions: + raise NotImplementedError( + "GRU assumes 2 * num_directions activation functions are provided" + ) + alpha_loc = 0 + alphas = attr.get("activation_alpha", []) + if isinstance(alphas, float): + alphas = [alphas] + beta_loc = 0 + betas = attr.get("activation_beta", []) + if isinstance(betas, float): + betas = [betas] + acts = [] + for i in range(2 * num_directions): + alpha = None + beta = None + activation = activations[i] + if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc: + alpha = alphas[alpha_loc] + alpha_loc += 1 + if cls._activation_needs_beta(activation) and len(betas) > beta_loc: + beta = betas[beta_loc] + beta_loc += 1 + acts.append(cls._activation_helper(activation, alpha, beta)) + else: + acts = [_op.sigmoid, _op.tanh] * 2 + + result_output = [] + result_H = [] + + X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0) + H_ts = _op.split(Hp_0, num_directions) + Ws = _op.split(Wp, num_directions) + Rs = _op.split(Rp, num_directions) + Bs = _op.split(Bp, num_directions) + + for i in range(num_directions): + H_t = _op.squeeze(H_ts[i], axis=[0]) + W = _op.squeeze(Ws[i], axis=[0]) + R = _op.squeeze(Rs[i], axis=[0]) + B = _op.squeeze(Bs[i], axis=[0]) + f_act, g_act = acts[i * 2 : (i + 1) * 2] + output, H = GRU.generate_gru( + X_steps=X_steps, + H_t=H_t, + W=W, + R=R, + B=B, + linear_before_reset=linear_before_reset, + f_act=f_act, + g_act=g_act, + W_dtype=W_dtype, + backwards=i == 1, + ) + + result_output.append(output) + result_H.append(H) + + output = _op.concatenate(result_output, axis=1) + H = _op.concatenate(result_H, axis=0) + + return _expr.TupleWrapper(_expr.Tuple((output, H)), 2) class Resize(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 2f92f2d51994..db71855fd80f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import re + import numpy as np import pytest import scipy @@ -3175,92 +3176,112 @@ def verify_rnn( use_initial_state=False, use_peep=False, linear_before_reset=False, + directions=1, ): if rnn_type == "LSTM": multiplier = 4 elif rnn_type == "GRU": multiplier = 3 else: - raise NotImplementedError("%s RNNs not yet supported." % rnn_type) - x_np = np.random.uniform(size=(seq_length, batch_size, input_size)).astype("float32") - w_np = np.random.uniform(size=(1, multiplier * hidden_size, input_size)).astype("float32") - r_np = np.random.uniform(size=(1, multiplier * hidden_size, hidden_size)).astype("float32") - input_names = ["X", "W", "R"] - input_tensors = [ - helper.make_tensor_value_info("X", TensorProto.FLOAT, list(x_np.shape)), - helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_np.shape)), - helper.make_tensor_value_info("R", TensorProto.FLOAT, list(r_np.shape)), - ] - input_values = [x_np, w_np, r_np] - - if use_bias: - b_np = np.random.uniform(size=(1, multiplier * 2 * hidden_size)).astype("float32") - input_names.append("B") - input_tensors.append( - helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, multiplier * 2 * hidden_size]) + raise NotImplementedError(f"{rnn_type} RNNs not yet supported.") + + if directions not in [1, 2]: + raise ValueError(f"Direction should be either 1 or 2 (for bidirectional LSTMs)") + + def get_inputs(): + input_names = [] + input_values = [] + input_tensors = [] + + def register(np_arr, name, shape=None): + input_values.append(np_arr) + input_names.append(name) + + # Map of numpy dtypes to the protobuf equivalent + dtype_map = { + "float32": TensorProto.FLOAT, + "int32": TensorProto.INT32, + "int8": TensorProto.INT8, + } + + if np_arr.dtype.name not in dtype_map: + raise ValueError(f"Unknown dtype we don't know how to handle {np.dtype.name}") + if shape is None: + shape = list(np_arr.shape) + proto_type = dtype_map[np_arr.dtype.name] + input_tensors.append(helper.make_tensor_value_info(name, proto_type, shape)) + + x_np = np.random.uniform(size=(seq_length, batch_size, input_size)).astype("float32") + w_np = np.random.uniform(size=(directions, multiplier * hidden_size, input_size)).astype( + "float32" ) - input_values.append(b_np) - - if use_initial_state: - assert use_bias == True, "Initial states must have bias specified." - sequence_np = np.repeat(seq_length, batch_size).astype("int32") - input_names.append("sequence_lens") - input_tensors.append( - helper.make_tensor_value_info("sequence_lens", TensorProto.INT32, [batch_size]) + r_np = np.random.uniform(size=(directions, multiplier * hidden_size, hidden_size)).astype( + "float32" ) - input_values.append(sequence_np) + register(x_np, "X") + register(w_np, "W") + register(r_np, "R") - initial_h_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype("float32") - input_names.append("initial_h") - input_tensors.append( - helper.make_tensor_value_info( - "initial_h", TensorProto.FLOAT, [1, batch_size, hidden_size] + if use_bias: + b_np = np.random.uniform(size=(directions, multiplier * 2 * hidden_size)).astype( + "float32" ) - ) - input_values.append(initial_h_np) + register(b_np, "B") - if rnn_type == "LSTM": - initial_c_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype("float32") - input_names.append("initial_c") - input_tensors.append( - helper.make_tensor_value_info( - "initial_c", TensorProto.FLOAT, [1, batch_size, hidden_size] - ) + if use_initial_state: + assert use_bias == True, "Initial states must have bias specified." + sequence_np = np.repeat(seq_length, batch_size).astype("int32") + register(sequence_np, "sequence_lens") + + initial_h_np = np.random.uniform(size=(directions, batch_size, hidden_size)).astype( + "float32" ) - input_values.append(initial_c_np) - - if use_peep and rnn_type == "LSTM": - assert use_initial_state == True, "Peepholes require initial state to be specified." - p_np = np.random.uniform(size=(1, 3 * hidden_size)).astype("float32") - input_names.append("P") - input_tensors.append( - helper.make_tensor_value_info("P", TensorProto.FLOAT, [1, 3 * hidden_size]) - ) - input_values.append(p_np) - - Y_shape = [seq_length, 1, batch_size, hidden_size] - Y_h_shape = [1, batch_size, hidden_size] - outputs = ["Y", "Y_h"] - graph_outputs = [ - helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(Y_shape)), - helper.make_tensor_value_info("Y_h", TensorProto.FLOAT, list(Y_h_shape)), - ] - output_shapes = [Y_shape, Y_h_shape] + register(initial_h_np, "initial_h") - if rnn_type == "LSTM": - Y_c_shape = [1, batch_size, hidden_size] - outputs.append("Y_c") - graph_outputs.append( - helper.make_tensor_value_info("Y_c", TensorProto.FLOAT, list(Y_c_shape)) - ) - output_shapes.append(Y_c_shape) + if rnn_type == "LSTM": + initial_c_np = np.random.uniform(size=(directions, batch_size, hidden_size)).astype( + "float32" + ) + register(initial_c_np, "initial_c") + + if use_peep and rnn_type == "LSTM": + assert use_initial_state == True, "Peepholes require initial state to be specified." + p_np = np.random.uniform(size=(directions, 3 * hidden_size)).astype("float32") + register(p_np, "P") + + return input_names, input_tensors, input_values + + input_names, input_tensors, input_values = get_inputs() + + def get_outputs(): + output_names = [] + graph_outputs = [] + output_shapes = [] + + def register(name, shape, proto_type): + output_names.append(name) + graph_outputs.append(helper.make_tensor_value_info(name, proto_type, list(shape))) + output_shapes.append(list(shape)) + + register("Y", [seq_length, directions, batch_size, hidden_size], TensorProto.FLOAT) + register("Y_h", [directions, batch_size, hidden_size], TensorProto.FLOAT) + + if rnn_type == "LSTM": + register("Y_c", [directions, batch_size, hidden_size], TensorProto.FLOAT) + + return output_names, graph_outputs, output_shapes + + output_names, graph_outputs, output_shapes = get_outputs() rnn_node = helper.make_node( - rnn_type, inputs=input_names, outputs=outputs, hidden_size=hidden_size + rnn_type, inputs=input_names, outputs=output_names, hidden_size=hidden_size ) if activations is not None: activations_attr = helper.make_attribute("activations", activations) rnn_node.attribute.append(activations_attr) + if directions == 2: + direction_attr = helper.make_attribute("direction", "bidirectional") + rnn_node.attribute.append(direction_attr) if alphas is not None: alphas_attr = helper.make_attribute("activation_alpha", alphas) rnn_node.attribute.append(alphas_attr) @@ -3280,169 +3301,247 @@ def verify_rnn( @tvm.testing.uses_gpu def test_lstm(): - # No bias. - verify_rnn( - seq_length=2, batch_size=1, input_size=16, hidden_size=32, use_bias=False, rnn_type="LSTM" - ) - # large batch. - verify_rnn( - seq_length=4, batch_size=8, input_size=16, hidden_size=32, use_bias=True, rnn_type="LSTM" - ) - # Non power of two. - verify_rnn( - seq_length=3, batch_size=3, input_size=16, hidden_size=40, use_bias=True, rnn_type="LSTM" - ) - # Long sequence. - verify_rnn( - seq_length=8, batch_size=1, input_size=16, hidden_size=32, use_bias=True, rnn_type="LSTM" - ) - # Large hidden. - verify_rnn( - seq_length=2, batch_size=1, input_size=16, hidden_size=128, use_bias=True, rnn_type="LSTM" - ) - # Large input. - verify_rnn( - seq_length=2, batch_size=1, input_size=64, hidden_size=32, use_bias=True, rnn_type="LSTM" - ) - - # Different activation testing. - # Default value hardsigmoid. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "Tanh", "Tanh"], - rnn_type="LSTM", - ) - # Multiple parameterized activations. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "LeakyRelu", "Tanh"], - alphas=[2.0, 0.5], - betas=[0.3], - rnn_type="LSTM", - ) - # All parameterized with new Affine activation. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "LeakyRelu", "Affine"], - alphas=[2.0, 0.5, 0.8], - betas=[0.3, 0.1], - rnn_type="LSTM", - ) - - # Testing with initial state and peepholes - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=True, - use_initial_state=True, - rnn_type="LSTM", - ) - - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=True, - use_initial_state=True, - use_peep=True, - rnn_type="LSTM", - ) + for directions in [1, 2]: + # No bias. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + rnn_type="LSTM", + directions=directions, + ) + # large batch. + verify_rnn( + seq_length=4, + batch_size=8, + input_size=16, + hidden_size=32, + use_bias=True, + rnn_type="LSTM", + directions=directions, + ) + # Non power of two. + verify_rnn( + seq_length=3, + batch_size=3, + input_size=16, + hidden_size=40, + use_bias=True, + rnn_type="LSTM", + directions=directions, + ) + # Long sequence. + verify_rnn( + seq_length=8, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=True, + rnn_type="LSTM", + directions=directions, + ) + # Large hidden. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=128, + use_bias=True, + rnn_type="LSTM", + directions=directions, + ) + # Large input. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=64, + hidden_size=32, + use_bias=True, + rnn_type="LSTM", + directions=directions, + ) + + # Different activation testing. + # Default value hardsigmoid. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + activations=["HardSigmoid", "Tanh", "Tanh"] * directions, + rnn_type="LSTM", + directions=directions, + ) + # Multiple parameterized activations. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + activations=["HardSigmoid", "LeakyRelu", "Tanh"] * directions, + alphas=[2.0, 0.5, 0.0] * directions, + betas=[0.3, 0.0, 0.0] * directions, + rnn_type="LSTM", + directions=directions, + ) + # All parameterized with new Affine activation. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + activations=["HardSigmoid", "LeakyRelu", "Affine"] * directions, + alphas=[2.0, 0.5, 0.8] * directions, + betas=[0.3, 0.1, 0.0] * directions, + rnn_type="LSTM", + directions=directions, + ) + + # Testing with initial state and peepholes + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=True, + use_initial_state=True, + rnn_type="LSTM", + directions=directions, + ) + + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=True, + use_initial_state=True, + use_peep=True, + rnn_type="LSTM", + directions=directions, + ) @tvm.testing.uses_gpu def test_gru(): - # No bias. - verify_rnn( - seq_length=2, batch_size=1, input_size=16, hidden_size=32, use_bias=False, rnn_type="GRU" - ) - # large batch. - verify_rnn( - seq_length=4, - batch_size=8, - input_size=16, - hidden_size=32, - use_bias=True, - rnn_type="GRU", - linear_before_reset=True, - ) - # Non power of two. - verify_rnn( - seq_length=3, batch_size=3, input_size=16, hidden_size=40, use_bias=True, rnn_type="GRU" - ) - # Long sequence. - verify_rnn( - seq_length=8, batch_size=1, input_size=16, hidden_size=32, use_bias=True, rnn_type="GRU" - ) - # Large hidden. - verify_rnn( - seq_length=2, batch_size=1, input_size=16, hidden_size=128, use_bias=True, rnn_type="GRU" - ) - # Large input. - verify_rnn( - seq_length=2, batch_size=1, input_size=64, hidden_size=32, use_bias=True, rnn_type="GRU" - ) - - # Different activation testing. - # Default value hardsigmoid. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "Softsign"], - rnn_type="GRU", - ) - # Multiple parameterized activations. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "LeakyRelu"], - alphas=[2.0, 0.5], - betas=[0.3], - rnn_type="GRU", - ) - # All parameterized with new Affine activation. - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=False, - activations=["HardSigmoid", "Affine"], - alphas=[2.0, 0.8], - betas=[0.3, 0.1], - rnn_type="GRU", - ) - - # Testing with initial state - verify_rnn( - seq_length=2, - batch_size=1, - input_size=16, - hidden_size=32, - use_bias=True, - use_initial_state=True, - rnn_type="GRU", - ) + for directions in [1, 2]: + # No bias. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + rnn_type="GRU", + directions=directions, + ) + # large batch. + verify_rnn( + seq_length=4, + batch_size=8, + input_size=16, + hidden_size=32, + use_bias=True, + rnn_type="GRU", + linear_before_reset=True, + directions=directions, + ) + # Non power of two. + verify_rnn( + seq_length=3, + batch_size=3, + input_size=16, + hidden_size=40, + use_bias=True, + rnn_type="GRU", + directions=directions, + ) + # Long sequence. + verify_rnn( + seq_length=8, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=True, + rnn_type="GRU", + directions=directions, + ) + # Large hidden. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=128, + use_bias=True, + rnn_type="GRU", + directions=directions, + ) + # Large input. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=64, + hidden_size=32, + use_bias=True, + rnn_type="GRU", + directions=directions, + ) + + # Different activation testing. + # Default value hardsigmoid. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + activations=["HardSigmoid", "Softsign"] * directions, + rnn_type="GRU", + directions=directions, + ) + # Multiple parameterized activations. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + activations=["HardSigmoid", "LeakyRelu"] * directions, + alphas=[2.0, 0.5] * directions, + betas=[0.3, 0.0] * directions, + rnn_type="GRU", + directions=directions, + ) + # All parameterized with new Affine activation. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + activations=["HardSigmoid", "Affine"] * directions, + alphas=[2.0, 0.8] * directions, + betas=[0.3, 0.1] * directions, + rnn_type="GRU", + directions=directions, + ) + + # Testing with initial state + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=True, + use_initial_state=True, + rnn_type="GRU", + directions=directions, + ) @tvm.testing.uses_gpu