From 68f1ae76dd3271a095b79c49fcc634f121de9216 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 31 Aug 2022 09:41:29 -0500 Subject: [PATCH] [Hexagon] Validate 2-d physical shapes for TIR-derived schedules Previously, the test cases only tested TE-based schedules. This commit runs the same tests for equivalent TIR-based schedules as well. This is intended to catch Hexagon-specific regressions, such as the one resolved in https://github.com/apache/tvm/pull/12652. --- .../test_hexagon/test_2d_physical_buffers.py | 60 ++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) mode change 100644 => 100755 tests/python/contrib/test_hexagon/test_2d_physical_buffers.py diff --git a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py old mode 100644 new mode 100755 index cebb36edc35d1..9c58d084b8af6 --- a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py +++ b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py @@ -24,6 +24,7 @@ import numpy as np import pytest import tvm +from tvm.script import tir as T # Needed to register the link_shared packedfunc. import tvm.contrib.hexagon @@ -41,6 +42,8 @@ # there as well # pylint: disable=invalid-name +schedule_type = tvm.testing.parameter("TE", "TIR") + dtype = tvm.testing.parameter("int8") batch_size = tvm.testing.parameter( 16, @@ -198,6 +201,7 @@ def output_shape(self, input_shape): @tvm.testing.fixture def schedule_args( self, + schedule_type, input_shape, dtype, input_layout, @@ -206,12 +210,39 @@ def schedule_args( working_scope, ): """Create and return the schedule and input args after applying layout transform""" + if schedule_type == "TE": + + return self._te_schedule_args( + input_shape, dtype, input_layout, output_layout, working_layout, working_scope + ) + elif schedule_type == "TIR": + return self._tir_schedule_args( + input_shape, dtype, input_layout, output_layout, working_layout, working_scope + ) + + else: + raise ValueError(f"Unknown schedule type: {schedule_type}") + + def _te_tensors(self, input_shape, dtype): input_tensor = te.placeholder(input_shape, dtype, name="Input") output_tensor = te.compute( shape=input_tensor.shape, fcompute=lambda *indices: (2 * input_tensor[indices]).astype(dtype), name="Output", ) + return input_tensor, output_tensor + + def _te_schedule_args( + self, + input_shape, + dtype, + input_layout, + output_layout, + working_layout, + working_scope, + ): + input_tensor, output_tensor = self._te_tensors(input_shape, dtype) + schedule = te.create_schedule(output_tensor.op) write_cache = schedule.cache_write(output_tensor, working_scope) @@ -235,6 +266,33 @@ def apply_transform(tensor, layout): return [schedule, [input_tensor, output_tensor]] + def _tir_schedule_args( + self, input_shape, dtype, input_layout, output_layout, working_layout, working_scope + ): + tensors = self._te_tensors(input_shape, dtype) + + sch = tvm.tir.Schedule(te.create_prim_func(tensors)) + + cache_read_block = sch.cache_read("Output", 0, working_scope) + cache_write_block = sch.cache_write("Output", 0, working_scope) + + def apply_transform(block, buffer_name, layout): + if layout == "nhwc": + pass + elif layout == "nchw-8h8w32c-1d": + sch.transform_layout(block, buffer_name, layout_transform_1d) + elif layout == "nchw-8h8w32c-2d": + sch.transform_layout(block, buffer_name, layout_transform_2d) + else: + raise RuntimeError(f"Unexpected layout '{layout}'") + + apply_transform(cache_read_block, ("read", 0), input_layout) + apply_transform(cache_read_block, ("write", 0), working_layout) + apply_transform(cache_write_block, ("read", 0), working_layout) + apply_transform(cache_write_block, ("write", 0), output_layout) + + return [sch.mod] + @tvm.testing.fixture def ir_module(self, schedule_args): # If the two buffers are accessed with the same indices, CSE @@ -272,7 +330,7 @@ def test_cache_shape(self, ir_module, input_layout, working_layout, output_layou "Input.global.vtcm": working_layout, "Output.global.vtcm": working_layout, "Output": output_layout, - }[buffer.name] + }[buffer.name.replace("_", ".")] expected_physical_dimensions = { "nhwc": 1,