From f2d18e41b726f4c3fd6849773209daa16949e9bf Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Tue, 17 May 2022 18:01:12 +0300 Subject: [PATCH 01/21] [TE] Optimized version of concatenation layer 1. Concat implemented using extern_op 2. New tests added. 3. Workaround to allow inline extern_op-s with other layers. --- python/tvm/relay/op/_transform.py | 7 +- python/tvm/relay/op/strategy/generic.py | 50 ++++++- python/tvm/relay/op/strategy/x86.py | 2 +- python/tvm/topi/x86/__init__.py | 4 + python/tvm/topi/x86/concat.py | 134 ++++++++++++++++++ python/tvm/topi/x86/injective.py | 45 ++++-- src/relay/op/tensor/transform.cc | 2 +- src/te/schedule/schedule_dataflow_rewrite.cc | 24 +++- tests/python/relay/test_op_level1.py | 92 ++++++++++++ .../test_micro_model_library_format.py | 27 ++-- 10 files changed, 365 insertions(+), 22 deletions(-) create mode 100644 python/tvm/topi/x86/concat.py diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 0338035329fc..d87ee266f01d 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -68,7 +68,12 @@ # concatenate -_reg.register_schedule("concatenate", strategy.schedule_concatenate) +@_reg.register_compute("concatenate") +def compute_concat(attrs, inputs, output_type): + return [topi.concatenate(inputs, attrs.axis)] + + +_reg.register_strategy("concatenate", strategy.concatenate_strategy) # sliding_window @_reg.register_compute("sliding_window") diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index fa62af5f9fed..cfbfd472d8b4 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -19,7 +19,7 @@ import logging import re -from tvm import _ffi, ir, te, topi +from tvm import _ffi, ir, te, topi, tir from tvm.target import generic_func, override_native_generic_func from tvm.topi.utils import get_const_float, get_const_int, get_const_tuple, get_float_tuple @@ -1781,6 +1781,15 @@ def _compute_scanop(attrs, inputs, _): return _compute_scanop +def wrap_compute_concat(topi_compute): + """Wrap concatenate topi compute""" + + def _compute_concat(attrs, inputs, _): + return [topi_compute(inputs, attrs.axis)] + + return _compute_concat + + @override_native_generic_func("cumsum_strategy") def cumsum_strategy(attrs, inputs, out_type, target): """cumsum generic strategy""" @@ -1793,6 +1802,45 @@ def cumsum_strategy(attrs, inputs, out_type, target): return strategy +@override_native_generic_func("concat_strategy") +def concatenate_strategy(attrs, inputs, out_type, target): + """concatenate generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_concat(topi.concatenate), + wrap_topi_schedule(topi.generic.schedule_extern), + name="concatenate", + ) + return strategy + + +@concatenate_strategy.register(["cpu"]) +def concatenate_strategy_cpu(attrs, inputs, out_type, target): + """concatenate x86 strategy""" + strategy = _op.OpStrategy() + use_old_concat = False + for inpt in inputs: + shape = inpt.shape + for i in shape: + if isinstance(i, tir.expr.SizeVar): + if i.name == "any_dim": + use_old_concat = True + break + if use_old_concat: + strategy.add_implementation( + wrap_compute_concat(topi.transform.concatenate), + wrap_topi_schedule(topi.x86.injective.schedule_concatenate), + name="concatenate.generic", + ) + else: + strategy.add_implementation( + wrap_compute_concat(topi.x86.concatenate), + wrap_topi_schedule(topi.x86.schedule_concatenate_cpu), + name="concatenate.cpu", + ) + return strategy + + @override_native_generic_func("cumprod_strategy") def cumprod_strategy(attrs, inputs, out_type, target): """cumprod generic strategy""" diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 0beb99e4f7db..877a4ef30462 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -52,7 +52,7 @@ def schedule_reduce_cpu(attrs, outs, target): def schedule_concatenate_cpu(attrs, outs, target): """schedule concatenate op for x86""" with target: - return topi.x86.schedule_concatenate(outs) + return topi.transform.schedule_concatenate(outs) @schedule_pool.register("cpu") diff --git a/python/tvm/topi/x86/__init__.py b/python/tvm/topi/x86/__init__.py index 34a5e0362d87..36ac432d2c94 100644 --- a/python/tvm/topi/x86/__init__.py +++ b/python/tvm/topi/x86/__init__.py @@ -42,4 +42,8 @@ from .dense_alter_op import * from .scatter import * from .group_conv2d import * +<<<<<<< HEAD from .math_alter_op import * +======= +from .concat import * +>>>>>>> [TE] Optimized version of concatenation layer diff --git a/python/tvm/topi/x86/concat.py b/python/tvm/topi/x86/concat.py new file mode 100644 index 000000000000..5c51c5598074 --- /dev/null +++ b/python/tvm/topi/x86/concat.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name +"concatenate related operators" +from typing import Optional +import tvm +from tvm import te +import numpy as np +from ..utils import get_const_int, const_vector + + +def _concat(a_tuple, axis=0): + """Join a sequence of arrays along an existing axis. + + Parameters + ---------- + a_tuple : tuple of tvm.te.Tensor + The arrays to concatenate + + axis : int, optional + The axis along which the arrays will be joined. Default is 0. + + Returns + ------- + ret : tvm.te.Tensor + """ + + def gen_ir_1D(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf): + ib = tvm.tir.ir_builder.create() + data_bufs1 = [ib.buffer_ptr(data_buf) for data_buf in data_bufs] + out_buf = ib.buffer_ptr(out_buf) + outers = ib.buffer_ptr(in_outers_tensor) + cumsum = ib.buffer_ptr(in_cumsum_tensor) + for i in range(len(a_tuple)): + with ib.for_range(0, outers[i], name="j") as j: + out_buf[cumsum[i] + j] = data_bufs1[i][j] + return ib.get() + + def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer): + ib = tvm.tir.ir_builder.create() + data_bufs1 = [ib.buffer_ptr(data_buf) for data_buf in data_bufs] + out_buf = ib.buffer_ptr(out_buf) + outers = ib.buffer_ptr(in_outers_tensor) + cumsum = ib.buffer_ptr(in_cumsum_tensor) + if inner > 1: + with ib.for_range(0, inner, name="inn", kind="parallel") as inn: + pos = inn * outer + for i in range(len(a_tuple)): + offset = inn * outers[i] + with ib.for_range(0, outers[i], name="j") as j: + out_buf[pos + cumsum[i] + j] = data_bufs1[i][offset + j] + else: + for i in range(len(a_tuple)): + with ib.for_range(0, outers[i], name="j", kind="parallel") as j: + out_buf[cumsum[i] + j] = data_bufs1[i][j] + return ib.get() + + if axis < 0: + axis += len(a_tuple[0].shape) + concat_axis_sizes = [int(t.shape[axis]) for t in a_tuple] + join_size = int(np.sum(concat_axis_sizes)) + in_outers = [int(np.prod(i.shape[axis:])) for i in a_tuple] + in_outers_cumsum = [0, *np.cumsum(in_outers, dtype="int64")[0:-1]] + dtype = a_tuple[0].dtype + out_shape = a_tuple[0].shape[:axis] + [join_size] + a_tuple[0].shape[axis + 1 :] + in_outers_tensor = const_vector(in_outers) + in_cumsum_tensor = const_vector(in_outers_cumsum, name="cumsum") + # check if dimensions tail is (... , axis, 1, ... , 1) + if len(out_shape[axis + 1 :]) == 0: + rightVal = out_shape[axis] + else: + rightVal = np.prod(out_shape[axis + 1 :]) + # check if dimensions tail is (1 , ... , 1, axis, ...) + if len(out_shape[:axis]) == 0: + leftVal = out_shape[axis] + else: + leftVal = np.prod(out_shape[:axis]) + + if ( + len(a_tuple[0].shape) == 1 + or rightVal == 1 + or (leftVal == 1 and axis == len(a_tuple[0].shape) - 1) + or (leftVal == 1 and rightVal == 1) + ): + # badly parallelized case + return te.extern( + [out_shape], + list(a_tuple) + [in_outers_tensor, in_cumsum_tensor], + lambda ins, outs: gen_ir_1D(ins, ins[-2], ins[-1], outs[0]), + dtype=dtype, + name="concatenate_ext", + ) + + inner = get_const_int(int(np.prod(out_shape[:axis]))) + outer = get_const_int(int(np.prod(out_shape[axis:]))) + return te.extern( + [out_shape], + list(a_tuple) + [in_outers_tensor, in_cumsum_tensor], + lambda ins, outs: gen_ir(ins, ins[-2], ins[-1], outs[0], inner, outer), + dtype=dtype, + name="concatenate_ext", + ) + + +def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0): + """Join a sequence of arrays along an existing axis. + + Parameters + ---------- + data : tuple of tvm.te.Tensor + The arrays to concatenate + + axis : int, optional + The axis along which the arrays will be joined. Default is 0. + + Returns + ------- + ret : tvm.te.Tensor + """ + return _concat(data, axis=axis) diff --git a/python/tvm/topi/x86/injective.py b/python/tvm/topi/x86/injective.py index 6492b78d6037..68f778727d54 100644 --- a/python/tvm/topi/x86/injective.py +++ b/python/tvm/topi/x86/injective.py @@ -17,20 +17,20 @@ # pylint: disable=invalid-name """x86 declaration and schedules.""" from tvm import te +from tvm.topi import tag from tvm.tir import IntImm +from tvm.topi.generic.injective import schedule_injective_from_existing from ..utils import is_empty_shape -def schedule_injective_from_existing(sch, out): +def schedule_injective_from_existing_ref(sch, out): """Schedule for injective op from existing schedule. - Parameters ---------- sch: Schedule The schedule to update. out: Tensor The tensor representing the injective op. - Returns ------- sch: Schedule @@ -60,14 +60,12 @@ def schedule_injective_from_existing(sch, out): def schedule_injective(outs): - """X86 schedule for injective op. - + """X86 reference schedule for injective op. Parameters ---------- outs: Array of Tensor The computation graph description of injective in the format of an array of tensors. - Returns ------- sch: Schedule @@ -79,19 +77,17 @@ def schedule_injective(outs): te.schedule.AutoInlineInjective(s) if not is_empty_shape(x.shape): - schedule_injective_from_existing(s, x) + schedule_injective_from_existing_ref(s, x) return s def schedule_concatenate(outs): """X86 schedule for concatenate op. - Parameters ---------- outs: Array of Tensor The computation graph description of injective in the format of an array of tensors. - Returns ------- sch: Schedule @@ -132,5 +128,36 @@ def vectorize(sch, tensor, vectorize_limit): return s +def schedule_concatenate_cpu(outs): + """X86 schedule for concatenate op. + Parameters + ---------- + outs: Array of Tensor + The computation graph description of injective in the format + of an array of tensors. + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def traverse(op): + if tag.is_injective(op.tag): + schedule_injective_from_existing(s, op.output(0)) + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + + for out in outs: + traverse(out.op) + + return s + + schedule_elemwise = schedule_injective schedule_broadcast = schedule_injective diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e888eccc2b1c..2dac6f6bdf52 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -346,7 +346,7 @@ RELAY_REGISTER_OP("concatenate") .set_support_level(1) .add_type_rel("Concatenate", ConcatenateRel) .set_attr("FInferCorrectLayout", ConcatenateLayout) - .set_attr("FTVMCompute", ConcatenateCompute) + // .set_attr("FTVMCompute", ConcatenateCompute) .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(StackAttrs); diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 2b30055c4f42..73c2c4b3fc8a 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -511,6 +511,23 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) { std::vector changed(sch->stages.size(), false); std::vector new_hybrid_body(sch->stages.size()); std::vector hybrid_changed(sch->stages.size(), false); + // (sshtin): this workaround allows to inline extern ops. + // All inputs for extern op should not be inlined because inlining happens + // before generation of TE script for particular extern op. That may lead to + // crash during lowering or building stages. + std::unordered_map ext_ops; + for (size_t i = 0; i < sch->stages.size(); i++) { + Stage stage = sch->stages[i]; + auto ext_op = stage->op.as(); + if (ext_op) { + auto inps = ext_op->InputTensors(); + for (size_t ii = 0; ii < inps.size(); ++ii) { + if (ext_ops.find(inps[ii]->op) == ext_ops.end()) { + ext_ops[inps[ii]->op] = stage->op; + } + } + } + } // inline all the ops for (size_t i = sch->stages.size(); i != 0; --i) { Stage stage = sch->stages[i - 1]; @@ -525,8 +542,13 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) { for (auto iv : compute->axis) { args.push_back(iv->var); } + if (ext_ops.find(stage->op) != ext_ops.end()) { + // sshtin: The extern op can try to get access to the input tensors as a row data, + // that can lead to error in TE scripts. + stage->attach_type = kGroupRoot; + continue; + } ICHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output"; - if (feature_extraction_mode && compute->attrs.count("const_matrix")) { // Use constant value to replace access of const matrices. // This produces wrong IR but is good enough for feature extraction purposes. diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 170850809ad5..8cd9337baa80 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -431,6 +431,98 @@ def test_batch_norm(): ) +def do_concat_test(shapes, t_shape, dtype, axis, dev, target): + varsToConcat = [] + inputData = [] + pos = 0 + for s in shapes: + varsToConcat.append(relay.var("x{}".format(pos), shape=s)) + inputData.append(np.random.rand(*s).astype(dtype)) + pos += 1 + t = relay.var("z", shape=t_shape, dtype=dtype) + z = relay.concatenate(varsToConcat, axis=axis) + z = relay.add(z, t) + params = varsToConcat + params.append(t) + func = relay.Function(params, z) + t_data = np.random.uniform(low=-10, high=10, size=t_shape).astype(dtype) + ref_res = np.concatenate((tuple(inputData)), axis=axis) + t_data + mod = tvm.IRModule.from_expr(func) + + executor = relay.create_executor("graph", mod=mod, device=dev, target=target) + op_res1 = executor.evaluate()(*inputData, t_data) + + tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=0.000001) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + *inputData, t_data + ) + tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=0.000001) + + +@tvm.testing.uses_gpu +def test_concatenate1(): + for target, dev in tvm.testing.enabled_targets(): + if target != "llvm": + continue + np.random.seed(471) + maxNumDimensions = 6 + shape = [4, 32, 16, 1, 31, 20, 21, 8, 28, 7] # just randomly selected 10 numbers + for dtype in ["float32"]: + for dimsNum in range(1, maxNumDimensions): + np.random.shuffle(shape) + for axis in range(0, dimsNum): # range should be (-dimsNum + 1, dimsNum) + numToConcat = np.random.uniform(low=2, high=10, size=(1)).astype("int64")[0] + shapes = [] + # the code below to normalize axes index. For some reasons tvm notifies about error if the axis is negative + normalizedAxis = axis + if axis < 0: + normalizedAxis += dimsNum + finalSize = 0 + for i in range(0, numToConcat): + shp = tuple(shape[:dimsNum]) + finalSize += shape[(i % len(shape))] + shapes.append( + shp[:normalizedAxis] + + tuple([shape[(i % len(shape))]]) + + shp[normalizedAxis + 1 :] + ) + t_shape = shp[:normalizedAxis] + tuple([finalSize]) + shp[normalizedAxis + 1 :] + do_concat_test(shapes, t_shape, dtype, axis, dev, target) + + +@tvm.testing.uses_gpu +def test_concatenate2(): + # test to cover cases (1, .. , x, 1, .. , 1) + for target, dev in tvm.testing.enabled_targets(): + if target != "llvm": + continue + np.random.seed(13) + maxNumDimensions = 6 + shape = [8, 3, 25, 33, 12, 29, 5, 11, 29, 11] # just randomly selected 10 numbers + ind = 0 + for dtype in ["float32"]: + for dimsNum in range(2, maxNumDimensions): + np.random.shuffle(shape) + for axis in range(-dimsNum + 1, dimsNum): # range should be (-dimsNum + 1, dimsNum) + numToConcat = np.random.uniform(low=2, high=10, size=(1)).astype("int64")[0] + shapes = [] + # the code below to normalize axes index. For some reasons tvm notifies about error if the axis is negative + normalizedAxis = axis + if axis < 0: + normalizedAxis += dimsNum + finalSize = 0 + for i in range(0, numToConcat): + axisVal = [1] * dimsNum + axisVal[axis] = shape[(ind % len(shape))] + ind += 1 + finalSize += axisVal[axis] + shapes.append(tuple(axisVal)) + temp = [1] * dimsNum + temp[axis] = finalSize + t_shape = tuple(temp) + do_concat_test(shapes, t_shape, dtype, axis, dev, target) + + def test_batch_norm_fold_const(): axis = 1 dtype = "float32" diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index ad054479fd7b..52bc25277709 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -22,6 +22,7 @@ import numpy import pytest +import platform import tvm import tvm.relay @@ -418,14 +419,24 @@ def test_export_byoc_c_module(): with tf.extractfile("./metadata.json") as f: metadata = json.load(f) main_md = metadata["memory"]["functions"]["main"] - assert main_md == [ - { - "constants_size_bytes": 0, - "device": 1, - "io_size_bytes": 4800, - "workspace_size_bytes": 800, - } - ] + if platform.architecture()[0] == "64bit": + assert main_md == [ + { + "constants_size_bytes": 0, + "device": 1, + "io_size_bytes": 4800, + "workspace_size_bytes": 3664, + } + ] + else: + assert main_md == [ + { + "constants_size_bytes": 0, + "device": 1, + "io_size_bytes": 4800, + "workspace_size_bytes": 3648, + } + ] if __name__ == "__main__": From cd1fbd8480369e8bcfd8d0b7525b137686edc05c Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Tue, 17 May 2022 19:28:10 +0300 Subject: [PATCH 02/21] *test fix --- tests/python/unittest/test_micro_model_library_format.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 52bc25277709..d707e6b4646b 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -425,7 +425,7 @@ def test_export_byoc_c_module(): "constants_size_bytes": 0, "device": 1, "io_size_bytes": 4800, - "workspace_size_bytes": 3664, + "workspace_size_bytes": 1264, } ] else: @@ -434,7 +434,7 @@ def test_export_byoc_c_module(): "constants_size_bytes": 0, "device": 1, "io_size_bytes": 4800, - "workspace_size_bytes": 3648, + "workspace_size_bytes": 1248, } ] From 7c37a4bedafd3797543febd94d191df660c52851 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Wed, 18 May 2022 07:20:48 +0300 Subject: [PATCH 03/21] test_any.py fix. --- python/tvm/relay/op/strategy/generic.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index cfbfd472d8b4..84378ba5a447 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1822,9 +1822,8 @@ def concatenate_strategy_cpu(attrs, inputs, out_type, target): for inpt in inputs: shape = inpt.shape for i in shape: - if isinstance(i, tir.expr.SizeVar): - if i.name == "any_dim": - use_old_concat = True + if not isinstance(i, tir.expr.IntImm): + use_old_concat = True break if use_old_concat: strategy.add_implementation( From fefb4af8cb5e538f4447da5eadf6c7c96a377167 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Wed, 18 May 2022 12:11:15 +0300 Subject: [PATCH 04/21] test_forward.py from tensorflow fix. --- python/tvm/topi/x86/concat.py | 2 +- tests/python/relay/test_op_level1.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/x86/concat.py b/python/tvm/topi/x86/concat.py index 5c51c5598074..aaac9ebf46c2 100644 --- a/python/tvm/topi/x86/concat.py +++ b/python/tvm/topi/x86/concat.py @@ -83,7 +83,7 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer) if len(out_shape[axis + 1 :]) == 0: rightVal = out_shape[axis] else: - rightVal = np.prod(out_shape[axis + 1 :]) + rightVal = np.prod(out_shape[axis :]) # check if dimensions tail is (1 , ... , 1, axis, ...) if len(out_shape[:axis]) == 0: leftVal = out_shape[axis] diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 8cd9337baa80..26e9dfd8fc66 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -522,6 +522,18 @@ def test_concatenate2(): t_shape = tuple(temp) do_concat_test(shapes, t_shape, dtype, axis, dev, target) +@tvm.testing.uses_gpu +def test_concatenate3(): + for target, dev in tvm.testing.enabled_targets(): + if target != "llvm": + continue + np.random.seed(477) + for dtype in ["float32"]: + axis = -2 + ending = 1 + shapes = [[3,2,1,ending], [3,2,1,ending]] + t_shape = [3,2,2,ending] + do_concat_test(shapes, t_shape, dtype, axis, dev, target) def test_batch_norm_fold_const(): axis = 1 From ae6400275abc92208b0b9d219a0fb08008fd5e25 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Wed, 18 May 2022 12:41:23 +0300 Subject: [PATCH 05/21] lint fix. --- python/tvm/topi/x86/concat.py | 2 +- tests/python/relay/test_op_level1.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/x86/concat.py b/python/tvm/topi/x86/concat.py index aaac9ebf46c2..59582b31d11e 100644 --- a/python/tvm/topi/x86/concat.py +++ b/python/tvm/topi/x86/concat.py @@ -83,7 +83,7 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer) if len(out_shape[axis + 1 :]) == 0: rightVal = out_shape[axis] else: - rightVal = np.prod(out_shape[axis :]) + rightVal = np.prod(out_shape[axis:]) # check if dimensions tail is (1 , ... , 1, axis, ...) if len(out_shape[:axis]) == 0: leftVal = out_shape[axis] diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 26e9dfd8fc66..0a5c7153f5f5 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -522,6 +522,7 @@ def test_concatenate2(): t_shape = tuple(temp) do_concat_test(shapes, t_shape, dtype, axis, dev, target) + @tvm.testing.uses_gpu def test_concatenate3(): for target, dev in tvm.testing.enabled_targets(): @@ -531,10 +532,11 @@ def test_concatenate3(): for dtype in ["float32"]: axis = -2 ending = 1 - shapes = [[3,2,1,ending], [3,2,1,ending]] - t_shape = [3,2,2,ending] + shapes = [[3, 2, 1, ending], [3, 2, 1, ending]] + t_shape = [3, 2, 2, ending] do_concat_test(shapes, t_shape, dtype, axis, dev, target) + def test_batch_norm_fold_const(): axis = 1 dtype = "float32" From cab5fbb5edcfc608b67b5461e8daeea4fcb53ae0 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Thu, 19 May 2022 10:54:12 +0300 Subject: [PATCH 06/21] Fixes after code review. --- python/tvm/relay/op/strategy/generic.py | 26 ----- python/tvm/relay/op/strategy/x86.py | 31 ++++++ python/tvm/topi/x86/concat.py | 17 +-- src/relay/op/tensor/transform.cc | 1 - tests/python/relay/test_op_level1.py | 137 +++++++++++------------- 5 files changed, 99 insertions(+), 113 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 84378ba5a447..dacc85580111 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1814,32 +1814,6 @@ def concatenate_strategy(attrs, inputs, out_type, target): return strategy -@concatenate_strategy.register(["cpu"]) -def concatenate_strategy_cpu(attrs, inputs, out_type, target): - """concatenate x86 strategy""" - strategy = _op.OpStrategy() - use_old_concat = False - for inpt in inputs: - shape = inpt.shape - for i in shape: - if not isinstance(i, tir.expr.IntImm): - use_old_concat = True - break - if use_old_concat: - strategy.add_implementation( - wrap_compute_concat(topi.transform.concatenate), - wrap_topi_schedule(topi.x86.injective.schedule_concatenate), - name="concatenate.generic", - ) - else: - strategy.add_implementation( - wrap_compute_concat(topi.x86.concatenate), - wrap_topi_schedule(topi.x86.schedule_concatenate_cpu), - name="concatenate.cpu", - ) - return strategy - - @override_native_generic_func("cumprod_strategy") def cumprod_strategy(attrs, inputs, out_type, target): """cumprod generic strategy""" diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 877a4ef30462..6748b54e99cc 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -741,3 +741,34 @@ def conv2d_winograd_without_weight_transfrom_strategy_cpu(attrs, inputs, out_typ "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) ) return strategy + + +@concatenate_strategy.register(["cpu"]) +def concatenate_strategy_cpu(attrs, inputs, out_type, target): + """concatenate x86 strategy""" + strategy = _op.OpStrategy() + use_only_old_concat = False + for inpt in inputs: + shape = inpt.shape + for i in shape: + if not isinstance(i, tir.expr.IntImm): + use_only_old_concat = True + break + if use_only_old_concat: + strategy.add_implementation( + wrap_compute_concat(topi.transform.concatenate), + wrap_topi_schedule(topi.x86.injective.schedule_concatenate), + name="concatenate.generic", + ) + else: + strategy.add_implementation( + wrap_compute_concat(topi.x86.concatenate), + wrap_topi_schedule(topi.x86.schedule_concatenate_cpu), + name="concatenate.cpu", + ) + strategy.add_implementation( + wrap_compute_concat(topi.transform.concatenate), + wrap_topi_schedule(topi.x86.injective.schedule_concatenate), + name="concatenate.generic", + ) + return strategy diff --git a/python/tvm/topi/x86/concat.py b/python/tvm/topi/x86/concat.py index 59582b31d11e..e965acdcfe54 100644 --- a/python/tvm/topi/x86/concat.py +++ b/python/tvm/topi/x86/concat.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name "concatenate related operators" from typing import Optional import tvm @@ -79,16 +78,8 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer) out_shape = a_tuple[0].shape[:axis] + [join_size] + a_tuple[0].shape[axis + 1 :] in_outers_tensor = const_vector(in_outers) in_cumsum_tensor = const_vector(in_outers_cumsum, name="cumsum") - # check if dimensions tail is (... , axis, 1, ... , 1) - if len(out_shape[axis + 1 :]) == 0: - rightVal = out_shape[axis] - else: - rightVal = np.prod(out_shape[axis:]) - # check if dimensions tail is (1 , ... , 1, axis, ...) - if len(out_shape[:axis]) == 0: - leftVal = out_shape[axis] - else: - leftVal = np.prod(out_shape[:axis]) + rightVal = np.prod(out_shape[axis:]) + leftVal = np.prod(out_shape[:axis]) if ( len(a_tuple[0].shape) == 1 @@ -105,8 +96,8 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer) name="concatenate_ext", ) - inner = get_const_int(int(np.prod(out_shape[:axis]))) - outer = get_const_int(int(np.prod(out_shape[axis:]))) + inner = get_const_int(int(leftVal)) + outer = get_const_int(int(rightVal)) return te.extern( [out_shape], list(a_tuple) + [in_outers_tensor, in_cumsum_tensor], diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 2dac6f6bdf52..57bf9f36def9 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -346,7 +346,6 @@ RELAY_REGISTER_OP("concatenate") .set_support_level(1) .add_type_rel("Concatenate", ConcatenateRel) .set_attr("FInferCorrectLayout", ConcatenateLayout) - // .set_attr("FTVMCompute", ConcatenateCompute) .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(StackAttrs); diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 0a5c7153f5f5..f4afc9e90562 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -459,82 +459,73 @@ def do_concat_test(shapes, t_shape, dtype, axis, dev, target): tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=0.000001) -@tvm.testing.uses_gpu -def test_concatenate1(): - for target, dev in tvm.testing.enabled_targets(): - if target != "llvm": - continue - np.random.seed(471) - maxNumDimensions = 6 - shape = [4, 32, 16, 1, 31, 20, 21, 8, 28, 7] # just randomly selected 10 numbers - for dtype in ["float32"]: - for dimsNum in range(1, maxNumDimensions): - np.random.shuffle(shape) - for axis in range(0, dimsNum): # range should be (-dimsNum + 1, dimsNum) - numToConcat = np.random.uniform(low=2, high=10, size=(1)).astype("int64")[0] - shapes = [] - # the code below to normalize axes index. For some reasons tvm notifies about error if the axis is negative - normalizedAxis = axis - if axis < 0: - normalizedAxis += dimsNum - finalSize = 0 - for i in range(0, numToConcat): - shp = tuple(shape[:dimsNum]) - finalSize += shape[(i % len(shape))] - shapes.append( - shp[:normalizedAxis] - + tuple([shape[(i % len(shape))]]) - + shp[normalizedAxis + 1 :] - ) - t_shape = shp[:normalizedAxis] + tuple([finalSize]) + shp[normalizedAxis + 1 :] - do_concat_test(shapes, t_shape, dtype, axis, dev, target) +@tvm.testing.parametrize_targets("llvm") +def test_concatenate1(target, dev): + np.random.seed(471) + maxNumDimensions = 6 + shape = [4, 32, 16, 1, 31, 20, 21, 8, 28, 7] # just randomly selected 10 numbers + for dtype in ["float32"]: + for dimsNum in range(1, maxNumDimensions): + np.random.shuffle(shape) + for axis in range(0, dimsNum): # range should be (-dimsNum + 1, dimsNum) + numToConcat = np.random.uniform(low=2, high=10, size=(1)).astype("int64")[0] + shapes = [] + # the code below to normalize axes index. For some reasons tvm notifies about error if the axis is negative + normalizedAxis = axis + if axis < 0: + normalizedAxis += dimsNum + finalSize = 0 + for i in range(0, numToConcat): + shp = tuple(shape[:dimsNum]) + finalSize += shape[(i % len(shape))] + shapes.append( + shp[:normalizedAxis] + + tuple([shape[(i % len(shape))]]) + + shp[normalizedAxis + 1 :] + ) + t_shape = shp[:normalizedAxis] + tuple([finalSize]) + shp[normalizedAxis + 1 :] + do_concat_test(shapes, t_shape, dtype, axis, dev, target) -@tvm.testing.uses_gpu -def test_concatenate2(): +@tvm.testing.parametrize_targets("llvm") +def test_concatenate2(target, dev): # test to cover cases (1, .. , x, 1, .. , 1) - for target, dev in tvm.testing.enabled_targets(): - if target != "llvm": - continue - np.random.seed(13) - maxNumDimensions = 6 - shape = [8, 3, 25, 33, 12, 29, 5, 11, 29, 11] # just randomly selected 10 numbers - ind = 0 - for dtype in ["float32"]: - for dimsNum in range(2, maxNumDimensions): - np.random.shuffle(shape) - for axis in range(-dimsNum + 1, dimsNum): # range should be (-dimsNum + 1, dimsNum) - numToConcat = np.random.uniform(low=2, high=10, size=(1)).astype("int64")[0] - shapes = [] - # the code below to normalize axes index. For some reasons tvm notifies about error if the axis is negative - normalizedAxis = axis - if axis < 0: - normalizedAxis += dimsNum - finalSize = 0 - for i in range(0, numToConcat): - axisVal = [1] * dimsNum - axisVal[axis] = shape[(ind % len(shape))] - ind += 1 - finalSize += axisVal[axis] - shapes.append(tuple(axisVal)) - temp = [1] * dimsNum - temp[axis] = finalSize - t_shape = tuple(temp) - do_concat_test(shapes, t_shape, dtype, axis, dev, target) - - -@tvm.testing.uses_gpu -def test_concatenate3(): - for target, dev in tvm.testing.enabled_targets(): - if target != "llvm": - continue - np.random.seed(477) - for dtype in ["float32"]: - axis = -2 - ending = 1 - shapes = [[3, 2, 1, ending], [3, 2, 1, ending]] - t_shape = [3, 2, 2, ending] - do_concat_test(shapes, t_shape, dtype, axis, dev, target) + np.random.seed(13) + maxNumDimensions = 6 + shape = [8, 3, 25, 33, 12, 29, 5, 11, 29, 11] # just randomly selected 10 numbers + ind = 0 + for dtype in ["float32"]: + for dimsNum in range(2, maxNumDimensions): + np.random.shuffle(shape) + for axis in range(-dimsNum + 1, dimsNum): # range should be (-dimsNum + 1, dimsNum) + numToConcat = np.random.uniform(low=2, high=10, size=(1)).astype("int64")[0] + shapes = [] + # the code below to normalize axes index. For some reasons tvm notifies about error if the axis is negative + normalizedAxis = axis + if axis < 0: + normalizedAxis += dimsNum + finalSize = 0 + for i in range(0, numToConcat): + axisVal = [1] * dimsNum + axisVal[axis] = shape[(ind % len(shape))] + ind += 1 + finalSize += axisVal[axis] + shapes.append(tuple(axisVal)) + temp = [1] * dimsNum + temp[axis] = finalSize + t_shape = tuple(temp) + do_concat_test(shapes, t_shape, dtype, axis, dev, target) + + +@tvm.testing.parametrize_targets("llvm") +def test_concatenate3(target, dev): + np.random.seed(477) + for dtype in ["float32"]: + axis = -2 + ending = 1 + shapes = [[3, 2, 1, ending], [3, 2, 1, ending]] + t_shape = [3, 2, 2, ending] + do_concat_test(shapes, t_shape, dtype, axis, dev, target) def test_batch_norm_fold_const(): From 1a01771d8de035e1f2ea8300bc5360e8edf3771f Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Thu, 19 May 2022 14:54:49 +0300 Subject: [PATCH 07/21] New comment added. --- src/te/schedule/schedule_dataflow_rewrite.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 73c2c4b3fc8a..c6f530ddfcc9 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -515,6 +515,12 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) { // All inputs for extern op should not be inlined because inlining happens // before generation of TE script for particular extern op. That may lead to // crash during lowering or building stages. + // The problem description: + // In case of operations fuzing arguments inlining + // prevents creation of ProducerNode for extern operation. + // Instead of the creation it supposed to use operation argument as inlined buffer + // but extern_op TIR generation can be peformed after inlining procedure so + // newly generated TIR does not have reference to input data at all. std::unordered_map ext_ops; for (size_t i = 0; i < sch->stages.size(); i++) { Stage stage = sch->stages[i]; From e000d2753c7e673fb8e78f2c87723b13d56b89b5 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Thu, 19 May 2022 15:28:28 +0300 Subject: [PATCH 08/21] Lint fix. --- python/tvm/topi/x86/concat.py | 50 +++++++++++++++++------------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/python/tvm/topi/x86/concat.py b/python/tvm/topi/x86/concat.py index e965acdcfe54..9f44759c581a 100644 --- a/python/tvm/topi/x86/concat.py +++ b/python/tvm/topi/x86/concat.py @@ -38,35 +38,35 @@ def _concat(a_tuple, axis=0): ret : tvm.te.Tensor """ - def gen_ir_1D(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf): - ib = tvm.tir.ir_builder.create() - data_bufs1 = [ib.buffer_ptr(data_buf) for data_buf in data_bufs] - out_buf = ib.buffer_ptr(out_buf) - outers = ib.buffer_ptr(in_outers_tensor) - cumsum = ib.buffer_ptr(in_cumsum_tensor) + def gen_ir_1d(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf): + i_b = tvm.tir.ir_builder.create() + data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs] + out_buf = i_b.buffer_ptr(out_buf) + outers = i_b.buffer_ptr(in_outers_tensor) + cumsum = i_b.buffer_ptr(in_cumsum_tensor) for i in range(len(a_tuple)): - with ib.for_range(0, outers[i], name="j") as j: + with i_b.for_range(0, outers[i], name="j") as j: out_buf[cumsum[i] + j] = data_bufs1[i][j] - return ib.get() + return i_b.get() def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer): - ib = tvm.tir.ir_builder.create() - data_bufs1 = [ib.buffer_ptr(data_buf) for data_buf in data_bufs] - out_buf = ib.buffer_ptr(out_buf) - outers = ib.buffer_ptr(in_outers_tensor) - cumsum = ib.buffer_ptr(in_cumsum_tensor) + i_b = tvm.tir.ir_builder.create() + data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs] + out_buf = i_b.buffer_ptr(out_buf) + outers = i_b.buffer_ptr(in_outers_tensor) + cumsum = i_b.buffer_ptr(in_cumsum_tensor) if inner > 1: - with ib.for_range(0, inner, name="inn", kind="parallel") as inn: + with i_b.for_range(0, inner, name="inn", kind="parallel") as inn: pos = inn * outer for i in range(len(a_tuple)): offset = inn * outers[i] - with ib.for_range(0, outers[i], name="j") as j: + with i_b.for_range(0, outers[i], name="j") as j: out_buf[pos + cumsum[i] + j] = data_bufs1[i][offset + j] else: for i in range(len(a_tuple)): - with ib.for_range(0, outers[i], name="j", kind="parallel") as j: + with i_b.for_range(0, outers[i], name="j", kind="parallel") as j: out_buf[cumsum[i] + j] = data_bufs1[i][j] - return ib.get() + return i_b.get() if axis < 0: axis += len(a_tuple[0].shape) @@ -78,26 +78,26 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer) out_shape = a_tuple[0].shape[:axis] + [join_size] + a_tuple[0].shape[axis + 1 :] in_outers_tensor = const_vector(in_outers) in_cumsum_tensor = const_vector(in_outers_cumsum, name="cumsum") - rightVal = np.prod(out_shape[axis:]) - leftVal = np.prod(out_shape[:axis]) + right_val = np.prod(out_shape[axis:]) + left_val = np.prod(out_shape[:axis]) if ( len(a_tuple[0].shape) == 1 - or rightVal == 1 - or (leftVal == 1 and axis == len(a_tuple[0].shape) - 1) - or (leftVal == 1 and rightVal == 1) + or right_val == 1 + or (left_val == 1 and axis == len(a_tuple[0].shape) - 1) + or (left_val == 1 and right_val == 1) ): # badly parallelized case return te.extern( [out_shape], list(a_tuple) + [in_outers_tensor, in_cumsum_tensor], - lambda ins, outs: gen_ir_1D(ins, ins[-2], ins[-1], outs[0]), + lambda ins, outs: gen_ir_1d(ins, ins[-2], ins[-1], outs[0]), dtype=dtype, name="concatenate_ext", ) - inner = get_const_int(int(leftVal)) - outer = get_const_int(int(rightVal)) + inner = get_const_int(int(left_val)) + outer = get_const_int(int(right_val)) return te.extern( [out_shape], list(a_tuple) + [in_outers_tensor, in_cumsum_tensor], From a350af1d0855ac868cb9e1cdc280f2a4e3d70921 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Thu, 19 May 2022 15:41:08 +0300 Subject: [PATCH 09/21] Another lint fix. --- python/tvm/relay/op/strategy/generic.py | 2 +- python/tvm/relay/op/strategy/x86.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index dacc85580111..d637ceccd25c 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -19,7 +19,7 @@ import logging import re -from tvm import _ffi, ir, te, topi, tir +from tvm import _ffi, ir, te, topi from tvm.target import generic_func, override_native_generic_func from tvm.topi.utils import get_const_float, get_const_int, get_const_tuple, get_float_tuple diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 6748b54e99cc..95c712a2c2a9 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -19,7 +19,7 @@ import logging import re -from tvm import topi +from tvm import topi, tir from tvm.topi.x86.utils import target_has_vnni from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.te import SpecializedCondition From b0d742de1393c0df97fd921502b85dd429b9a4d9 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Sat, 21 May 2022 09:31:48 +0300 Subject: [PATCH 10/21] Comments added. --- python/tvm/topi/x86/concat.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/x86/concat.py b/python/tvm/topi/x86/concat.py index 9f44759c581a..e981e694bd28 100644 --- a/python/tvm/topi/x86/concat.py +++ b/python/tvm/topi/x86/concat.py @@ -39,6 +39,7 @@ def _concat(a_tuple, axis=0): """ def gen_ir_1d(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf): + """Custom conactenation execution.""" i_b = tvm.tir.ir_builder.create() data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs] out_buf = i_b.buffer_ptr(out_buf) @@ -50,6 +51,7 @@ def gen_ir_1d(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf): return i_b.get() def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer): + """Common case of conactenation execution.""" i_b = tvm.tir.ir_builder.create() data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs] out_buf = i_b.buffer_ptr(out_buf) @@ -108,7 +110,7 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer) def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0): - """Join a sequence of arrays along an existing axis. + """Join a sequence of arrays along an existing axis. Optimized for CPU exeution. Parameters ---------- From bfbcb86fc8af6ef6d5dd8c5490d8645d0463fd48 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Tue, 24 May 2022 15:34:01 +0300 Subject: [PATCH 11/21] rebase issue fix. --- python/tvm/topi/x86/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/tvm/topi/x86/__init__.py b/python/tvm/topi/x86/__init__.py index 36ac432d2c94..d075090f01ea 100644 --- a/python/tvm/topi/x86/__init__.py +++ b/python/tvm/topi/x86/__init__.py @@ -42,8 +42,5 @@ from .dense_alter_op import * from .scatter import * from .group_conv2d import * -<<<<<<< HEAD from .math_alter_op import * -======= from .concat import * ->>>>>>> [TE] Optimized version of concatenation layer From 14e8b7000025d8fff7162b6ece1663b75fd00a94 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Wed, 25 May 2022 11:31:49 +0300 Subject: [PATCH 12/21] Restored previous state. --- python/tvm/topi/x86/concat.py | 48 +++++++++++------------------------ 1 file changed, 15 insertions(+), 33 deletions(-) diff --git a/python/tvm/topi/x86/concat.py b/python/tvm/topi/x86/concat.py index e981e694bd28..5cb3cd3f57d5 100644 --- a/python/tvm/topi/x86/concat.py +++ b/python/tvm/topi/x86/concat.py @@ -22,12 +22,12 @@ from ..utils import get_const_int, const_vector -def _concat(a_tuple, axis=0): - """Join a sequence of arrays along an existing axis. +def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0): + """Join a sequence of arrays along an existing axis. Optimized for CPU exeution. Parameters ---------- - a_tuple : tuple of tvm.te.Tensor + data : tuple of tvm.te.Tensor The arrays to concatenate axis : int, optional @@ -45,7 +45,7 @@ def gen_ir_1d(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf): out_buf = i_b.buffer_ptr(out_buf) outers = i_b.buffer_ptr(in_outers_tensor) cumsum = i_b.buffer_ptr(in_cumsum_tensor) - for i in range(len(a_tuple)): + for i in range(len(data)): with i_b.for_range(0, outers[i], name="j") as j: out_buf[cumsum[i] + j] = data_bufs1[i][j] return i_b.get() @@ -60,39 +60,39 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer) if inner > 1: with i_b.for_range(0, inner, name="inn", kind="parallel") as inn: pos = inn * outer - for i in range(len(a_tuple)): + for i in range(len(data)): offset = inn * outers[i] with i_b.for_range(0, outers[i], name="j") as j: out_buf[pos + cumsum[i] + j] = data_bufs1[i][offset + j] else: - for i in range(len(a_tuple)): + for i in range(len(data)): with i_b.for_range(0, outers[i], name="j", kind="parallel") as j: out_buf[cumsum[i] + j] = data_bufs1[i][j] return i_b.get() if axis < 0: - axis += len(a_tuple[0].shape) - concat_axis_sizes = [int(t.shape[axis]) for t in a_tuple] + axis += len(data[0].shape) + concat_axis_sizes = [int(t.shape[axis]) for t in data] join_size = int(np.sum(concat_axis_sizes)) - in_outers = [int(np.prod(i.shape[axis:])) for i in a_tuple] + in_outers = [int(np.prod(i.shape[axis:])) for i in data] in_outers_cumsum = [0, *np.cumsum(in_outers, dtype="int64")[0:-1]] - dtype = a_tuple[0].dtype - out_shape = a_tuple[0].shape[:axis] + [join_size] + a_tuple[0].shape[axis + 1 :] + dtype = data[0].dtype + out_shape = data[0].shape[:axis] + [join_size] + data[0].shape[axis + 1 :] in_outers_tensor = const_vector(in_outers) in_cumsum_tensor = const_vector(in_outers_cumsum, name="cumsum") right_val = np.prod(out_shape[axis:]) left_val = np.prod(out_shape[:axis]) if ( - len(a_tuple[0].shape) == 1 + len(data[0].shape) == 1 or right_val == 1 - or (left_val == 1 and axis == len(a_tuple[0].shape) - 1) + or (left_val == 1 and axis == len(data[0].shape) - 1) or (left_val == 1 and right_val == 1) ): # badly parallelized case return te.extern( [out_shape], - list(a_tuple) + [in_outers_tensor, in_cumsum_tensor], + list(data) + [in_outers_tensor, in_cumsum_tensor], lambda ins, outs: gen_ir_1d(ins, ins[-2], ins[-1], outs[0]), dtype=dtype, name="concatenate_ext", @@ -102,26 +102,8 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer) outer = get_const_int(int(right_val)) return te.extern( [out_shape], - list(a_tuple) + [in_outers_tensor, in_cumsum_tensor], + list(data) + [in_outers_tensor, in_cumsum_tensor], lambda ins, outs: gen_ir(ins, ins[-2], ins[-1], outs[0], inner, outer), dtype=dtype, name="concatenate_ext", ) - - -def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0): - """Join a sequence of arrays along an existing axis. Optimized for CPU exeution. - - Parameters - ---------- - data : tuple of tvm.te.Tensor - The arrays to concatenate - - axis : int, optional - The axis along which the arrays will be joined. Default is 0. - - Returns - ------- - ret : tvm.te.Tensor - """ - return _concat(data, axis=axis) From 3ec0d76f0d5e61519aa100b56263e27d95393ab6 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Wed, 25 May 2022 19:05:40 +0300 Subject: [PATCH 13/21] Update after code review. --- python/tvm/relay/op/strategy/x86.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 95c712a2c2a9..59a57fd233f5 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -48,13 +48,6 @@ def schedule_reduce_cpu(attrs, outs, target): return topi.x86.schedule_reduce(outs) -@schedule_concatenate.register("cpu") -def schedule_concatenate_cpu(attrs, outs, target): - """schedule concatenate op for x86""" - with target: - return topi.transform.schedule_concatenate(outs) - - @schedule_pool.register("cpu") def schedule_pool_cpu(attrs, outs, target): """schedule pooling ops for x86""" From 835e8a1c0ce031162b8384ca95cf15888cc9abe5 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Mon, 30 May 2022 12:03:32 +0300 Subject: [PATCH 14/21] After code review changes. --- python/tvm/relay/op/strategy/generic.py | 2 +- python/tvm/topi/x86/injective.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index d637ceccd25c..2bb009dbc8f7 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1808,7 +1808,7 @@ def concatenate_strategy(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_concat(topi.concatenate), - wrap_topi_schedule(topi.generic.schedule_extern), + wrap_topi_schedule(topi.generic.schedule_injective), name="concatenate", ) return strategy diff --git a/python/tvm/topi/x86/injective.py b/python/tvm/topi/x86/injective.py index 68f778727d54..79f13c1f7361 100644 --- a/python/tvm/topi/x86/injective.py +++ b/python/tvm/topi/x86/injective.py @@ -19,11 +19,12 @@ from tvm import te from tvm.topi import tag from tvm.tir import IntImm -from tvm.topi.generic.injective import schedule_injective_from_existing +from tvm.topi.generic.injective \ + import schedule_injective_from_existing as schedule_injective_for_concat from ..utils import is_empty_shape -def schedule_injective_from_existing_ref(sch, out): +def schedule_injective_from_existing(sch, out): """Schedule for injective op from existing schedule. Parameters ---------- @@ -77,7 +78,7 @@ def schedule_injective(outs): te.schedule.AutoInlineInjective(s) if not is_empty_shape(x.shape): - schedule_injective_from_existing_ref(s, x) + schedule_injective_from_existing(s, x) return s @@ -147,7 +148,8 @@ def schedule_concatenate_cpu(outs): def traverse(op): if tag.is_injective(op.tag): - schedule_injective_from_existing(s, op.output(0)) + schedule_injective_for_concat(s, op.output(0)) + for tensor in op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops: traverse(tensor.op) From 2199e439268fdf514dc1a30cced9b8ce8da3248d Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Mon, 30 May 2022 12:21:13 +0300 Subject: [PATCH 15/21] lint review. --- python/tvm/topi/x86/injective.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/x86/injective.py b/python/tvm/topi/x86/injective.py index 79f13c1f7361..b80bedcd89a6 100644 --- a/python/tvm/topi/x86/injective.py +++ b/python/tvm/topi/x86/injective.py @@ -19,8 +19,9 @@ from tvm import te from tvm.topi import tag from tvm.tir import IntImm -from tvm.topi.generic.injective \ - import schedule_injective_from_existing as schedule_injective_for_concat +from tvm.topi.generic.injective import ( + schedule_injective_from_existing as schedule_injective_for_concat, +) from ..utils import is_empty_shape From d474d164ab27bf92e3ff09c84757638fa503530b Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Mon, 30 May 2022 14:45:45 +0300 Subject: [PATCH 16/21] Change strategy for cuda to fix tests. --- python/tvm/relay/op/strategy/cuda.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 59971d4e206f..4a7cff5f3f33 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -42,11 +42,15 @@ def schedule_reduce_cuda(attrs, outs, target): return topi.cuda.schedule_reduce(outs) -@schedule_concatenate.register(["cuda", "gpu"]) -def schedule_concatenate_cuda(attrs, outs, target): - """schedule concatenate for cuda""" - with target: - return topi.cuda.schedule_injective(outs) +@concatenate_strategy.register(["cuda", "gpu"]) +def concatenate_strategy_cuda(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_concat(topi.transform.concatenate), + wrap_topi_schedule(topi.cuda.schedule_injective), + name="concatenate.cuda", + ) + return strategy @schedule_pool.register(["cuda", "gpu"]) From 37250d3568ef5167828ec31fafdc65f4e546136e Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Mon, 30 May 2022 17:16:47 +0300 Subject: [PATCH 17/21] Rebase to main From a2c968260de189c4d52d1c8d36cb14c98fe6d4a3 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Tue, 31 May 2022 15:59:03 +0300 Subject: [PATCH 18/21] Comments changes after review. --- python/tvm/topi/x86/injective.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/x86/injective.py b/python/tvm/topi/x86/injective.py index b80bedcd89a6..78893397ba31 100644 --- a/python/tvm/topi/x86/injective.py +++ b/python/tvm/topi/x86/injective.py @@ -62,7 +62,7 @@ def schedule_injective_from_existing(sch, out): def schedule_injective(outs): - """X86 reference schedule for injective op. + """X86 schedule for injective op. Parameters ---------- outs: Array of Tensor @@ -135,7 +135,7 @@ def schedule_concatenate_cpu(outs): Parameters ---------- outs: Array of Tensor - The computation graph description of injective in the format + The computation graph description in the format of an array of tensors. Returns ------- From dd8d1db46e03c3702729a31b46fbf90c42475234 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Wed, 1 Jun 2022 10:29:21 +0300 Subject: [PATCH 19/21] Some more comments fixes. --- src/te/schedule/schedule_dataflow_rewrite.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index c6f530ddfcc9..bd3bd2d828cc 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -511,14 +511,14 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) { std::vector changed(sch->stages.size(), false); std::vector new_hybrid_body(sch->stages.size()); std::vector hybrid_changed(sch->stages.size(), false); - // (sshtin): this workaround allows to inline extern ops. - // All inputs for extern op should not be inlined because inlining happens - // before generation of TE script for particular extern op. That may lead to + // (sshtin): this workaround allows to inline extern ops into their consumer. + // All inputs for extern op should not be inlined because inlining may happen + // before TE generation for particular extern op. That may lead to // crash during lowering or building stages. // The problem description: - // In case of operations fuzing arguments inlining + // In case of operations fusing, arguments inlining // prevents creation of ProducerNode for extern operation. - // Instead of the creation it supposed to use operation argument as inlined buffer + // Instead of the creation it is supposed to use operation argument as inlined buffer // but extern_op TIR generation can be peformed after inlining procedure so // newly generated TIR does not have reference to input data at all. std::unordered_map ext_ops; @@ -550,7 +550,7 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) { } if (ext_ops.find(stage->op) != ext_ops.end()) { // sshtin: The extern op can try to get access to the input tensors as a row data, - // that can lead to error in TE scripts. + // that can lead to error in IR builder. stage->attach_type = kGroupRoot; continue; } From 213c3c64f89ecee8ba50b0287451b8b15c604355 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Wed, 1 Jun 2022 10:54:00 +0300 Subject: [PATCH 20/21] One more error fix in comments. --- src/te/schedule/schedule_dataflow_rewrite.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index bd3bd2d828cc..a8363fd084cd 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -549,7 +549,7 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) { args.push_back(iv->var); } if (ext_ops.find(stage->op) != ext_ops.end()) { - // sshtin: The extern op can try to get access to the input tensors as a row data, + // sshtin: The extern op can try to get access to the input tensors as a raw data, // that can lead to error in IR builder. stage->attach_type = kGroupRoot; continue; From ef94d6fcaa8e9661ba4b3ad4933ad366d2ba1cf4 Mon Sep 17 00:00:00 2001 From: Sergey Shtin Date: Wed, 1 Jun 2022 14:16:36 +0300 Subject: [PATCH 21/21] restart build