Skip to content

Commit

Permalink
[TE] Optimized version of concatenation layer (#11341)
Browse files Browse the repository at this point in the history
* [TE] Optimized version of concatenation layer
     1. Concat implemented using extern_op
     2. New tests added.
     3. Workaround to allow inline extern_op-s with other layers.

* *test fix

* test_any.py fix.

* test_forward.py from tensorflow fix.

* lint fix.

* Fixes after code review.

* New comment added.

* Lint fix.

* Another lint fix.

* Comments added.

* rebase issue fix.

* Restored previous state.

* Update after code review.

* After code review changes.

* lint review.

* Change strategy for cuda to fix tests.

* Rebase to main

* Comments changes after review.

* Some more comments fixes.

* One more error fix in comments.

* restart build
  • Loading branch information
shtinsa authored Jun 1, 2022
1 parent a1d95ec commit e84f163
Show file tree
Hide file tree
Showing 11 changed files with 359 additions and 30 deletions.
7 changes: 6 additions & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@


# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
@_reg.register_compute("concatenate")
def compute_concat(attrs, inputs, output_type):
return [topi.concatenate(inputs, attrs.axis)]


_reg.register_strategy("concatenate", strategy.concatenate_strategy)

# sliding_window
@_reg.register_compute("sliding_window")
Expand Down
14 changes: 9 additions & 5 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@ def schedule_reduce_cuda(attrs, outs, target):
return topi.cuda.schedule_reduce(outs)


@schedule_concatenate.register(["cuda", "gpu"])
def schedule_concatenate_cuda(attrs, outs, target):
"""schedule concatenate for cuda"""
with target:
return topi.cuda.schedule_injective(outs)
@concatenate_strategy.register(["cuda", "gpu"])
def concatenate_strategy_cuda(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_concat(topi.transform.concatenate),
wrap_topi_schedule(topi.cuda.schedule_injective),
name="concatenate.cuda",
)
return strategy


@schedule_pool.register(["cuda", "gpu"])
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,6 +1781,15 @@ def _compute_scanop(attrs, inputs, _):
return _compute_scanop


def wrap_compute_concat(topi_compute):
"""Wrap concatenate topi compute"""

def _compute_concat(attrs, inputs, _):
return [topi_compute(inputs, attrs.axis)]

return _compute_concat


@override_native_generic_func("cumsum_strategy")
def cumsum_strategy(attrs, inputs, out_type, target):
"""cumsum generic strategy"""
Expand All @@ -1793,6 +1802,18 @@ def cumsum_strategy(attrs, inputs, out_type, target):
return strategy


@override_native_generic_func("concat_strategy")
def concatenate_strategy(attrs, inputs, out_type, target):
"""concatenate generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_concat(topi.concatenate),
wrap_topi_schedule(topi.generic.schedule_injective),
name="concatenate",
)
return strategy


@override_native_generic_func("cumprod_strategy")
def cumprod_strategy(attrs, inputs, out_type, target):
"""cumprod generic strategy"""
Expand Down
40 changes: 32 additions & 8 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging

import re
from tvm import topi
from tvm import topi, tir
from tvm.topi.x86.utils import target_has_vnni
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
Expand Down Expand Up @@ -48,13 +48,6 @@ def schedule_reduce_cpu(attrs, outs, target):
return topi.x86.schedule_reduce(outs)


@schedule_concatenate.register("cpu")
def schedule_concatenate_cpu(attrs, outs, target):
"""schedule concatenate op for x86"""
with target:
return topi.x86.schedule_concatenate(outs)


@schedule_pool.register("cpu")
def schedule_pool_cpu(attrs, outs, target):
"""schedule pooling ops for x86"""
Expand Down Expand Up @@ -741,3 +734,34 @@ def conv2d_winograd_without_weight_transfrom_strategy_cpu(attrs, inputs, out_typ
"Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout)
)
return strategy


@concatenate_strategy.register(["cpu"])
def concatenate_strategy_cpu(attrs, inputs, out_type, target):
"""concatenate x86 strategy"""
strategy = _op.OpStrategy()
use_only_old_concat = False
for inpt in inputs:
shape = inpt.shape
for i in shape:
if not isinstance(i, tir.expr.IntImm):
use_only_old_concat = True
break
if use_only_old_concat:
strategy.add_implementation(
wrap_compute_concat(topi.transform.concatenate),
wrap_topi_schedule(topi.x86.injective.schedule_concatenate),
name="concatenate.generic",
)
else:
strategy.add_implementation(
wrap_compute_concat(topi.x86.concatenate),
wrap_topi_schedule(topi.x86.schedule_concatenate_cpu),
name="concatenate.cpu",
)
strategy.add_implementation(
wrap_compute_concat(topi.transform.concatenate),
wrap_topi_schedule(topi.x86.injective.schedule_concatenate),
name="concatenate.generic",
)
return strategy
1 change: 1 addition & 0 deletions python/tvm/topi/x86/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@
from .scatter import *
from .group_conv2d import *
from .math_alter_op import *
from .concat import *
109 changes: 109 additions & 0 deletions python/tvm/topi/x86/concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"concatenate related operators"
from typing import Optional
import tvm
from tvm import te
import numpy as np
from ..utils import get_const_int, const_vector


def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0):
"""Join a sequence of arrays along an existing axis. Optimized for CPU exeution.
Parameters
----------
data : tuple of tvm.te.Tensor
The arrays to concatenate
axis : int, optional
The axis along which the arrays will be joined. Default is 0.
Returns
-------
ret : tvm.te.Tensor
"""

def gen_ir_1d(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf):
"""Custom conactenation execution."""
i_b = tvm.tir.ir_builder.create()
data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs]
out_buf = i_b.buffer_ptr(out_buf)
outers = i_b.buffer_ptr(in_outers_tensor)
cumsum = i_b.buffer_ptr(in_cumsum_tensor)
for i in range(len(data)):
with i_b.for_range(0, outers[i], name="j") as j:
out_buf[cumsum[i] + j] = data_bufs1[i][j]
return i_b.get()

def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer):
"""Common case of conactenation execution."""
i_b = tvm.tir.ir_builder.create()
data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs]
out_buf = i_b.buffer_ptr(out_buf)
outers = i_b.buffer_ptr(in_outers_tensor)
cumsum = i_b.buffer_ptr(in_cumsum_tensor)
if inner > 1:
with i_b.for_range(0, inner, name="inn", kind="parallel") as inn:
pos = inn * outer
for i in range(len(data)):
offset = inn * outers[i]
with i_b.for_range(0, outers[i], name="j") as j:
out_buf[pos + cumsum[i] + j] = data_bufs1[i][offset + j]
else:
for i in range(len(data)):
with i_b.for_range(0, outers[i], name="j", kind="parallel") as j:
out_buf[cumsum[i] + j] = data_bufs1[i][j]
return i_b.get()

if axis < 0:
axis += len(data[0].shape)
concat_axis_sizes = [int(t.shape[axis]) for t in data]
join_size = int(np.sum(concat_axis_sizes))
in_outers = [int(np.prod(i.shape[axis:])) for i in data]
in_outers_cumsum = [0, *np.cumsum(in_outers, dtype="int64")[0:-1]]
dtype = data[0].dtype
out_shape = data[0].shape[:axis] + [join_size] + data[0].shape[axis + 1 :]
in_outers_tensor = const_vector(in_outers)
in_cumsum_tensor = const_vector(in_outers_cumsum, name="cumsum")
right_val = np.prod(out_shape[axis:])
left_val = np.prod(out_shape[:axis])

if (
len(data[0].shape) == 1
or right_val == 1
or (left_val == 1 and axis == len(data[0].shape) - 1)
or (left_val == 1 and right_val == 1)
):
# badly parallelized case
return te.extern(
[out_shape],
list(data) + [in_outers_tensor, in_cumsum_tensor],
lambda ins, outs: gen_ir_1d(ins, ins[-2], ins[-1], outs[0]),
dtype=dtype,
name="concatenate_ext",
)

inner = get_const_int(int(left_val))
outer = get_const_int(int(right_val))
return te.extern(
[out_shape],
list(data) + [in_outers_tensor, in_cumsum_tensor],
lambda ins, outs: gen_ir(ins, ins[-2], ins[-1], outs[0], inner, outer),
dtype=dtype,
name="concatenate_ext",
)
42 changes: 36 additions & 6 deletions python/tvm/topi/x86/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@
# pylint: disable=invalid-name
"""x86 declaration and schedules."""
from tvm import te
from tvm.topi import tag
from tvm.tir import IntImm
from tvm.topi.generic.injective import (
schedule_injective_from_existing as schedule_injective_for_concat,
)
from ..utils import is_empty_shape


def schedule_injective_from_existing(sch, out):
"""Schedule for injective op from existing schedule.
Parameters
----------
sch: Schedule
The schedule to update.
out: Tensor
The tensor representing the injective op.
Returns
-------
sch: Schedule
Expand Down Expand Up @@ -61,13 +63,11 @@ def schedule_injective_from_existing(sch, out):

def schedule_injective(outs):
"""X86 schedule for injective op.
Parameters
----------
outs: Array of Tensor
The computation graph description of injective in the format
of an array of tensors.
Returns
-------
sch: Schedule
Expand All @@ -85,13 +85,11 @@ def schedule_injective(outs):

def schedule_concatenate(outs):
"""X86 schedule for concatenate op.
Parameters
----------
outs: Array of Tensor
The computation graph description of injective in the format
of an array of tensors.
Returns
-------
sch: Schedule
Expand Down Expand Up @@ -132,5 +130,37 @@ def vectorize(sch, tensor, vectorize_limit):
return s


def schedule_concatenate_cpu(outs):
"""X86 schedule for concatenate op.
Parameters
----------
outs: Array of Tensor
The computation graph description in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""

outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
scheduled_ops = []

def traverse(op):
if tag.is_injective(op.tag):
schedule_injective_for_concat(s, op.output(0))

for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)

for out in outs:
traverse(out.op)

return s


schedule_elemwise = schedule_injective
schedule_broadcast = schedule_injective
1 change: 0 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ RELAY_REGISTER_OP("concatenate")
.set_support_level(1)
.add_type_rel("Concatenate", ConcatenateRel<ConcatenateAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConcatenateLayout)
.set_attr<FTVMCompute>("FTVMCompute", ConcatenateCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

TVM_REGISTER_NODE_TYPE(StackAttrs);
Expand Down
30 changes: 29 additions & 1 deletion src/te/schedule/schedule_dataflow_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,29 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) {
std::vector<bool> changed(sch->stages.size(), false);
std::vector<Stmt> new_hybrid_body(sch->stages.size());
std::vector<bool> hybrid_changed(sch->stages.size(), false);
// (sshtin): this workaround allows to inline extern ops into their consumer.
// All inputs for extern op should not be inlined because inlining may happen
// before TE generation for particular extern op. That may lead to
// crash during lowering or building stages.
// The problem description:
// In case of operations fusing, arguments inlining
// prevents creation of ProducerNode for extern operation.
// Instead of the creation it is supposed to use operation argument as inlined buffer
// but extern_op TIR generation can be peformed after inlining procedure so
// newly generated TIR does not have reference to input data at all.
std::unordered_map<Operation, Operation> ext_ops;
for (size_t i = 0; i < sch->stages.size(); i++) {
Stage stage = sch->stages[i];
auto ext_op = stage->op.as<ExternOpNode>();
if (ext_op) {
auto inps = ext_op->InputTensors();
for (size_t ii = 0; ii < inps.size(); ++ii) {
if (ext_ops.find(inps[ii]->op) == ext_ops.end()) {
ext_ops[inps[ii]->op] = stage->op;
}
}
}
}
// inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage stage = sch->stages[i - 1];
Expand All @@ -525,8 +548,13 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) {
for (auto iv : compute->axis) {
args.push_back(iv->var);
}
if (ext_ops.find(stage->op) != ext_ops.end()) {
// sshtin: The extern op can try to get access to the input tensors as a raw data,
// that can lead to error in IR builder.
stage->attach_type = kGroupRoot;
continue;
}
ICHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output";

if (feature_extraction_mode && compute->attrs.count("const_matrix")) {
// Use constant value to replace access of const matrices.
// This produces wrong IR but is good enough for feature extraction purposes.
Expand Down
Loading

0 comments on commit e84f163

Please sign in to comment.