diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 2c1845e868de..a78ad294b770 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -63,7 +63,7 @@ def partition_for_arm_compute_lib(mod, params=None): [ transform.InferType(), transform.MergeComposite(arm_compute_lib_pattern_table()), - transform.AnnotateTarget("arm_compute_lib"), + transform.AnnotateTarget("arm_compute_lib", False), transform.PartitionGraph(), ] ) diff --git a/tests/python/contrib/test_arm_compute_lib/test_network.py b/tests/python/contrib/test_arm_compute_lib/test_network.py index 4efae487f220..898446b32ed9 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_network.py +++ b/tests/python/contrib/test_arm_compute_lib/test_network.py @@ -152,7 +152,32 @@ def get_model(): ) +def test_squeezenet(): + Device.load("test_config.json") + + if skip_runtime_test(): + return + + import tvm.relay.testing.tf as tf_testing + + device = Device() + + def get_model(): + model_path = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz", + "squeezenet.tflite", + ) + inputs = {"Placeholder": ((1, 224, 224, 3), "float32")} + mod, params = _get_tflite_model(model_path, inputs_dict=inputs) + return mod, params, inputs + + _build_and_run_network( + *get_model(), device=device, tvm_ops=10, acl_partitions=30, atol=8, rtol=0 + ) + + if __name__ == "__main__": test_vgg16() test_mobilenet() test_quantized_mobilenet() + test_squeezenet()