Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA][Schedule] Better Layout Transform Schedules #14167

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
1df07f7
initial basis
AndrewZhaoLuo Mar 1, 2023
8748bc8
Generated all the tile sizes
AndrewZhaoLuo Mar 1, 2023
54d3625
is this all you need?
AndrewZhaoLuo Mar 2, 2023
6ccea68
linting
AndrewZhaoLuo Mar 2, 2023
ced49f4
forgot to forward arg
AndrewZhaoLuo Mar 3, 2023
c6e8739
fix tests
AndrewZhaoLuo Mar 3, 2023
f18f933
reduce search space
AndrewZhaoLuo Mar 3, 2023
5a51ffa
lint
AndrewZhaoLuo Mar 3, 2023
e1ce901
schedule rule documentation
AndrewZhaoLuo Mar 8, 2023
5f0b1b0
add a note
AndrewZhaoLuo Mar 8, 2023
826a877
fix wording
AndrewZhaoLuo Mar 8, 2023
d11e66e
handle implicit reshape case v1
AndrewZhaoLuo Mar 13, 2023
0be1be0
clean up comments
AndrewZhaoLuo Mar 13, 2023
f2f5165
address comments
AndrewZhaoLuo Mar 13, 2023
1fb271b
testing harness
AndrewZhaoLuo Mar 14, 2023
31ca25a
more progress on testing harness
AndrewZhaoLuo Mar 14, 2023
585191b
fix case where shape changes in mod
AndrewZhaoLuo Mar 14, 2023
9822f4c
inline after schedule genreation to help analysis
AndrewZhaoLuo Mar 14, 2023
e152e39
proper autoinlining INTO layout transform block to maintain extants
AndrewZhaoLuo Mar 14, 2023
8274d21
clean up
AndrewZhaoLuo Mar 14, 2023
1c7aa19
reindex for introducing cache block
AndrewZhaoLuo Mar 15, 2023
e704664
reorganize testing
AndrewZhaoLuo Mar 15, 2023
67c2db9
more cleanup
AndrewZhaoLuo Mar 15, 2023
6e1ea6b
remove forced false
AndrewZhaoLuo Mar 15, 2023
8eace30
use the proper dispatcher
AndrewZhaoLuo Mar 15, 2023
8eb6f8b
update test, make default schedule rule None
AndrewZhaoLuo Mar 16, 2023
5546079
linting
AndrewZhaoLuo Mar 16, 2023
a3729c9
fix mypy errors
AndrewZhaoLuo Mar 16, 2023
bd57077
clean up
AndrewZhaoLuo Mar 16, 2023
7eb21ad
manual test cases
AndrewZhaoLuo Mar 16, 2023
3d092fe
manual tests
AndrewZhaoLuo Mar 16, 2023
7d4df3f
add comment, fix improper implicit reshape handling
AndrewZhaoLuo Mar 16, 2023
37a0e9d
fix
AndrewZhaoLuo Mar 16, 2023
0687e57
remove extra comments
AndrewZhaoLuo Mar 16, 2023
8945399
more lints
AndrewZhaoLuo Mar 16, 2023
ceb7548
refactor
AndrewZhaoLuo Mar 16, 2023
7fdcb69
remove extraneous check
AndrewZhaoLuo Mar 20, 2023
a5d8f5f
lint again :/
AndrewZhaoLuo Mar 21, 2023
bf60774
remove uneeded newline
AndrewZhaoLuo Mar 21, 2023
2c1047a
remove leading spaces
AndrewZhaoLuo Mar 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1592,10 +1592,12 @@ inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& in
* \param dst_layout the destination layout.
* \param name output tensor name.
* \param tag output tensor tag.
* \param schedule_rule name of specialized schedule rule to use.
* \return A tensor with shape in \p dst_layout
*/
inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
const std::string& dst_layout,
const std::string schedule_rule = "None",
AndrewZhaoLuo marked this conversation as resolved.
Show resolved Hide resolved
const std::string name = "T_layout_trans",
const std::string tag = kInjective) {
Layout src_layout_struct(src_layout);
Expand All @@ -1614,6 +1616,12 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,

Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);

Map<String, ObjectRef> attrs = {{"schedule_rule", String(schedule_rule)},
// Information about layouts needed for the schedule rule
{"src_layout", String(src_layout)},
{"dst_layout", String(dst_layout)},
{"input_shape", src->shape}};

return compute(
dst_shape,
[&](const Array<Var>& dst_indices) {
Expand All @@ -1625,7 +1633,7 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
}
return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0)));
},
name, tag);
name, tag, attrs);
}

/*! \brief Utility function for auto_scheduler_layout_transform */
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/schedule/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@
# specific language governing permissions and limitations
# under the License.
"""Per-block schedule rules in MetaSchedule for target key 'cuda'"""

from . import layout_transform
583 changes: 583 additions & 0 deletions python/tvm/meta_schedule/schedule/cuda/layout_transform.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def compute_strided_set(attrs, inputs, output_type):
_reg.register_injective_schedule("strided_set")

# layout_transform
_reg.register_injective_schedule("layout_transform")
_reg.register_strategy("layout_transform", strategy.layout_transform_strategy)
AndrewZhaoLuo marked this conversation as resolved.
Show resolved Hide resolved
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
_reg.register_injective_schedule("auto_scheduler_layout_transform")
_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE)
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,3 +1396,14 @@ def dft_strategy_cuda(attrs, inputs, out_type, target):
name="dft.cuda",
)
return strategy


@layout_transform_strategy.register(["cuda", "gpu"])
def layout_transform_strategy_cuda(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_layout_transform(topi.layout_transform, schedule_rule="layout_transform"),
schedule_injective,
name="layout_transform.cuda",
)
return strategy
36 changes: 30 additions & 6 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@

from tvm import _ffi, ir, te, topi
from tvm.target import generic_func, override_native_generic_func
from tvm.topi.utils import (
get_const_float,
get_const_int,
get_const_tuple,
get_float_tuple,
)
from tvm.topi.utils import get_const_float, get_const_int, get_const_tuple, get_float_tuple

from .. import op as _op

Expand Down Expand Up @@ -2060,3 +2055,32 @@ def conv2d_backward_weight_strategy(attrs, inputs, out_type, target):
"conv2d_backward_weight is currently only supported with cudnn. "
"Please run Legalize pass to decompose this op into supported ops."
)


@override_native_generic_func("layout_transform_strategy")
def layout_transform_strategy(attrs, inputs, out_type, target):
"""layout transform generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_layout_transform(topi.layout_transform),
# Defined earlier in the file
schedule_injective,
name="layout_transform.generic",
)
return strategy


def wrap_compute_layout_transform(topi_compute, schedule_rule="None"):
"""Wrap layout transform compute"""

def _compute_layout_transform(attrs, inputs, output_type):
return [
topi_compute(
inputs[0],
attrs.src_layout,
attrs.dst_layout,
schedule_rule,
)
]

return _compute_layout_transform
17 changes: 10 additions & 7 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
# pylint: disable=invalid-name,consider-using-enumerate,redefined-outer-name
"""Injective transformation operators"""
from __future__ import absolute_import as _abs

import tvm
from tvm import te
from tvm import topi
from tvm import te, topi
from tvm.te import hybrid
from . import cpp
from . import tag
from .utils import within_index, make_idx, const_vector

from . import cpp, tag
from .utils import const_vector, make_idx, within_index


def expand_dims(a, axis, num_newaxis=1):
Expand Down Expand Up @@ -636,7 +636,7 @@ def tile(a, reps):
return cpp.tile(a, reps)


def layout_transform(array, src_layout, dst_layout):
def layout_transform(array, src_layout, dst_layout, schedule_rule="None"):
"""Transform the layout according to src_layout and dst_layout

Parameters
Expand All @@ -649,8 +649,11 @@ def layout_transform(array, src_layout, dst_layout):

dst_layout : str
the destination layout.

schedule_rule : str
the schedule rule to apply if any
"""
return cpp.layout_transform(array, src_layout, dst_layout)
return cpp.layout_transform(array, src_layout, dst_layout, schedule_rule)


def shape(array, dtype="int32"):
Expand Down
2 changes: 1 addition & 1 deletion src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) {
});

TVM_REGISTER_GLOBAL("topi.layout_transform").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = layout_transform(args[0], args[1], args[2]);
*rv = layout_transform(args[0], args[1], args[2], args[3]);
});

TVM_REGISTER_GLOBAL("topi.take").set_body([](TVMArgs args, TVMRetValue* rv) {
Expand Down
3 changes: 3 additions & 0 deletions tests/python/unittest/test_meta_schedule_relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import pytest

import tvm
import tvm.testing
from tvm import IRModule
Expand Down Expand Up @@ -420,6 +421,7 @@ def main( # type: ignore
ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(placeholder[ax0, ax1 * T.int64(3) + ax4, ax2, ax3])
T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4])
T.block_attr({"dst_layout": "NCHW3c", "input_shape": [1, 3, 16, 16], "schedule_rule": "None", "src_layout": "NCHW"})
T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(
ax0 < T.int64(1) and ax1 * T.int64(3) + ax4 < T.int64(3) and ax2 < T.int64(16) and ax3 < T.int64(16), # type: ignore
placeholder[ax0, ax1 * T.int64(3) + ax4, ax2, ax3],
Expand All @@ -440,6 +442,7 @@ def main(placeholder: T.Buffer((T.int64(1), T.int64(2), T.int64(16), T.int64(16)
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(placeholder[ax0, ax1 // T.int64(4), ax2, ax3, ax1 % T.int64(4)]) # type: ignore
T.writes(T_layout_trans[ax0, ax1, ax2, ax3])
T.block_attr({"dst_layout": "NCHW", "input_shape": [1, 2, 16, 16, 4], "schedule_rule": "None", "src_layout": "NCHW4c"})
T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < T.int64(1) and ax1 < T.int64(8) and ax2 < T.int64(16) and ax3 < T.int64(16), placeholder[ax0, ax1 // T.int64(4), ax2, ax3, ax1 % T.int64(4)], T.float32(0), dtype="float32") # type: ignore

@tvm.script.ir_module
Expand Down
Loading