From 8ecf2905aed6e887f0cebdfd2b3dcce1cb140b44 Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Mon, 24 May 2021 21:39:17 -0700 Subject: [PATCH] [Frontend] [Tensorflow2] Added test infrastructure for TF2 frozen models (#8074) * added test infrastructure for frozen TF2 models * linting with black * removing some comments * change in comment in sequential test * addressed the comments * refactored to place vmobj_to_list in a common file * Added helper function in python/tvm/relay/testing/tf.py Co-authored-by: David Huang Co-authored-by: Rohan Mukherjee Co-authored-by: Xiao * Refactor tf according to CI error Co-authored-by: David Huang Co-authored-by: Rohan Mukherjee Co-authored-by: Xiao * Added docstring Co-authored-by: David Huang Co-authored-by: Rohan Mukherjee Co-authored-by: Xiao * removing print Co-authored-by: David Huang Co-authored-by: Xiao --- python/tvm/relay/testing/tf.py | 42 ++ tests/python/frontend/tensorflow2/common.py | 105 +++++ .../tensorflow2/test_functional_models.py | 361 ++++++++++++++++++ .../tensorflow2/test_sequential_models.py | 113 ++++++ 4 files changed, 621 insertions(+) create mode 100644 tests/python/frontend/tensorflow2/common.py create mode 100644 tests/python/frontend/tensorflow2/test_functional_models.py create mode 100644 tests/python/frontend/tensorflow2/test_sequential_models.py diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index d20c0e0ab9dd..9fb3f1102137 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -28,6 +28,8 @@ # Tensorflow imports import tensorflow as tf from tensorflow.core.framework import graph_pb2 + +import tvm from tvm.contrib.download import download_testdata try: @@ -73,6 +75,46 @@ def convert_to_list(x): return x +def vmobj_to_list(o): + """Converts TVM objects returned by VM execution to Python List. + + Parameters + ---------- + o : Obj + VM Object as output from VM runtime executor. + + Returns + ------- + result : list + Numpy objects as list with equivalent values to the input object. + + """ + + if isinstance(o, tvm.nd.NDArray): + result = [o.asnumpy()] + elif isinstance(o, tvm.runtime.container.ADT): + result = [] + for f in o: + result.extend(vmobj_to_list(f)) + elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): + if o.constructor.name_hint == "Cons": + tl = vmobj_to_list(o.fields[1]) + hd = vmobj_to_list(o.fields[0]) + hd.extend(tl) + result = hd + elif o.constructor.name_hint == "Nil": + result = [] + elif "tensor_nil" in o.constructor.name_hint: + result = [0] + elif "tensor" in o.constructor.name_hint: + result = [o.fields[0].asnumpy()] + else: + raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint) + else: + raise RuntimeError("Unknown object type: %s" % type(o)) + return result + + def AddShapesToGraphDef(session, out_node): """Add shapes attribute to nodes of the graph. Input graph here is the default graph in context. diff --git a/tests/python/frontend/tensorflow2/common.py b/tests/python/frontend/tensorflow2/common.py new file mode 100644 index 000000000000..e30ee7b0c993 --- /dev/null +++ b/tests/python/frontend/tensorflow2/common.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except +# pylint: disable=import-outside-toplevel, redefined-builtin +"""TF2 to relay converter test utilities""" + +import tvm +from tvm import relay + +from tvm.runtime.vm import VirtualMachine +import tvm.contrib.graph_executor as runtime +from tvm.relay.frontend.tensorflow import from_tensorflow + +import tvm.testing +from tvm.relay.testing.tf import vmobj_to_list as vmobj_to_list + +import tensorflow as tf +from tensorflow.python.eager.def_function import Function + + +def run_tf_code(func, input_): + if type(func) is Function: + out = func(input_) + if isinstance(out, list): + a = [x.numpy() for x in out] + else: + a = [out.numpy()] + else: + a = func(tf.constant(input_)) + if type(a) is dict: + a = [x.numpy() for x in a.values()] + elif type(a) is list: + a = [x.numpy() for x in a] + else: + a = a.numpy() + return a + + +def compile_graph_executor(mod, params, target="llvm", target_host="llvm", opt_level=3): + with tvm.transform.PassContext(opt_level): + lib = relay.build(mod, target=target, target_host=target_host, params=params) + return lib + + +def compile_vm(mod, params, target="llvm", target_host="llvm", opt_level=3, disabled_pass=None): + with tvm.transform.PassContext(opt_level, disabled_pass=disabled_pass): + vm_exec = relay.vm.compile(mod, target, target_host, params=params) + return vm_exec + + +def run_vm(vm_exec, input_, ctx=tvm.cpu(0)): + vm = VirtualMachine(vm_exec, ctx) + _out = vm.invoke("main", input_) + return vmobj_to_list(_out) + + +def run_graph_executor(lib, input_, ctx=tvm.cpu(0)): + mod = runtime.GraphModule(lib["default"](ctx)) + mod.set_input(0, input_) + mod.run() + return [mod.get_output(0).asnumpy()] + + +def compare_tf_tvm(gdef, input_, output_, runtime="vm", output_tensors=None): + """compare tf and tvm execution for the same input. + + Parameters + ---------- + gdef: TF2 graph def extracted to be fed into from_tensorflow parser. + (https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) + + input_: a single numpy array object + + output_: the expected output from TF to match TVM output with + + runtime: choose TVM runtime; either "vm" for VirtualMachine or "graph" for GraphExecutor + + output_tensors : List of output tensor names (Optional) + if not specified then the last node is assumed as graph output. + """ + mod, params = from_tensorflow(gdef, outputs=output_tensors) + if runtime == "vm": + exec_ = compile_vm(mod, params) + tvm_out = run_vm(exec_, input_) + elif runtime == "graph": + lib = compile_graph_executor(mod, params) + tvm_out = run_graph_executor(lib, input_) + else: + raise RuntimeError("Runtime input not supported: %s" % runtime) + + tvm.testing.assert_allclose(output_, tvm_out, atol=1e-5) diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py new file mode 100644 index 000000000000..40d42a28025a --- /dev/null +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -0,0 +1,361 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except +# pylint: disable=import-outside-toplevel, redefined-builtin +"""TF2 to relay converter test: tests basic examples""" + +import tempfile +import tensorflow as tf +import numpy as np +import pytest +from common import compare_tf_tvm +from common import run_tf_code + + +def _function_graph(TestClass): + f = TestClass().func + gdef = f.get_concrete_function().graph.as_graph_def() + gdef_ops = list(set([n.op for n in gdef.node])) + input_ = TestClass().get_input() + output = run_tf_code(f, input_) + return gdef, input_, output + + +def _model_graph(TestClass): + model = TestClass() + with tempfile.TemporaryDirectory() as model_path: + tf.saved_model.save(model, model_path) + imported = tf.saved_model.load(model_path) + + f = imported.signatures["serving_default"] + gdef = f.graph.as_graph_def(add_shapes=True) + + input_ = model.get_input() + output = run_tf_code(f, input_) + return gdef, input_, output + + +def run_all(TestClass): + def run_func_graph(TestClass, runtime="vm"): + compare_tf_tvm(*_function_graph(TestClass), runtime=runtime) + + def run_model_graph(TestClass): + compare_tf_tvm(*_model_graph(TestClass), runtime="vm") + + run_model_graph(TestClass) + for runtime_ in ["vm", "graph"]: + run_func_graph(TestClass, runtime=runtime_) + + +def test_add_one(): + class AddOne(tf.Module): + """ simple function to test x=x+1; scalar as input""" + + def get_input(self): + return np.array(1.0, dtype="float32") + + @tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)]) + def func(self, x): + return x + 1 + + run_all(AddOne) + + +def test_add_one_2d(): + class AddOne2D(tf.Module): + """2D array as input""" + + def get_input(self): + return np.ones((2, 2), dtype="float32") + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) + def func(self, x): + return x + 1 + + run_all(AddOne2D) + + +def test_add_one_2d_constant(): + class AddOne2DConstant(tf.Module): + """2D array as input with 2D constant as well; 2D constant stored in params after convert""" + + def get_input(self): + return np.ones((2, 2), dtype="float32") + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) + def func(self, x): + return x + np.ones((2, 2), dtype="float32") + + run_all(AddOne2DConstant) + + +def test_sub_one_2d_constant(): + class SubOne2DConstant(tf.Module): + """2D array as input with 2D constant as well; 2D constant stored in params after convert""" + + def get_input(self): + return np.ones((2, 2), dtype="float32") + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) + def func(self, x): + return x - np.ones((2, 2), dtype="float32") + + run_all(SubOne2DConstant) + + +def test_mul_one_2d_constant(): + class MulOne2DConstant(tf.Module): + """2D array as input with 2D constant as well; 2D constant stored in params after convert""" + + def get_input(self): + return np.ones((2, 2), dtype="float32") + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) + def func(self, x): + return x * np.ones((2, 2), dtype="float32") + + run_all(MulOne2DConstant) + + +def test_div_one_2d_constant(): + class DivOne2DConstant(tf.Module): + """2D array as input with 2D constant as well; 2D constant stored in params after convert""" + + def get_input(self): + return np.ones((2, 2), dtype="float32") + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) + def func(self, x): + return x / np.ones((2, 2), dtype="float32") + + run_all(DivOne2DConstant) + + +def test_strided_slice(): + class StridedSlice(tf.Module): + def get_input(self): + return np.ones((3, 2, 3), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(3, 2, 3), dtype=tf.float32)]) + def func(self, x): + return tf.strided_slice(x, [1, 0, 0], [2, 1, 3], [1, 1, 1]) + + run_all(StridedSlice) + + +def test_split(): + class Split(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + a, b, c = tf.split(x, 3, axis=1) + return tf.raw_ops.Pack(values=[a, b, c], axis=1) + + run_all(Split) + + +def test_shape(): + class Shape(tf.Module): + def get_input(self): + return np.ones((3, 2, 3), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(3, 2, 3), dtype=tf.float32)]) + def func(self, x): + a = tf.ones_like(tf.raw_ops.Shape(input=x), dtype=tf.float32) + return a + x + + run_all(Shape) + + +def test_pack(): + class Pack(tf.Module): + def get_input(self): + return np.ones((2, 3), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) + def func(self, x): + return tf.raw_ops.Pack(values=[x, x], axis=0) + + run_all(Pack) + + +def test_max(): + class Maximum(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + a, b = tf.split(x, 2, axis=1) + return tf.math.maximum(a, b, name=None) + + run_all(Maximum) + + +def test_less(): + class Less(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + a, b = tf.split(x, 2, axis=1) + return tf.math.less(a, b, name=None) + + run_all(Less) + + +def test_equal(): + class Equal(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + a, b = tf.split(x, 2, axis=1) + return tf.math.equal(a, b, name=None) + + run_all(Equal) + + +def test_cast(): + class Cast(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + return tf.cast(x, tf.int32) + + run_all(Cast) + + +def test_expand_dims(): + class ExpandDims(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + return tf.expand_dims(x, axis=2) + + run_all(ExpandDims) + + +def test_transpose(): + class Transpose(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + x = tf.expand_dims(x, axis=2) + return tf.transpose(x, perm=[0, 2, 1]) + + run_all(Transpose) + + +def test_reshape(): + class Reshape(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + return tf.reshape(x, (1, 2, 15)) + + run_all(Reshape) + + +def test_tanh(): + class Tanh(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + return tf.math.tanh(x) + + run_all(Tanh) + + +def test_sigmoid(): + class Sigmoid(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + return tf.math.sigmoid(x) + + run_all(Sigmoid) + + +def test_relu(): + class Relu(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + return tf.nn.relu(x) + + run_all(Relu) + + +def test_floor(): + class Floor(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + return tf.math.floor(x) + + run_all(Floor) + + +def test_floor_mod(): + class FloorMod(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + a, b = tf.split(x, 2, axis=1) + return tf.math.floormod(a, b) + + run_all(FloorMod) + + +def test_concat_v2(): + class ConcatV2(tf.Module): + def get_input(self): + return np.ones((1, 30), dtype=np.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + a, b, c = tf.split(x, 3, axis=1) + return tf.raw_ops.ConcatV2(values=[a, b, c], axis=1) + + run_all(ConcatV2) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/frontend/tensorflow2/test_sequential_models.py b/tests/python/frontend/tensorflow2/test_sequential_models.py new file mode 100644 index 000000000000..394a49d0f2e9 --- /dev/null +++ b/tests/python/frontend/tensorflow2/test_sequential_models.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except +# pylint: disable=import-outside-toplevel, redefined-builtin +"""TF2 to relay converter test: testing models built with tf.keras.Sequential()""" + +import tempfile +import numpy as np +import pytest +import tensorflow as tf +from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 + +from common import compare_tf_tvm +from common import run_tf_code + + +def run_sequential_model(model_fn, input_shape): + def get_input(shape): + _input = np.random.uniform(0, 1, shape).astype(dtype="float32") + return _input + + def save_and_reload(_model): + with tempfile.TemporaryDirectory() as model_path: + tf.saved_model.save(_model, model_path) + loaded = tf.saved_model.load(model_path) + func = loaded.signatures["serving_default"] + frozen_func = convert_variables_to_constants_v2(func) + return frozen_func + + def model_graph(model, input_shape): + _input = get_input(input_shape) + f = save_and_reload(model(input_shape)) + _output = run_tf_code(f, _input) + gdef = f.graph.as_graph_def(add_shapes=True) + return gdef, _input, _output + + compare_tf_tvm(*model_graph(model_fn, input_shape), runtime="vm") + + +def test_dense_model(): + def dense_model(input_shape, num_units=128): + return tf.keras.Sequential( + [tf.keras.layers.Flatten(input_shape=input_shape[1:]), tf.keras.layers.Dense(num_units)] + ) + + run_sequential_model(dense_model, input_shape=(1, 28, 28)) + + +def test_mnist_model(): + def mnist_model(input_shape): + return tf.keras.Sequential( + [ + tf.keras.layers.Flatten(input_shape=input_shape[1:]), + tf.keras.layers.Dense(128, activation="relu"), + tf.keras.layers.Dense(10), + ] + ) + + run_sequential_model(mnist_model, input_shape=(1, 28, 28)) + + +def test_conv2d_model(): + def conv2d_model(input_shape, kernel=(3, 3), filters=16): + model = tf.keras.Sequential( + [ + tf.keras.layers.Input(shape=input_shape[1:], batch_size=1), + tf.keras.layers.Conv2D(filters, kernel), + ] + ) + return model + + run_sequential_model(conv2d_model, input_shape=(1, 32, 32, 3)) + + +def test_maxpool_model(): + def maxpool_model(input_shape, pool_size=(2, 2)): + model = tf.keras.Sequential( + [tf.keras.layers.MaxPool2D(pool_size=pool_size, input_shape=input_shape[1:])] + ) + return model + + run_sequential_model(maxpool_model, input_shape=(1, 32, 32, 3)) + + +def test_maxpool_batchnorm_model(): + def maxpool_batchnorm_model(input_shape, pool_size=(2, 2)): + model = tf.keras.Sequential( + [ + tf.keras.layers.MaxPool2D(pool_size=pool_size, input_shape=input_shape[1:]), + tf.keras.layers.BatchNormalization(), + ] + ) + return model + + run_sequential_model(maxpool_batchnorm_model, input_shape=(1, 32, 32, 3)) + + +if __name__ == "__main__": + pytest.main([__file__])