Skip to content

Commit

Permalink
log report when fail
Browse files Browse the repository at this point in the history
  • Loading branch information
Archermmt committed Jan 1, 2024
1 parent caeb078 commit b7e63f3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
3 changes: 2 additions & 1 deletion tests/python/contrib/test_msc/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

""" Test Managers in MSC. """

import json
import pytest
import torch

Expand Down Expand Up @@ -103,7 +104,7 @@ def _check_manager(manager, expected_info):
err = "Model info {} mismatch with expected {}".format(model_info, expected_info)
manager.destory()
if not passed:
raise Exception(err)
raise Exception("{}\nReport:{}".format(err, json.dumps(manager.report, indent=2)))


def _test_from_torch(compile_type, expected_info, is_training=False, atol=1e-1, rtol=1e-1):
Expand Down
16 changes: 5 additions & 11 deletions tests/python/contrib/test_msc/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

""" Test Tools in MSC. """

import os
import json
import pytest

import torch

import tvm.testing
Expand Down Expand Up @@ -194,25 +193,20 @@ def _get_torch_model(name, is_training=False):
return None


def _check_manager(manager, tools_config, expected_info):
def _check_manager(manager, expected_info):
"""Check the manager results"""

model_info = manager.runner.model_info
passed, err = True, ""
if not manager.report["success"]:
passed = False
err = "Failed to run pipe for {} -> {}".format(manager.model_type, manager.compile_type)
for t_type, config in tools_config.items():
if not os.path.isfile(msc_utils.get_config_dir().relpath(config["plan_file"])):
passed = False
err = "Failed to find plan of " + str(t_type)
break
if not msc_utils.dict_equal(model_info, expected_info):
passed = False
err = "Model info {} mismatch with expected {}".format(model_info, expected_info)
manager.destory()
if not passed:
raise Exception(err)
raise Exception("{}\nReport:{}".format(err, json.dumps(manager.report, indent=2)))


def _test_from_torch(
Expand Down Expand Up @@ -240,7 +234,7 @@ def _test_from_torch(
)
manager = MSCManager(torch_model, config)
manager.run_pipe()
_check_manager(manager, tools_config, expected_info)
_check_manager(manager, expected_info)


def get_model_info(compile_type):
Expand Down Expand Up @@ -300,7 +294,7 @@ def test_tvm_distill(tool_type):
@tvm.testing.requires_gpu
@pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER])
def test_tvm_gym(tool_type):
"""Test tools for tvm with distiller"""
"""Test tools for tvm with gym"""

tool_config = get_tool_config(tool_type, use_gym=True)
_test_from_torch(
Expand Down

0 comments on commit b7e63f3

Please sign in to comment.