From 8e5e9e9c44ab9b0740685e71f88dbca42fa90b74 Mon Sep 17 00:00:00 2001 From: Gustavo Romero Date: Mon, 24 May 2021 03:13:33 +0000 Subject: [PATCH] Please linter --- python/tvm/driver/tvmc/compiler.py | 6 +++++- tests/python/driver/tvmc/conftest.py | 8 ++++++-- tests/python/driver/tvmc/test_mlf.py | 14 +++++++------- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 560c3bf48a59..5e4dde91b6fa 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -271,7 +271,11 @@ def compile_model( # Create a new tvmc model package object from the graph definition. package_path = tvmc_model.export_package( - graph_module, package_path, cross, cross_options, output_format, + graph_module, + package_path, + cross, + cross_options, + output_format, ) # Write dumps to file. diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index ad81912c4eb3..9c0d8fa8911e 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -41,7 +41,7 @@ def download_and_untar(model_url, model_sub_path, temp_dir): return os.path.join(temp_dir, model_sub_path) -def get_sample_compiled_module(target_dir, package_filename, output_format='so'): +def get_sample_compiled_module(target_dir, package_filename, output_format="so"): """Support function that returns a TFLite compiled module""" base_url = "https://storage.googleapis.com/download.tensorflow.org/models" model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" @@ -53,7 +53,9 @@ def get_sample_compiled_module(target_dir, package_filename, output_format='so') tvmc_model = tvmc.frontends.load_model(model_file) return tvmc.compiler.compile_model( - tvmc_model, target="llvm", package_path=os.path.join(target_dir, package_filename), + tvmc_model, + target="llvm", + package_path=os.path.join(target_dir, package_filename), output_format=output_format, ) @@ -182,6 +184,7 @@ def tflite_compiled_model(tmpdir_factory): target_dir = tmpdir_factory.mktemp("data") return get_sample_compiled_module(target_dir, "mock.tar") + @pytest.fixture(scope="session") def tflite_compiled_model_mlf(tmpdir_factory): @@ -199,6 +202,7 @@ def tflite_compiled_model_mlf(tmpdir_factory): target_dir = tmpdir_factory.mktemp("data") return get_sample_compiled_module(target_dir, "mock.tar", "mlf") + @pytest.fixture(scope="session") def imagenet_cat(tmpdir_factory): tmpdir_name = tmpdir_factory.mktemp("data") diff --git a/tests/python/driver/tvmc/test_mlf.py b/tests/python/driver/tvmc/test_mlf.py index 329cf743b85a..fb7162723b51 100644 --- a/tests/python/driver/tvmc/test_mlf.py +++ b/tests/python/driver/tvmc/test_mlf.py @@ -40,21 +40,20 @@ def test_tvmc_cl_compile_run_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory): assert os.path.exists(output_file), "Could not find the exported MLF archive." # Run the MLF archive. It must fail since it's only supported on micro targets. - tvmc_cmd = ( - f"tvmc run {output_file}" - ) + tvmc_cmd = f"tvmc run {output_file}" tvmc_args = tvmc_cmd.split(" ")[1:] exit_code = _main(tvmc_args) on_error = "Trying to run a MLF archive must fail because it's only supported on micro targets." assert exit_code != 0, on_error + def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory): pytest.importorskip("tflite") tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) mod, params = tvmc_model.mod, tvmc_model.params - graph_module = tvm.relay.build(mod, target='llvm', params=params) + graph_module = tvm.relay.build(mod, target="llvm", params=params) output_dir = tmpdir_factory.mktemp("mlf") output_file = os.path.join(output_dir, "mock.tar") @@ -64,7 +63,7 @@ def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory): executor_factory=graph_module, package_path=output_file, cross=None, - output_format='mlf', + output_format="mlf", ) assert os.path.exists(output_file), "Could not find the exported MLF archive." @@ -75,13 +74,14 @@ def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory): tvmc_model.export_package( executor_factory=graph_module, package_path=output_file, - cross='cc', - output_format='mlf', + cross="cc", + output_format="mlf", ) expected_reason = "Specifying the MLF output and a cross compiler is not supported." on_error = "A TVMCException was caught but its reason is not the expected one." assert str(exp.value) == expected_reason, on_error + def test_tvmc_import_package_mlf(tflite_compiled_model_mlf): # Compile and export a model to a MLF archive so it can be imported. exported_tvmc_package = tflite_compiled_model_mlf