Skip to content

Commit

Permalink
[Frontend] [Tensorflow2] Added test infrastructure for TF2 frozen mod…
Browse files Browse the repository at this point in the history
…els (#8074)

* added test infrastructure for frozen TF2 models

* linting with black

* removing some comments

* change in comment in sequential test

* addressed the comments

* refactored to place vmobj_to_list in a common file

* Added helper function in python/tvm/relay/testing/tf.py

Co-authored-by: David Huang <davhuan@amazon.com>
Co-authored-by: Rohan Mukherjee <mukrohan@amazon.com>
Co-authored-by: Xiao <weix@amazon.com>

* Refactor tf according to CI error

Co-authored-by: David Huang <davhuan@amazon.com>
Co-authored-by: Rohan Mukherjee <mukrohan@amazon.com>
Co-authored-by: Xiao <weix@amazon.com>

* Added docstring

Co-authored-by: David Huang <davhuan@amazon.com>
Co-authored-by: Rohan Mukherjee <mukrohan@amazon.com>
Co-authored-by: Xiao <weix@amazon.com>

* removing print

Co-authored-by: David Huang <davhuan@amazon.com>
Co-authored-by: Xiao <weix@amazon.com>
  • Loading branch information
3 people authored May 25, 2021
1 parent b64466a commit 65cd19f
Show file tree
Hide file tree
Showing 4 changed files with 621 additions and 0 deletions.
42 changes: 42 additions & 0 deletions python/tvm/relay/testing/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
# Tensorflow imports
import tensorflow as tf
from tensorflow.core.framework import graph_pb2

import tvm
from tvm.contrib.download import download_testdata

try:
Expand Down Expand Up @@ -73,6 +75,46 @@ def convert_to_list(x):
return x


def vmobj_to_list(o):
"""Converts TVM objects returned by VM execution to Python List.
Parameters
----------
o : Obj
VM Object as output from VM runtime executor.
Returns
-------
result : list
Numpy objects as list with equivalent values to the input object.
"""

if isinstance(o, tvm.nd.NDArray):
result = [o.asnumpy()]
elif isinstance(o, tvm.runtime.container.ADT):
result = []
for f in o:
result.extend(vmobj_to_list(f))
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == "Cons":
tl = vmobj_to_list(o.fields[1])
hd = vmobj_to_list(o.fields[0])
hd.extend(tl)
result = hd
elif o.constructor.name_hint == "Nil":
result = []
elif "tensor_nil" in o.constructor.name_hint:
result = [0]
elif "tensor" in o.constructor.name_hint:
result = [o.fields[0].asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
else:
raise RuntimeError("Unknown object type: %s" % type(o))
return result


def AddShapesToGraphDef(session, out_node):
"""Add shapes attribute to nodes of the graph.
Input graph here is the default graph in context.
Expand Down
105 changes: 105 additions & 0 deletions tests/python/frontend/tensorflow2/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
# pylint: disable=import-outside-toplevel, redefined-builtin
"""TF2 to relay converter test utilities"""

import tvm
from tvm import relay

from tvm.runtime.vm import VirtualMachine
import tvm.contrib.graph_executor as runtime
from tvm.relay.frontend.tensorflow import from_tensorflow

import tvm.testing
from tvm.relay.testing.tf import vmobj_to_list as vmobj_to_list

import tensorflow as tf
from tensorflow.python.eager.def_function import Function


def run_tf_code(func, input_):
if type(func) is Function:
out = func(input_)
if isinstance(out, list):
a = [x.numpy() for x in out]
else:
a = [out.numpy()]
else:
a = func(tf.constant(input_))
if type(a) is dict:
a = [x.numpy() for x in a.values()]
elif type(a) is list:
a = [x.numpy() for x in a]
else:
a = a.numpy()
return a


def compile_graph_executor(mod, params, target="llvm", target_host="llvm", opt_level=3):
with tvm.transform.PassContext(opt_level):
lib = relay.build(mod, target=target, target_host=target_host, params=params)
return lib


def compile_vm(mod, params, target="llvm", target_host="llvm", opt_level=3, disabled_pass=None):
with tvm.transform.PassContext(opt_level, disabled_pass=disabled_pass):
vm_exec = relay.vm.compile(mod, target, target_host, params=params)
return vm_exec


def run_vm(vm_exec, input_, ctx=tvm.cpu(0)):
vm = VirtualMachine(vm_exec, ctx)
_out = vm.invoke("main", input_)
return vmobj_to_list(_out)


def run_graph_executor(lib, input_, ctx=tvm.cpu(0)):
mod = runtime.GraphModule(lib["default"](ctx))
mod.set_input(0, input_)
mod.run()
return [mod.get_output(0).asnumpy()]


def compare_tf_tvm(gdef, input_, output_, runtime="vm", output_tensors=None):
"""compare tf and tvm execution for the same input.
Parameters
----------
gdef: TF2 graph def extracted to be fed into from_tensorflow parser.
(https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
input_: a single numpy array object
output_: the expected output from TF to match TVM output with
runtime: choose TVM runtime; either "vm" for VirtualMachine or "graph" for GraphExecutor
output_tensors : List of output tensor names (Optional)
if not specified then the last node is assumed as graph output.
"""
mod, params = from_tensorflow(gdef, outputs=output_tensors)
if runtime == "vm":
exec_ = compile_vm(mod, params)
tvm_out = run_vm(exec_, input_)
elif runtime == "graph":
lib = compile_graph_executor(mod, params)
tvm_out = run_graph_executor(lib, input_)
else:
raise RuntimeError("Runtime input not supported: %s" % runtime)

tvm.testing.assert_allclose(output_, tvm_out, atol=1e-5)
Loading

0 comments on commit 65cd19f

Please sign in to comment.