Skip to content

Commit

Permalink
[CUDA][Schedule] Better Layout Transform Schedules (#14167)
Browse files Browse the repository at this point in the history
* initial basis

* Generated all the tile sizes

* is this all you need?

* linting

lint

move schedule rule to own file

lint p2

layout transform fixings

* forgot to forward arg

* fix tests

* reduce search space

* lint

* schedule rule documentation

* add a note

* fix wording

* handle implicit reshape case v1

* clean up comments

* address comments

* testing harness

* more progress on testing harness

* fix case where shape changes in mod

* inline after schedule genreation to help analysis

* proper autoinlining INTO layout transform block to maintain extants

* clean up

* reindex for introducing cache block

* reorganize testing

* more cleanup

* remove forced false

* use the proper dispatcher

* update test, make default schedule rule None

* linting

* fix mypy errors

* clean up

* manual test cases

* manual tests

* add comment, fix improper implicit reshape handling

* fix

* remove extra comments

* more lints

* refactor

* remove extraneous check

* lint again :/

* remove uneeded newline

* remove leading spaces
  • Loading branch information
AndrewZhaoLuo authored Mar 23, 2023
1 parent 3f56a95 commit e5ae434
Show file tree
Hide file tree
Showing 10 changed files with 1,116 additions and 16 deletions.
10 changes: 9 additions & 1 deletion include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1592,10 +1592,12 @@ inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& in
* \param dst_layout the destination layout.
* \param name output tensor name.
* \param tag output tensor tag.
* \param schedule_rule name of specialized schedule rule to use.
* \return A tensor with shape in \p dst_layout
*/
inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
const std::string& dst_layout,
const std::string schedule_rule = "None",
const std::string name = "T_layout_trans",
const std::string tag = kInjective) {
Layout src_layout_struct(src_layout);
Expand All @@ -1614,6 +1616,12 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,

Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);

Map<String, ObjectRef> attrs = {{"schedule_rule", String(schedule_rule)},
// Information about layouts needed for the schedule rule
{"src_layout", String(src_layout)},
{"dst_layout", String(dst_layout)},
{"input_shape", src->shape}};

return compute(
dst_shape,
[&](const Array<Var>& dst_indices) {
Expand All @@ -1625,7 +1633,7 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
}
return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0)));
},
name, tag);
name, tag, attrs);
}

/*! \brief Utility function for auto_scheduler_layout_transform */
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/schedule/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@
# specific language governing permissions and limitations
# under the License.
"""Per-block schedule rules in MetaSchedule for target key 'cuda'"""

from . import layout_transform
583 changes: 583 additions & 0 deletions python/tvm/meta_schedule/schedule/cuda/layout_transform.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def compute_strided_set(attrs, inputs, output_type):
_reg.register_injective_schedule("strided_set")

# layout_transform
_reg.register_injective_schedule("layout_transform")
_reg.register_strategy("layout_transform", strategy.layout_transform_strategy)
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
_reg.register_injective_schedule("auto_scheduler_layout_transform")
_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE)
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,3 +1396,14 @@ def dft_strategy_cuda(attrs, inputs, out_type, target):
name="dft.cuda",
)
return strategy


@layout_transform_strategy.register(["cuda", "gpu"])
def layout_transform_strategy_cuda(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_layout_transform(topi.layout_transform, schedule_rule="layout_transform"),
schedule_injective,
name="layout_transform.cuda",
)
return strategy
36 changes: 30 additions & 6 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@

from tvm import _ffi, ir, te, topi
from tvm.target import generic_func, override_native_generic_func
from tvm.topi.utils import (
get_const_float,
get_const_int,
get_const_tuple,
get_float_tuple,
)
from tvm.topi.utils import get_const_float, get_const_int, get_const_tuple, get_float_tuple

from .. import op as _op

Expand Down Expand Up @@ -2060,3 +2055,32 @@ def conv2d_backward_weight_strategy(attrs, inputs, out_type, target):
"conv2d_backward_weight is currently only supported with cudnn. "
"Please run Legalize pass to decompose this op into supported ops."
)


@override_native_generic_func("layout_transform_strategy")
def layout_transform_strategy(attrs, inputs, out_type, target):
"""layout transform generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_layout_transform(topi.layout_transform),
# Defined earlier in the file
schedule_injective,
name="layout_transform.generic",
)
return strategy


def wrap_compute_layout_transform(topi_compute, schedule_rule="None"):
"""Wrap layout transform compute"""

def _compute_layout_transform(attrs, inputs, output_type):
return [
topi_compute(
inputs[0],
attrs.src_layout,
attrs.dst_layout,
schedule_rule,
)
]

return _compute_layout_transform
17 changes: 10 additions & 7 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
# pylint: disable=invalid-name,consider-using-enumerate,redefined-outer-name
"""Injective transformation operators"""
from __future__ import absolute_import as _abs

import tvm
from tvm import te
from tvm import topi
from tvm import te, topi
from tvm.te import hybrid
from . import cpp
from . import tag
from .utils import within_index, make_idx, const_vector

from . import cpp, tag
from .utils import const_vector, make_idx, within_index


def expand_dims(a, axis, num_newaxis=1):
Expand Down Expand Up @@ -636,7 +636,7 @@ def tile(a, reps):
return cpp.tile(a, reps)


def layout_transform(array, src_layout, dst_layout):
def layout_transform(array, src_layout, dst_layout, schedule_rule="None"):
"""Transform the layout according to src_layout and dst_layout
Parameters
Expand All @@ -649,8 +649,11 @@ def layout_transform(array, src_layout, dst_layout):
dst_layout : str
the destination layout.
schedule_rule : str
the schedule rule to apply if any
"""
return cpp.layout_transform(array, src_layout, dst_layout)
return cpp.layout_transform(array, src_layout, dst_layout, schedule_rule)


def shape(array, dtype="int32"):
Expand Down
2 changes: 1 addition & 1 deletion src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) {
});

TVM_REGISTER_GLOBAL("topi.layout_transform").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = layout_transform(args[0], args[1], args[2]);
*rv = layout_transform(args[0], args[1], args[2], args[3]);
});

TVM_REGISTER_GLOBAL("topi.take").set_body([](TVMArgs args, TVMRetValue* rv) {
Expand Down
3 changes: 3 additions & 0 deletions tests/python/unittest/test_meta_schedule_relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import pytest

import tvm
import tvm.testing
from tvm import IRModule
Expand Down Expand Up @@ -420,6 +421,7 @@ def main( # type: ignore
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(placeholder[ax0, ax1 * T.int64(3) + ax4, ax2, ax3])
T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4])
T.block_attr({"dst_layout": "NCHW3c", "input_shape": [1, 3, 16, 16], "schedule_rule": "None", "src_layout": "NCHW"})
T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(
ax0 < T.int64(1) and ax1 * T.int64(3) + ax4 < T.int64(3) and ax2 < T.int64(16) and ax3 < T.int64(16), # type: ignore
placeholder[ax0, ax1 * T.int64(3) + ax4, ax2, ax3],
Expand All @@ -440,6 +442,7 @@ def main(placeholder: T.Buffer((T.int64(1), T.int64(2), T.int64(16), T.int64(16)
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(placeholder[ax0, ax1 // T.int64(4), ax2, ax3, ax1 % T.int64(4)]) # type: ignore
T.writes(T_layout_trans[ax0, ax1, ax2, ax3])
T.block_attr({"dst_layout": "NCHW", "input_shape": [1, 2, 16, 16, 4], "schedule_rule": "None", "src_layout": "NCHW4c"})
T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < T.int64(1) and ax1 < T.int64(8) and ax2 < T.int64(16) and ax3 < T.int64(16), placeholder[ax0, ax1 // T.int64(4), ax2, ax3, ax1 % T.int64(4)], T.float32(0), dtype="float32") # type: ignore

@tvm.script.ir_module
Expand Down
Loading

0 comments on commit e5ae434

Please sign in to comment.