Skip to content

Commit

Permalink
Please linter
Browse files Browse the repository at this point in the history
  • Loading branch information
gromero committed May 24, 2021
1 parent ef89320 commit 8e5e9e9
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
6 changes: 5 additions & 1 deletion python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,11 @@ def compile_model(

# Create a new tvmc model package object from the graph definition.
package_path = tvmc_model.export_package(
graph_module, package_path, cross, cross_options, output_format,
graph_module,
package_path,
cross,
cross_options,
output_format,
)

# Write dumps to file.
Expand Down
8 changes: 6 additions & 2 deletions tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def download_and_untar(model_url, model_sub_path, temp_dir):
return os.path.join(temp_dir, model_sub_path)


def get_sample_compiled_module(target_dir, package_filename, output_format='so'):
def get_sample_compiled_module(target_dir, package_filename, output_format="so"):
"""Support function that returns a TFLite compiled module"""
base_url = "https://storage.googleapis.com/download.tensorflow.org/models"
model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
Expand All @@ -53,7 +53,9 @@ def get_sample_compiled_module(target_dir, package_filename, output_format='so')

tvmc_model = tvmc.frontends.load_model(model_file)
return tvmc.compiler.compile_model(
tvmc_model, target="llvm", package_path=os.path.join(target_dir, package_filename),
tvmc_model,
target="llvm",
package_path=os.path.join(target_dir, package_filename),
output_format=output_format,
)

Expand Down Expand Up @@ -182,6 +184,7 @@ def tflite_compiled_model(tmpdir_factory):
target_dir = tmpdir_factory.mktemp("data")
return get_sample_compiled_module(target_dir, "mock.tar")


@pytest.fixture(scope="session")
def tflite_compiled_model_mlf(tmpdir_factory):

Expand All @@ -199,6 +202,7 @@ def tflite_compiled_model_mlf(tmpdir_factory):
target_dir = tmpdir_factory.mktemp("data")
return get_sample_compiled_module(target_dir, "mock.tar", "mlf")


@pytest.fixture(scope="session")
def imagenet_cat(tmpdir_factory):
tmpdir_name = tmpdir_factory.mktemp("data")
Expand Down
14 changes: 7 additions & 7 deletions tests/python/driver/tvmc/test_mlf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,20 @@ def test_tvmc_cl_compile_run_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory):
assert os.path.exists(output_file), "Could not find the exported MLF archive."

# Run the MLF archive. It must fail since it's only supported on micro targets.
tvmc_cmd = (
f"tvmc run {output_file}"
)
tvmc_cmd = f"tvmc run {output_file}"
tvmc_args = tvmc_cmd.split(" ")[1:]
exit_code = _main(tvmc_args)
on_error = "Trying to run a MLF archive must fail because it's only supported on micro targets."
assert exit_code != 0, on_error


def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory):
pytest.importorskip("tflite")

tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant)
mod, params = tvmc_model.mod, tvmc_model.params

graph_module = tvm.relay.build(mod, target='llvm', params=params)
graph_module = tvm.relay.build(mod, target="llvm", params=params)

output_dir = tmpdir_factory.mktemp("mlf")
output_file = os.path.join(output_dir, "mock.tar")
Expand All @@ -64,7 +63,7 @@ def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory):
executor_factory=graph_module,
package_path=output_file,
cross=None,
output_format='mlf',
output_format="mlf",
)
assert os.path.exists(output_file), "Could not find the exported MLF archive."

Expand All @@ -75,13 +74,14 @@ def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory):
tvmc_model.export_package(
executor_factory=graph_module,
package_path=output_file,
cross='cc',
output_format='mlf',
cross="cc",
output_format="mlf",
)
expected_reason = "Specifying the MLF output and a cross compiler is not supported."
on_error = "A TVMCException was caught but its reason is not the expected one."
assert str(exp.value) == expected_reason, on_error


def test_tvmc_import_package_mlf(tflite_compiled_model_mlf):
# Compile and export a model to a MLF archive so it can be imported.
exported_tvmc_package = tflite_compiled_model_mlf
Expand Down

0 comments on commit 8e5e9e9

Please sign in to comment.