From 533f6fe2d405e83fe3f80504313b89577956fe2a Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Thu, 17 Dec 2020 14:26:47 +0000 Subject: [PATCH] [BYOC] [ACL] include_non_call_ops = False ACL codegen now uses AnnotateTarget pass with include_non_call_ops = False to prevent promoting non-call ops under the target of its arguments. Squeezenet unit test added. --- .../tvm/relay/op/contrib/arm_compute_lib.py | 2 +- .../test_arm_compute_lib/test_network.py | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 2c1845e868de..a78ad294b770 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -63,7 +63,7 @@ def partition_for_arm_compute_lib(mod, params=None): [ transform.InferType(), transform.MergeComposite(arm_compute_lib_pattern_table()), - transform.AnnotateTarget("arm_compute_lib"), + transform.AnnotateTarget("arm_compute_lib", False), transform.PartitionGraph(), ] ) diff --git a/tests/python/contrib/test_arm_compute_lib/test_network.py b/tests/python/contrib/test_arm_compute_lib/test_network.py index 4efae487f220..898446b32ed9 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_network.py +++ b/tests/python/contrib/test_arm_compute_lib/test_network.py @@ -152,7 +152,32 @@ def get_model(): ) +def test_squeezenet(): + Device.load("test_config.json") + + if skip_runtime_test(): + return + + import tvm.relay.testing.tf as tf_testing + + device = Device() + + def get_model(): + model_path = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz", + "squeezenet.tflite", + ) + inputs = {"Placeholder": ((1, 224, 224, 3), "float32")} + mod, params = _get_tflite_model(model_path, inputs_dict=inputs) + return mod, params, inputs + + _build_and_run_network( + *get_model(), device=device, tvm_ops=10, acl_partitions=30, atol=8, rtol=0 + ) + + if __name__ == "__main__": test_vgg16() test_mobilenet() test_quantized_mobilenet() + test_squeezenet()