Skip to content

Commit

Permalink
Add unidirectional sequence lstm (apache#11183)
Browse files Browse the repository at this point in the history
* UnidirectionalLSTM added

* fixed missing import

* fixed pylint warnings

* black formatted tflite.py

* corrections according to reviewer comments

* fixed black formatting

* just to trigger the CI again

* assertion now tests that there are exactly 24 input tensors.

* black formatted tflite.py

* added explanatory comment regarding unused imports

* removed unused import

* nothing

* nothing

* added some details in a comment about the differences in unbind regarding to the version in common.py

* improved comment on unbind

* fix of black issue
  • Loading branch information
Sebastian Boblest authored and driazati committed May 27, 2022
1 parent ad484c5 commit 40bde2f
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 5 deletions.
180 changes: 179 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ..backend.name_transforms import sanitize_name
from .common import ExprTable
from .common import infer_shape as _infer_shape
from .common import to_int_list, shape_of
from .common import lstm_cell, to_int_list, shape_of
from .tflite_flexbuffer import FlexBufferDecoder

__all__ = ["from_tflite"]
Expand Down Expand Up @@ -173,6 +173,7 @@ def __init__(self, model, subgraph, exp_tab):
"TRANSPOSE_CONV": self.convert_transpose_conv,
"TRANSPOSE": self.convert_transpose,
"UNPACK": self.convert_unpack,
"UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm,
"WHERE": self.convert_select,
"ZEROS_LIKE": self.convert_zeros_like,
}
Expand Down Expand Up @@ -220,6 +221,41 @@ def check_unsupported_ops(self):
if len(raise_msg) > 0:
raise tvm.error.OpNotImplemented(raise_msg)

def unbind(self, data, axis=1):
"""
This is a modified version compared to the one in common.py.
The onnx version takes a relay.Expr.Call, the tflite
version a TensorWrapper. Also this version by default splits
along axis 1 and not axis 0 as the onnx version.
Parameters
----------
data : tvm.relay.frontend.tflite.TensorWrapper
Input tensor
axis : int
Axis along which tensor is split.
Returns
-------
result : List[relay.Expr]
The sequence of computed tensors
"""
shape = to_int_list(self.get_tensor_shape(data))
if axis >= len(shape):
msg = "Please check input dim, it shouldn't be greater than or equal to rank."
raise AttributeError(msg)

selections = shape[axis]
shape.pop(axis)
timestep = 0 # Reshape to make time step as the first dim
shape.insert(timestep, selections)
res_split = _op.split(
_op.reshape(self.get_expr(data.tensor_idx), tuple(shape)), selections, timestep
)
ret = []
for i in range(selections):
ret.append(_op.squeeze(res_split[i], axis=[timestep]))
return _expr.TupleWrapper(_expr.Tuple(ret), selections)

def convert_op_to_relay(self):
"""Convert TFLite ops to relay ops"""
for op_idx in range(self.subgraph.OperatorsLength()):
Expand Down Expand Up @@ -2715,6 +2751,148 @@ def convert_unpack(self, op):

return squeezed

def convert_unidirectional_sequence_lstm(self, op):
"""Long Short Term Memory for TFLite implementation."""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
"TFlite quantized UNIDIRECTIONALSEQUENCELSTM operator is not supported yet."
)

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 24, "input tensors length should be == 24"

# Extract input tensor from saved model
input_tensor = input_tensors[0]

# Extract tensors from input tensors from saved model
# Input weights
input_input_weights = input_tensors[1]
input_forget_weights = input_tensors[2]
input_cell_weights = input_tensors[3]
input_output_weights = input_tensors[4]
# Recurrent weights
recurrent_input_weights = input_tensors[5]
recurrent_forget_weights = input_tensors[6]
recurrent_cell_weights = input_tensors[7]
recurrent_output_weights = input_tensors[8]
# inputs 9, 10, 11, 16, 17, 20, 21, 22, 23 are not occupied
# there locations are -1 in the flatbuffer
# Bias weights
input_gate_bias = input_tensors[12]
forget_gate_bias = input_tensors[13]
cell_gate_bias = input_tensors[14]
output_gate_bias = input_tensors[15]

# State input
output_state_in = input_tensors[18]
cell_state_in = input_tensors[19]

# Extract output tensor from saved model
output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
X_steps = self.unbind(input_tensor, axis=1)
weights_dict = {}

# hidden_state_weights is equivalent to output_state_in in tflite model
out_state_in_shape = tuple(self.get_tensor_shape(output_state_in))
out_state_in_dtype = self.get_tensor_type_str(output_state_in.tensor.Type())
out_state_in_expr = _op.zeros(out_state_in_shape, dtype=out_state_in_dtype)
weights_dict["hidden_state"] = _op.split(out_state_in_expr, 1)[0]

# cell_state_weights is equivalent to output_state_in tflite model
cell_state_in_shape = tuple(self.get_tensor_shape(cell_state_in))
cell_state_in_dtype = self.get_tensor_type_str(cell_state_in.tensor.Type())
cell_state_in_expr = _op.zeros(cell_state_in_shape, dtype=cell_state_in_dtype)
weights_dict["cell_state"] = _op.split(cell_state_in_expr, 1)[0]

# Process weight matrix of input: w_inp
# Concatenate of [input_input_weight, input_forget_weights,
# input_cell_weights, input_output_weights]
input_input_weights_default_values = self.get_tensor_value(input_input_weights)
input_input_weights_op = _op.split(
_op.const(input_input_weights_default_values.tolist()), 1
)
input_output_weights_default_values = self.get_tensor_value(input_output_weights)
input_output_weights_op = _op.split(
_op.const(input_output_weights_default_values.tolist()), 1
)
input_forget_weights_default_values = self.get_tensor_value(input_forget_weights)
input_forget_weights_op = _op.split(
_op.const(input_forget_weights_default_values.tolist()), 1
)
input_cell_weights_default_values = self.get_tensor_value(input_cell_weights)
input_cell_weights_op = _op.split(_op.const(input_cell_weights_default_values.tolist()), 1)
weights_dict["w_inp"] = _op.concatenate(
[
_op.squeeze(input_input_weights_op[0]),
_op.squeeze(input_forget_weights_op[0]),
_op.squeeze(input_cell_weights_op[0]),
_op.squeeze(input_output_weights_op[0]),
],
axis=0,
)

# Process weight matrix of hidden state:
# w_hid to support lstm_cell function. Not used in tflite
recurrent_input_weights_values = self.get_tensor_value(recurrent_input_weights)
recurrent_input_weights_op = _op.split(
_op.const(recurrent_input_weights_values.tolist()), 1
)
recurrent_output_weights_values = self.get_tensor_value(recurrent_output_weights)
recurrent_output_weights_op = _op.split(
_op.const(recurrent_output_weights_values.tolist()), 1
)
recurrent_forget_weights_values = self.get_tensor_value(recurrent_forget_weights)
recurrent_forget_weights_op = _op.split(
_op.const(recurrent_forget_weights_values.tolist()), 1
)
recurrent_cell_weights_values = self.get_tensor_value(recurrent_cell_weights)
recurrent_cell_weights_op = _op.split(_op.const(recurrent_cell_weights_values.tolist()), 1)
weights_dict["w_hid"] = _op.concatenate(
[
recurrent_input_weights_op[0],
recurrent_forget_weights_op[0],
recurrent_cell_weights_op[0],
recurrent_output_weights_op[0],
],
axis=0,
)

# Process weight matrix of bias: b_inp
input_gate_bias_values = self.get_tensor_value(input_gate_bias)
input_gate_bias_op = _op.split(_op.const(input_gate_bias_values.tolist()), 1)
output_gate_bias_values = self.get_tensor_value(output_gate_bias)
output_gate_bias_op = _op.split(_op.const(output_gate_bias_values.tolist()), 1)
forget_gate_bias_values = self.get_tensor_value(forget_gate_bias)
forget_gate_bias_op = _op.split(_op.const(forget_gate_bias_values.tolist()), 1)
cell_gate_bias_values = self.get_tensor_value(cell_gate_bias)
cell_gate_bias_op = _op.split(_op.const(cell_gate_bias_values.tolist()), 1)
weights_dict["b_inp"] = _op.concatenate(
[
input_gate_bias_op[0],
forget_gate_bias_op[0],
cell_gate_bias_op[0],
output_gate_bias_op[0],
],
axis=0,
)

# Process weight matrix of hidden bias:
# b_hid (with the same shape as b_inp)
gate_bias_dtype = self.get_tensor_type_str(input_gate_bias.tensor.Type())
weights_dict["b_hid"] = _op.split(
_op.const(
np.zeros(_infer_shape(weights_dict["b_inp"]), dtype=gate_bias_dtype),
dtype=gate_bias_dtype,
),
1,
)[0]

outputs, _, _ = lstm_cell(input_seqs=X_steps, **weights_dict)

output = _op.stack(outputs, axis=1)
return output

def convert_batch_to_space_nd(self, op):
"""batch_to_space_nd implementation."""

Expand Down
38 changes: 34 additions & 4 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,7 +1867,7 @@ def tf_function(self, x):
model,
export_dir,
signatures=model.tf_function.get_concrete_function(
tf.TensorSpec(data.shape, tf.float32, name="input"),
tf.TensorSpec(data.shape, tf.float32, name="input")
),
)

Expand Down Expand Up @@ -3759,8 +3759,7 @@ def test_forward_prelu():
np.full((32, 3), 0.2, dtype="float32"),
)
_test_prelu(
np.random.uniform(-5, 5, size=(32, 3)).astype("float32"),
np.full((3), 0.2, dtype="float32"),
np.random.uniform(-5, 5, size=(32, 3)).astype("float32"), np.full((3), 0.2, dtype="float32")
)


Expand Down Expand Up @@ -4693,6 +4692,36 @@ def representative_dataset():
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)


#######################################################################
# Unidirectional Sequence LSTM
# ---------------------
def test_forward_unidirectional_sequence_lstm():
"""Test the UnidirectionalSequenceLSTM TFLite"""
if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
tflite_model_file = download_testdata(
"https://github.com/SebastianBoblestETAS/nn_models/blob/ce49c5de64889493161ca4194a20e0fd5eb707e6/lstm_1_in_3_out_2_ts_4.tflite?raw=true",
"lstm_1_in_3_out_2_ts_4.tflite",
)
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()

data = np.array(
[
[
[0.5488135, 0.71518934, 0.60276335],
[0.5448832, 0.4236548, 0.6458941],
[0.4375872, 0.891773, 0.96366274],
[0.3834415, 0.79172504, 0.5288949],
]
],
dtype="float32",
)

tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, data, "serving_default_input_1:0")
tvm.testing.assert_allclose(tflite_output, tvm_output)


#######################################################################
# Quantized SSD Mobilenet
# -----------------------
Expand Down Expand Up @@ -4930,10 +4959,11 @@ def test_prevent_tensorflow_dynamic_range():
test_forward_leaky_relu()
test_forward_relu_n1_to_1()
test_forward_log_softmax()
test_forward_prelu()
test_forward_fully_connected()
test_forward_l2_normalization()
test_forward_local_response_normalization()
test_forward_prelu()
test_forward_unidirectional_sequence_lstm()

# Elemwise
test_all_elemwise()
Expand Down

0 comments on commit 40bde2f

Please sign in to comment.