Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
Archermmt committed Mar 23, 2024
1 parent 25201ef commit d8f7ddc
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
# specific language governing permissions and limitations
# under the License.

""" Test Managers in MSC. """
""" Test Pipeline in MSC. """

import json
import pytest
import torch

import tvm.testing
from tvm.contrib.msc.pipeline import MSCManager
from tvm.contrib.msc.pipeline import MSCManager, TorchDynamic
from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.core import utils as msc_utils

Expand All @@ -32,13 +32,13 @@
)


def _get_config(model_type, compile_type, inputs, outputs, atol=1e-1, rtol=1e-1):
def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1e-1, rtol=1e-1):
"""Get msc config"""

path = "test_manager_{}_{}".format(model_type, compile_type)
path = "test_pipe_{}_{}_{}".format(model_type, compile_type, "dynamic" if dynamic else "static")
return {
"workspace": msc_utils.msc_dir(path),
"verbose": "critical",
"verbose": "info",
"model_type": model_type,
"inputs": inputs,
"outputs": outputs,
Expand Down Expand Up @@ -95,23 +95,29 @@ def _get_tf_graph():
return None


def _check_manager(manager, expected_info):
"""Check the manager results"""
def _check_pipeline(pipeline, expected_info, dynamic=False):
"""Check the pipeline results"""

model_info = manager.runner.model_info
passed, err = True, ""
if not manager.report["success"]:
if not pipeline.report["success"]:
passed = False
err = "Failed to run pipe for {} -> {}".format(manager.model_type, manager.compile_type)
if not msc_utils.dict_equal(model_info, expected_info):
passed = False
err = "Model info {} mismatch with expected {}".format(model_info, expected_info)
manager.destory()
err = "Failed to run pipe for {} -> {}".format(pipeline.model_type, pipeline.compile_type)
if not dynamic:
model_info = pipeline.get_runtime().model_info
if not msc_utils.dict_equal(model_info, expected_info):
passed = False
err = "Model info {} mismatch with expected {}".format(model_info, expected_info)
pipeline.destory()
if not passed:
raise Exception("{}\nReport:{}".format(err, json.dumps(manager.report, indent=2)))
raise Exception("{}\nReport:{}".format(err, json.dumps(pipeline.report, indent=2)))


def _test_from_torch(
compile_type, expected_info, training=False, dynamic=False, atol=1e-1, rtol=1e-1
):
if dynamic and not hasattr(torch, "compile"):
return

def _test_from_torch(compile_type, expected_info, training=False, atol=1e-1, rtol=1e-1):
torch_model = _get_torch_model("resnet50", training)
if torch_model:
if torch.cuda.is_available():
Expand All @@ -121,12 +127,13 @@ def _test_from_torch(compile_type, expected_info, training=False, atol=1e-1, rto
compile_type,
inputs=[["input_0", [1, 3, 224, 224], "float32"]],
outputs=["output"],
dynamic=dynamic,
atol=atol,
rtol=rtol,
)
manager = MSCManager(torch_model, config)
manager.run_pipe()
_check_manager(manager, expected_info)
pipeline = TorchDynamic(torch_model, config) if dynamic else MSCManager(torch_model, config)
pipeline.run_pipe()
_check_pipeline(pipeline, expected_info, dynamic)


def _test_from_tf(compile_type, expected_info, atol=1e-2, rtol=1e-2):
Expand All @@ -143,11 +150,12 @@ def _test_from_tf(compile_type, expected_info, atol=1e-2, rtol=1e-2):
config["compile"]["profile"]["check"]["err_rate"] = -1
manager = MSCManager(graphdef, config)
manager.run_pipe()
_check_manager(manager, expected_info)
_check_pipeline(manager, expected_info)


def test_tvm_manager():
"""Test manager for tvm"""
@pytest.mark.parametrize("dynamic", [False, True])
def test_tvm_pipeline(dynamic):
"""Test pipeline for tvm"""

model_info = {
"inputs": [
Expand All @@ -168,40 +176,42 @@ def test_tvm_manager():
"msc.linear_bias": 1,
},
}
_test_from_torch(MSCFramework.TVM, model_info, training=False)

model_info = {
"inputs": [
{"name": "input", "shape": [1, 224, 224, 3], "dtype": "float32", "layout": "NHWC"}
],
"outputs": [
{
"name": "MobilenetV2/Predictions/Reshape_1:0",
"shape": [1, 1001],
"dtype": "float32",
"layout": "NC",
}
],
"nodes": {
"total": 138,
"input": 1,
"msc.conv2d_bias": 36,
"clip": 35,
"nn.conv2d": 17,
"nn.batch_norm": 17,
"get_item": 17,
"add": 10,
"nn.avg_pool2d": 1,
"squeeze": 1,
"reshape": 2,
"nn.softmax": 1,
},
}
_test_from_tf(MSCFramework.TVM, model_info)


def test_torch_manager():
"""Test manager for torch"""
_test_from_torch(MSCFramework.TVM, model_info, training=False, dynamic=dynamic)

if not dynamic:
model_info = {
"inputs": [
{"name": "input", "shape": [1, 224, 224, 3], "dtype": "float32", "layout": "NHWC"}
],
"outputs": [
{
"name": "MobilenetV2/Predictions/Reshape_1:0",
"shape": [1, 1001],
"dtype": "float32",
"layout": "NC",
}
],
"nodes": {
"total": 138,
"input": 1,
"msc.conv2d_bias": 36,
"clip": 35,
"nn.conv2d": 17,
"nn.batch_norm": 17,
"get_item": 17,
"add": 10,
"nn.avg_pool2d": 1,
"squeeze": 1,
"reshape": 2,
"nn.softmax": 1,
},
}
_test_from_tf(MSCFramework.TVM, model_info)


@pytest.mark.parametrize("dynamic", [False, True])
def test_torch_pipeline(dynamic):
"""Test pipeline for torch"""

model_info = {
"inputs": [
Expand All @@ -222,10 +232,10 @@ def test_torch_manager():
"msc.linear_bias": 1,
},
}
_test_from_torch(MSCFramework.TORCH, model_info, training=False)
_test_from_torch(MSCFramework.TORCH, model_info, training=False, dynamic=dynamic)


def test_tensorflow_manager():
def test_tensorflow_pipeline():
"""Test manager for tensorflow"""

model_info = {
Expand Down Expand Up @@ -259,8 +269,9 @@ def test_tensorflow_manager():


@requires_tensorrt
def test_tensorrt_manager():
"""Test manager for tensorrt"""
@pytest.mark.parametrize("dynamic", [False, True])
def test_tensorrt_pipeline(dynamic):
"""Test pipeline for tensorrt"""

model_info = {
"inputs": [
Expand All @@ -269,7 +280,7 @@ def test_tensorrt_manager():
"outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": ""}],
"nodes": {"total": 2, "input": 1, "msc_tensorrt": 1},
}
_test_from_torch(MSCFramework.TENSORRT, model_info, training=False)
_test_from_torch(MSCFramework.TENSORRT, model_info, training=False, dynamic=dynamic)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_msc/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def _test_with_manager(plugins, compile_type, expected_info):
}
manager = MSCManager(model, config, plugins=plugins)
report = manager.run_pipe()
model_info = manager.runner.model_info
model_info = manager.get_runtime().model_info
manager.destory()
assert report["success"], "Failed to run pipe for torch -> {}".format(compile_type)
assert msc_utils.dict_equal(
Expand Down
4 changes: 2 additions & 2 deletions tests/python/contrib/test_msc/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _test_from_torch(runner_cls, device, training=False, atol=1e-1, rtol=1e-1):
golden = [msc_utils.cast_array(golden)]
workspace.destory()
for gol_r, out_r in zip(golden, outputs):
tvm.testing.assert_allclose(gol_r, out_r, atol=atol, rtol=rtol)
tvm.testing.assert_allclose(gol_r, msc_utils.cast_array(out_r), atol=atol, rtol=rtol)


def test_tvm_runner_cpu():
Expand Down Expand Up @@ -162,7 +162,7 @@ def test_tensorflow_runner():
outputs = runner.run([data], ret_type="list")
workspace.destory()
for gol_r, out_r in zip(golden, outputs):
tvm.testing.assert_allclose(gol_r, out_r, atol=1e-3, rtol=1e-3)
tvm.testing.assert_allclose(gol_r, msc_utils.cast_array(out_r), atol=1e-3, rtol=1e-3)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/python/contrib/test_msc/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_tools(tool_type, use_distill=False, run_type=MSCFramework.MSC):
}
],
}
tools.append({"tool_type": ToolType.TRACKER, "tool_config": config, "apply_once": True})
tools.append({"tool_type": ToolType.TRACKER, "tool_config": config})
if use_distill:
config = {
"plan_file": "msc_distiller.json",
Expand Down Expand Up @@ -180,7 +180,7 @@ def _get_torch_model(name, training=False):
def _check_manager(manager, expected_info):
"""Check the manager results"""

model_info = manager.runner.model_info
model_info = manager.get_runtime().model_info
passed, err = True, ""
if not manager.report["success"]:
passed = False
Expand Down

0 comments on commit d8f7ddc

Please sign in to comment.