Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] [Tensorflow2] Added test infrastructure for TF2 frozen models #8074

Merged
merged 10 commits into from
May 25, 2021
42 changes: 42 additions & 0 deletions python/tvm/relay/testing/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
# Tensorflow imports
import tensorflow as tf
from tensorflow.core.framework import graph_pb2

import tvm
from tvm.contrib.download import download_testdata

try:
Expand Down Expand Up @@ -73,6 +75,46 @@ def convert_to_list(x):
return x


def vmobj_to_list(o):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this copied from somewhere? We can add this to testing to remove the other places

Copy link
Contributor Author

@rohanmukh rohanmukh May 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from tests/python/frontend/tensorflow/test_forward.py
Do you mean I should remove it from there and import its invocations from here (i.e. relay.testing.tf ) in this PR? Or do you mean that should be done in another PR? Similar code exists in other places as well like test_vm.py, test_tensorrt.py etc.

"""Converts TVM objects returned by VM execution to Python List.

Parameters
----------
o : Obj
VM Object as output from VM runtime executor.

Returns
-------
result : list
Numpy objects as list with equivalent values to the input object.

"""

if isinstance(o, tvm.nd.NDArray):
result = [o.asnumpy()]
elif isinstance(o, tvm.runtime.container.ADT):
result = []
for f in o:
result.extend(vmobj_to_list(f))
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == "Cons":
tl = vmobj_to_list(o.fields[1])
hd = vmobj_to_list(o.fields[0])
hd.extend(tl)
result = hd
elif o.constructor.name_hint == "Nil":
result = []
elif "tensor_nil" in o.constructor.name_hint:
result = [0]
elif "tensor" in o.constructor.name_hint:
result = [o.fields[0].asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
else:
raise RuntimeError("Unknown object type: %s" % type(o))
return result


def AddShapesToGraphDef(session, out_node):
"""Add shapes attribute to nodes of the graph.
Input graph here is the default graph in context.
Expand Down
106 changes: 106 additions & 0 deletions tests/python/frontend/tensorflow2/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
# pylint: disable=import-outside-toplevel, redefined-builtin
"""TF2 to relay converter test utilities"""

import tvm
from tvm import relay

from tvm.runtime.vm import VirtualMachine
import tvm.contrib.graph_executor as runtime
from tvm.relay.frontend.tensorflow import from_tensorflow

import tvm.testing
from tvm.relay.testing.tf import vmobj_to_list as vmobj_to_list

import tensorflow as tf
from tensorflow.python.eager.def_function import Function


def run_tf_code(func, input_):
print(type(func))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove or change to log

if type(func) is Function:
out = func(input_)
if isinstance(out, list):
a = [x.numpy() for x in out]
else:
a = [out.numpy()]
else:
a = func(tf.constant(input_))
if type(a) is dict:
a = [x.numpy() for x in a.values()]
elif type(a) is list:
a = [x.numpy() for x in a]
else:
a = a.numpy()
return a


def compile_graph_executor(mod, params, target="llvm", target_host="llvm", opt_level=3):
with tvm.transform.PassContext(opt_level):
lib = relay.build(mod, target=target, target_host=target_host, params=params)
return lib


def compile_vm(mod, params, target="llvm", target_host="llvm", opt_level=3, disabled_pass=None):
with tvm.transform.PassContext(opt_level, disabled_pass=disabled_pass):
vm_exec = relay.vm.compile(mod, target, target_host, params=params)
return vm_exec


def run_vm(vm_exec, input_, ctx=tvm.cpu(0)):
vm = VirtualMachine(vm_exec, ctx)
_out = vm.invoke("main", input_)
return vmobj_to_list(_out)


def run_graph_executor(lib, input_, ctx=tvm.cpu(0)):
mod = runtime.GraphModule(lib["default"](ctx))
mod.set_input(0, input_)
mod.run()
return [mod.get_output(0).asnumpy()]


def compare_tf_tvm(gdef, input_, output_, runtime="vm", output_tensors=None):
"""compare tf and tvm execution for the same input.

Parameters
----------
gdef: TF2 graph def extracted to be fed into from_tensorflow parser.
(https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)

input_: a single numpy array object

output_: the expected output from TF to match TVM output with

runtime: choose TVM runtime; either "vm" for VirtualMachine or "graph" for GraphExecutor

output_tensors : List of output tensor names (Optional)
if not specified then the last node is assumed as graph output.
"""
mod, params = from_tensorflow(gdef, outputs=output_tensors)
if runtime == "vm":
exec_ = compile_vm(mod, params)
tvm_out = run_vm(exec_, input_)
elif runtime == "graph":
lib = compile_graph_executor(mod, params)
tvm_out = run_graph_executor(lib, input_)
else:
raise RuntimeError("Runtime input not supported: %s" % runtime)

tvm.testing.assert_allclose(output_, tvm_out, atol=1e-5)
Loading