Skip to content

Commit

Permalink
lint code
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippvK committed Mar 11, 2023
1 parent 8d21064 commit 7f8af41
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions tests/python/driver/tvmc/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def test_layout_transform_convert_kernel_layout_pass_args(relay_conv2d, monkeypa
monkeypatch.setattr(relay.transform, "ConvertLayout", mock_convert_layout)

with tvm.transform.PassContext(opt_level=3):
apply_graph_transforms(relay_conv2d, {"desired_layout": [desired_layout], "desired_layout_ops": desired_layout_ops})
apply_graph_transforms(
relay_conv2d,
{"desired_layout": [desired_layout], "desired_layout_ops": desired_layout_ops}
)

mock_convert_layout.assert_called_once_with(
{
Expand All @@ -109,7 +112,10 @@ def test_layout_transform_convert_layout_pass_args_multiple(relay_conv2d, monkey
monkeypatch.setattr(relay.transform, "ConvertLayout", mock_convert_layout)

with tvm.transform.PassContext(opt_level=3):
apply_graph_transforms(relay_conv2d, {"desired_layout": desired_layout, "desired_layout_ops": desired_layout_ops})
apply_graph_transforms(
relay_conv2d,
{"desired_layout": desired_layout, "desired_layout_ops": desired_layout_ops}
)

mock_convert_layout.assert_called_once_with(
{
Expand All @@ -119,10 +125,13 @@ def test_layout_transform_convert_layout_pass_args_multiple(relay_conv2d, monkey
)


@pytest.mark.parametrize("desired", [
(["NHWC", "NCHW"], ["nn.max_pool2d"]),
(["NHWC", "NCHW"], None),
])
@pytest.mark.parametrize(
"desired",
[
(["NHWC", "NCHW"], ["nn.max_pool2d"]),
(["NHWC", "NCHW"], None),
]
)
def test_layout_transform_convert_layout_pass_args_multiple_invalid(relay_conv2d, monkeypatch, desired):
"""
TODO
Expand All @@ -135,7 +144,10 @@ def test_layout_transform_convert_layout_pass_args_multiple_invalid(relay_conv2d

with pytest.raises(TVMCException):
with tvm.transform.PassContext(opt_level=3):
apply_graph_transforms(relay_conv2d, {"desired_layout": desired_layout, "desired_layout_ops": desired_layout_ops})
apply_graph_transforms(
relay_conv2d,
{"desired_layout": desired_layout, "desired_layout_ops": desired_layout_ops}
)


def test_layout_transform_to_mixed_precision_pass_args_mock(relay_conv2d, monkeypatch):
Expand Down

0 comments on commit 7f8af41

Please sign in to comment.