diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 7a1e98402996..5f961f1ae0e8 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -686,6 +686,46 @@ def unbind(data, axis=0): return _expr.TupleWrapper(_expr.Tuple(ret), selections) +def rnn_cell( + input_seqs, hidden_state, w_inp, w_hid, b_inp=None, b_hid=None, backwards=False, act=_op.tanh +): + """ + Common implementation of RNN cell for all frontends of TVM + + Parameters + ---------- + input_seqs : List[relay.Expr] + The sequence of input tensors + Input tensor should be 2d while issue #8412 is not resolved + Shape = (batch, feature_size) + hidden_state : relay.Expr + Hidden state. shape = (batch_size, hidden_size) + w_inp, w_hid: relay.Expr + weight matrices. shape = (hidden_size, feature_size), (hidden_size, feature_size) + b_inp, b_hid : relay.Expr + bias matrices. The same order of internal parts as for weights. shape = (1 * hidden_size) + backwards : bool + Flag for reverse pass of RNN + act : relay.op + activation function. It is tanh by default. + + Returns + ------- + result : List[relay.Expr], relay.Expr, relay.Expr + The sequence of computed result, final hidden and cell state + """ + outputs_list = [] + for x_t in input_seqs if not backwards else reversed(input_seqs): + xwt = _op.nn.dense(x_t, w_inp) + hwt = _op.nn.dense(hidden_state, w_hid) + if b_inp is not None and b_hid is not None: + xwt += b_inp + hwt += b_hid + hidden_state = act(xwt + hwt) + outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)] + return outputs_list, hidden_state + + def gru_cell( input_seqs, hidden_state, diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b1a760886037..d7e1a5dd1ddb 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -40,7 +40,7 @@ from ..prelude import Prelude, StaticTensorArrayOps from ..ty import Any, TensorType, TupleType from . import qnn_torch -from .common import AttrCvt, get_relay_op, gru_cell, logger +from .common import AttrCvt, get_relay_op, gru_cell, logger, rnn_cell from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value from .common import infer_value_simulated as _infer_value_simulated @@ -2630,6 +2630,191 @@ def flip(self, inputs, input_types): axis = inputs[1] return _op.transform.reverse(data, axis=axis[0]) + def bidir_rnn_cell(self, input_seqs, weights_dicts, act=_op.tanh): + """ + Bidirectional RNN cell + """ + seq_len = len(input_seqs) + forward_outputs, fw_H_t = rnn_cell(input_seqs, **weights_dicts[0], backwards=False, act=act) + + reverse_outputs, rev_H_t = rnn_cell(input_seqs, **weights_dicts[1], backwards=True, act=act) + + final_outputs = [] + for i in range(seq_len): + final_outputs.append( + _op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=-1) + ) + + return final_outputs, _op.stack([fw_H_t, rev_H_t], axis=0) + + def rnn_layers(self, input_data, layer_weights_dicts, bidirectional, act, dropout_p=0.0): + """ + Methods iterates layers for Stacked RNN + """ + layers_num = len(layer_weights_dicts) + # split input sequence to samples set + input_seqs = unbind(input_data, 0) # [seq_num, (batch, feature_size)] + output_hiddens = [] + for i in range(layers_num): + weights_dicts = layer_weights_dicts[i] + # input_seqs shape = [seq_num, (batch, feature_size)] or + # [seq_num, (batch, 2*feature_size)] for bidirectional + if bidirectional: + input_seqs, H_t = self.bidir_rnn_cell(input_seqs, weights_dicts, act=act) + else: + input_seqs, H_t = rnn_cell(input_seqs, **weights_dicts[0], act=act) + + output_hiddens.append(H_t) + + # TODO (yuanfz98): in pytorch implementation train is also checked + # see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339 + # /aten/src/ATen/native/RNN.cpp#L1054 + if dropout_p != 0 and i < layers_num - 1: + # for input in input_seqs: + # input = _op.dropout(input, dropout_p) + raise NotImplementedError("Dropout for GRU has not been supported yet!") + output_hiddens = ( + _op.concatenate(output_hiddens, 0) if bidirectional else _op.stack(output_hiddens, 0) + ) + return _op.stack(input_seqs, 0), output_hiddens + + def rnn(self, inputs, input_types, nonlinearity): + """ + Description of RNN in pytorch: + https://pytorch.org/docs/stable/generated/torch.nn.RNN.html#torch.nn.RNN + Description of inputs: + https://github.com/pytorch/pytorch/blob/736fb7d22cc948b739db2c35aeb5ad4d19aea4f4/torch/overrides.py#L937 + """ + # TODO (yuanfz98): support dropout + assert len(inputs) == 9, "Input of size 9 is expected" + # Unpack inputs, note that if optional and not provided then value will be None. + _X = inputs[0] + # _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size) + + hidden_state = inputs[1] + # Hidden state shape (hidden_layers_num, batch, hidden_size) + + _weights = inputs[2] + # Wi layer[0] shape (hidden_size, feature_size) + # Wh layer[0] shape (hidden_size, hidden_size) + # Bi layer[0] shape (hidden_size) + # Bh layer[0] shape (hidden_size) + + # Wi layer[>0] shape (hidden_size, hidden_size * num_directions) + # Wh layer[>0] shape (hidden_size, hidden_size) + # Bi layer[>0] shape (hidden_size) + # Bh layer[>0] shape (hidden_size) + + # Scalar inputs + has_biases = inputs[3] + num_layers = inputs[4] + dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout + # train = inputs[6] + bidirectional = inputs[7] + batch_first = inputs[8] + + num_directions = 1 + if bidirectional: + num_directions = 2 + + rsd = len(_weights) % num_layers + assert rsd == 0, "The number of weights must be a multiple of the number of layers!" + rsd = (len(_weights) / num_layers) % num_directions + assert ( + rsd == 0 + ), "The number of weights in layer must be a multiple of the number of directions!" + + weights_num = int(len(_weights) / num_layers / num_directions) + if has_biases: + assert weights_num == 4, "The weights number in layer is expected equal to 4" + else: + assert weights_num == 2, "The weights number in layer is expected equal to 2" + if nonlinearity == "tanh": + act = _op.tanh + elif nonlinearity == "relu": + act = _op.nn.relu + assert act, "The nonlinearity is unknown" + X = ( + _op.transpose(_X, (1, 0, 2)) if batch_first else _X + ) # always (seq_num, batch, feature_size) + # TODO (yuanfz98): Which data type should be used? from input or weights? + # Instead of it _infer_type(X).checked_type.dtype can be used + X_dtype = input_types[0] + X_shape = _infer_shape(X) # (seq_num, batch, feature_size) + + hidden_size = int(_infer_shape(_weights[0])[0]) + batch_size = X_shape[1] + + # Initialize hidden states if not provided. + layers_h = [] + hidden_layers_num = num_directions * num_layers + if hidden_state is None: + h_0 = _op.zeros((batch_size, hidden_size), X_dtype) + for i in range(hidden_layers_num): + layers_h.append(h_0) + else: + layers_h = unbind(hidden_state, 0) + + layer_weights_dicts = [] + k = 0 # layer counter + if has_biases: + names = ["hidden_state", "w_inp", "w_hid", "b_inp", "b_hid"] + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of RNN weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_tensors = [layers_h[2 * k], *_weights[i : i + 4]] + fw_weights_dict = dict(zip(names, fw_tensors)) + j = i + weights_num + rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 4]] + rev_weights_dict = dict(zip(names, rev_tensors)) + layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) + k += 1 + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights" + for i in range(0, len(_weights), weights_num): + fw_tensors = [layers_h[k], *_weights[i : i + 4]] + fw_weights_dict = dict(zip(names, fw_tensors)) + layer_weights_dicts.append([fw_weights_dict]) + k += 1 + else: + names = ["hidden_state", "w_inp", "w_hid"] + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of RNN weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_tensors = [layers_h[2 * k], *_weights[i : i + 2]] + fw_weights_dict = dict(zip(names, fw_tensors)) + j = i + weights_num + rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 2]] + rev_weights_dict = dict(zip(names, rev_tensors)) + layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) + k += 1 + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of RNN weights" + for i in range(0, len(_weights), weights_num): + fw_tensors = [layers_h[k], *_weights[i : i + 2]] + fw_weights_dict = dict(zip(names, fw_tensors)) + layer_weights_dicts.append([fw_weights_dict]) + k += 1 + assert ( + len(layer_weights_dicts) == num_layers and k == num_layers + ), "For stacked RNN number of weights sets should be the same as number of layers!" + output, out_hidden_state = self.rnn_layers( + X, + layer_weights_dicts, + bidirectional, + act, + dropout_p=dropout_p, + ) + + # output shape = (seq_num, batch, hidden_size) or + # (seq_num, batch, 2*feature_size) for bidirectional + if batch_first: + output = _op.transpose(output, (1, 0, 2)) + + return (output, out_hidden_state) + def bidir_gru_cell( self, input_seqs, @@ -3442,6 +3627,8 @@ def create_convert_map(self): "aten::l1_loss": self.l1_loss, "aten::mse_loss": self.mse_loss, "aten::flip": self.flip, + "aten::rnn_tanh": functools.partial(self.rnn, nonlinearity="tanh"), + "aten::rnn_relu": functools.partial(self.rnn, nonlinearity="relu"), "aten::gru": self.gru, "aten::lstm": self.lstm, "aten::all": functools.partial(self.all_any_common, _op.all), diff --git a/tests/python/frontend/pytorch/test_rnns.py b/tests/python/frontend/pytorch/test_rnns.py index b0180a7a99d4..fba55b9c4c8f 100644 --- a/tests/python/frontend/pytorch/test_rnns.py +++ b/tests/python/frontend/pytorch/test_rnns.py @@ -40,6 +40,10 @@ seqs_length = 2 batch_size = 2 +##RNN parameters +rnn_feature_size = 8 +rnn_hidden_size = 16 + class RNN_Model(nn.Module): """ @@ -93,6 +97,72 @@ def get_tvm_inputs(self, dtype): raise NotImplementedError("subclasses must override get_tvm_inputs(dtype)!") +class RNN_Model_Impl(RNN_Model): + def __init__( + self, + seq_len=seqs_length, + batch_size=batch_size, + feature_size=rnn_feature_size, + hidden_size=rnn_hidden_size, + batch_first=False, + layer_num=1, + bidirectional=False, + use_bias=True, + rnd_weights_init=False, + nonlinearity="tanh", + dropout=0.0, + ): + super().__init__() + # Shapes + self.shape = [seq_len, batch_size, feature_size] + if batch_first: + self.shape = [batch_size, seq_len, feature_size] + layers_num = 2 * layer_num if bidirectional else layer_num + self.h0_shape = [layers_num, batch_size, hidden_size] + # Dummy inputs + self.dummy_inputs = (torch.rand(self.shape), torch.zeros(self.h0_shape)) + + self.model = nn.RNN( + input_size=feature_size, + hidden_size=hidden_size, + num_layers=layer_num, + nonlinearity=nonlinearity, + bias=use_bias, + batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional, + ) + + if rnd_weights_init: + self.gen_rnd_weights() + + def gen_rnd_weights(self): + super().gen_rnd_weights() + + def get_dummy_inputs(self): + return self.dummy_inputs + + def get_input_names(self): + return ["input", "h0"] + + def get_shape_desc(self, frontend_type): + shape_desc = None + if frontend_type == "pt": # PyTorch + shape_desc = [("input", self.shape)] + elif frontend_type == "onnx": # ONNX + shape_desc = { + "input": self.shape, + "h0": self.h0_shape, + } + return shape_desc + + def get_tvm_inputs(self, dtype): + return { + "input": tvm.nd.array(self.dummy_inputs[0].numpy().astype(dtype)), + "h0": tvm.nd.array(self.dummy_inputs[1].numpy().astype(dtype)), + } + + class GRU_Model(RNN_Model): def __init__( self, @@ -331,6 +401,10 @@ def get_model( args["bidirectional"] = True if "s" in rnn_mod: args["layer_num"] = num_layers + if "tanh" in rnn_mod: + args["nonlinearity"] = "tanh" + if "relu" in rnn_mod: + args["nonlinearity"] = "relu" if rnn_type == "GRU": RNN_Model_selector = GRU_Model @@ -338,6 +412,8 @@ def get_model( RNN_Model_selector = LSTM_Model if "p" in rnn_mod: args["proj_size"] = lstm_projection_size + elif rnn_type == "RNN": + RNN_Model_selector = RNN_Model_Impl return RNN_Model_selector(**args) @@ -425,6 +501,9 @@ def test_rnns(): for mod_type in ["uni", "s", "b", "sb"]: check_rnn("LSTM", mod_type, target, dev) + for mod_type in ["uni", "s", "b", "sb", "tanh", "relu"]: + check_rnn("RNN", mod_type, target, dev) + if __name__ == "__main__": test_rnns()