Skip to content

Commit

Permalink
use unique path for test
Browse files Browse the repository at this point in the history
  • Loading branch information
Archermmt committed Jan 2, 2024
1 parent b7e63f3 commit 1fb2793
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/tvm/contrib/msc/core/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class MSCDirectory(object):

def __init__(self, path: str = None, keep_history: bool = True, cleanup: bool = False):
if not path:
path = "msc_" + str(datetime.datetime.now().strftime("%Y%m%d_%H-%M-%S"))
path = "msc_" + str(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
self._path = os.path.abspath(path)
self._cleanup = cleanup
self._cwd = os.getcwd()
Expand Down
9 changes: 8 additions & 1 deletion tests/python/contrib/test_msc/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
""" Test Managers in MSC. """

import json
import datetime
import pytest
import torch

Expand All @@ -34,8 +35,12 @@

def _get_config(model_type, compile_type, inputs, outputs, atol=1e-1, rtol=1e-1):
"""Get msc config"""

path = "test_manager_{}_{}_{}".format(
model_type, compile_type, datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
)
return {
"workspace": msc_utils.msc_dir(),
"workspace": msc_utils.msc_dir(path),
"verbose": "critical",
"model_type": model_type,
"inputs": inputs,
Expand All @@ -55,6 +60,7 @@ def _get_config(model_type, compile_type, inputs, outputs, atol=1e-1, rtol=1e-1)

def _get_torch_model(name, is_training=False):
"""Get model from torch vision"""

# pylint: disable=import-outside-toplevel
try:
import torchvision
Expand All @@ -72,6 +78,7 @@ def _get_torch_model(name, is_training=False):

def _get_tf_graph():
"""Get graph from tensorflow"""

# pylint: disable=import-outside-toplevel
try:
from tvm.contrib.msc.framework.tensorflow import tf_v1
Expand Down
10 changes: 8 additions & 2 deletions tests/python/contrib/test_msc/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

""" Test Runners in MSC. """

import datetime
import pytest
import numpy as np

Expand All @@ -41,6 +42,7 @@

def _get_torch_model(name, is_training=False):
"""Get model from torch vision"""

# pylint: disable=import-outside-toplevel
try:
import torchvision
Expand Down Expand Up @@ -82,7 +84,10 @@ def _test_from_torch(runner_cls, device, is_training=False, atol=1e-1, rtol=1e-1

torch_model = _get_torch_model("resnet50", is_training)
if torch_model:
workspace = msc_utils.set_workspace(msc_utils.msc_dir())
path = "test_runner_torch_{}_{}_{}".format(
runner_cls.__name__, device, datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
)
workspace = msc_utils.set_workspace(msc_utils.msc_dir(path))
log_path = workspace.relpath("MSC_LOG", keep_history=False)
msc_utils.set_global_logger("info", log_path)
input_info = [([1, 3, 224, 224], "float32")]
Expand Down Expand Up @@ -139,7 +144,8 @@ def test_tensorflow_runner():

tf_graph, graph_def = _get_tf_graph()
if tf_graph and graph_def:
workspace = msc_utils.set_workspace(msc_utils.msc_dir())
path = "test_runner_tf_" + str(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
workspace = msc_utils.set_workspace(msc_utils.msc_dir(path))
log_path = workspace.relpath("MSC_LOG", keep_history=False)
msc_utils.set_global_logger("info", log_path)
data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
Expand Down
13 changes: 12 additions & 1 deletion tests/python/contrib/test_msc/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
""" Test Tools in MSC. """

import json
import datetime
import pytest
import torch

Expand All @@ -44,8 +45,15 @@ def _get_config(
optimize_type=None,
):
"""Get msc config"""

path = "test_tool_{}_{}".format(model_type, compile_type)
for t_type, config in tools_config.items():
path = path + "_" + str(t_type)
if "gym_configs" in config:
path = path + "_gym"
path = path + "_" + str(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
return {
"workspace": msc_utils.msc_dir(),
"workspace": msc_utils.msc_dir(path),
"verbose": "critical",
"model_type": model_type,
"inputs": inputs,
Expand All @@ -70,6 +78,7 @@ def _get_config(

def get_tool_config(tool_type, use_distill=False, use_gym=False):
"""Get config for the tool"""

config = {}
if tool_type == ToolType.PRUNER:
config = {
Expand Down Expand Up @@ -178,6 +187,7 @@ def get_tool_config(tool_type, use_distill=False, use_gym=False):

def _get_torch_model(name, is_training=False):
"""Get model from torch vision"""

# pylint: disable=import-outside-toplevel
try:
import torchvision
Expand Down Expand Up @@ -239,6 +249,7 @@ def _test_from_torch(

def get_model_info(compile_type):
"""Get the model info"""

if compile_type == MSCFramework.TVM:
return {
"inputs": [
Expand Down

0 comments on commit 1fb2793

Please sign in to comment.