diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 0338035329fc..d87ee266f01d 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -68,7 +68,12 @@ # concatenate -_reg.register_schedule("concatenate", strategy.schedule_concatenate) +@_reg.register_compute("concatenate") +def compute_concat(attrs, inputs, output_type): + return [topi.concatenate(inputs, attrs.axis)] + + +_reg.register_strategy("concatenate", strategy.concatenate_strategy) # sliding_window @_reg.register_compute("sliding_window") diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 59971d4e206f..4a7cff5f3f33 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -42,11 +42,15 @@ def schedule_reduce_cuda(attrs, outs, target): return topi.cuda.schedule_reduce(outs) -@schedule_concatenate.register(["cuda", "gpu"]) -def schedule_concatenate_cuda(attrs, outs, target): - """schedule concatenate for cuda""" - with target: - return topi.cuda.schedule_injective(outs) +@concatenate_strategy.register(["cuda", "gpu"]) +def concatenate_strategy_cuda(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_concat(topi.transform.concatenate), + wrap_topi_schedule(topi.cuda.schedule_injective), + name="concatenate.cuda", + ) + return strategy @schedule_pool.register(["cuda", "gpu"]) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index fa62af5f9fed..2bb009dbc8f7 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1781,6 +1781,15 @@ def _compute_scanop(attrs, inputs, _): return _compute_scanop +def wrap_compute_concat(topi_compute): + """Wrap concatenate topi compute""" + + def _compute_concat(attrs, inputs, _): + return [topi_compute(inputs, attrs.axis)] + + return _compute_concat + + @override_native_generic_func("cumsum_strategy") def cumsum_strategy(attrs, inputs, out_type, target): """cumsum generic strategy""" @@ -1793,6 +1802,18 @@ def cumsum_strategy(attrs, inputs, out_type, target): return strategy +@override_native_generic_func("concat_strategy") +def concatenate_strategy(attrs, inputs, out_type, target): + """concatenate generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_concat(topi.concatenate), + wrap_topi_schedule(topi.generic.schedule_injective), + name="concatenate", + ) + return strategy + + @override_native_generic_func("cumprod_strategy") def cumprod_strategy(attrs, inputs, out_type, target): """cumprod generic strategy""" diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 0beb99e4f7db..59a57fd233f5 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -19,7 +19,7 @@ import logging import re -from tvm import topi +from tvm import topi, tir from tvm.topi.x86.utils import target_has_vnni from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.te import SpecializedCondition @@ -48,13 +48,6 @@ def schedule_reduce_cpu(attrs, outs, target): return topi.x86.schedule_reduce(outs) -@schedule_concatenate.register("cpu") -def schedule_concatenate_cpu(attrs, outs, target): - """schedule concatenate op for x86""" - with target: - return topi.x86.schedule_concatenate(outs) - - @schedule_pool.register("cpu") def schedule_pool_cpu(attrs, outs, target): """schedule pooling ops for x86""" @@ -741,3 +734,34 @@ def conv2d_winograd_without_weight_transfrom_strategy_cpu(attrs, inputs, out_typ "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) ) return strategy + + +@concatenate_strategy.register(["cpu"]) +def concatenate_strategy_cpu(attrs, inputs, out_type, target): + """concatenate x86 strategy""" + strategy = _op.OpStrategy() + use_only_old_concat = False + for inpt in inputs: + shape = inpt.shape + for i in shape: + if not isinstance(i, tir.expr.IntImm): + use_only_old_concat = True + break + if use_only_old_concat: + strategy.add_implementation( + wrap_compute_concat(topi.transform.concatenate), + wrap_topi_schedule(topi.x86.injective.schedule_concatenate), + name="concatenate.generic", + ) + else: + strategy.add_implementation( + wrap_compute_concat(topi.x86.concatenate), + wrap_topi_schedule(topi.x86.schedule_concatenate_cpu), + name="concatenate.cpu", + ) + strategy.add_implementation( + wrap_compute_concat(topi.transform.concatenate), + wrap_topi_schedule(topi.x86.injective.schedule_concatenate), + name="concatenate.generic", + ) + return strategy diff --git a/python/tvm/topi/x86/__init__.py b/python/tvm/topi/x86/__init__.py index 34a5e0362d87..d075090f01ea 100644 --- a/python/tvm/topi/x86/__init__.py +++ b/python/tvm/topi/x86/__init__.py @@ -43,3 +43,4 @@ from .scatter import * from .group_conv2d import * from .math_alter_op import * +from .concat import * diff --git a/python/tvm/topi/x86/concat.py b/python/tvm/topi/x86/concat.py new file mode 100644 index 000000000000..5cb3cd3f57d5 --- /dev/null +++ b/python/tvm/topi/x86/concat.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"concatenate related operators" +from typing import Optional +import tvm +from tvm import te +import numpy as np +from ..utils import get_const_int, const_vector + + +def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0): + """Join a sequence of arrays along an existing axis. Optimized for CPU exeution. + + Parameters + ---------- + data : tuple of tvm.te.Tensor + The arrays to concatenate + + axis : int, optional + The axis along which the arrays will be joined. Default is 0. + + Returns + ------- + ret : tvm.te.Tensor + """ + + def gen_ir_1d(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf): + """Custom conactenation execution.""" + i_b = tvm.tir.ir_builder.create() + data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs] + out_buf = i_b.buffer_ptr(out_buf) + outers = i_b.buffer_ptr(in_outers_tensor) + cumsum = i_b.buffer_ptr(in_cumsum_tensor) + for i in range(len(data)): + with i_b.for_range(0, outers[i], name="j") as j: + out_buf[cumsum[i] + j] = data_bufs1[i][j] + return i_b.get() + + def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer): + """Common case of conactenation execution.""" + i_b = tvm.tir.ir_builder.create() + data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs] + out_buf = i_b.buffer_ptr(out_buf) + outers = i_b.buffer_ptr(in_outers_tensor) + cumsum = i_b.buffer_ptr(in_cumsum_tensor) + if inner > 1: + with i_b.for_range(0, inner, name="inn", kind="parallel") as inn: + pos = inn * outer + for i in range(len(data)): + offset = inn * outers[i] + with i_b.for_range(0, outers[i], name="j") as j: + out_buf[pos + cumsum[i] + j] = data_bufs1[i][offset + j] + else: + for i in range(len(data)): + with i_b.for_range(0, outers[i], name="j", kind="parallel") as j: + out_buf[cumsum[i] + j] = data_bufs1[i][j] + return i_b.get() + + if axis < 0: + axis += len(data[0].shape) + concat_axis_sizes = [int(t.shape[axis]) for t in data] + join_size = int(np.sum(concat_axis_sizes)) + in_outers = [int(np.prod(i.shape[axis:])) for i in data] + in_outers_cumsum = [0, *np.cumsum(in_outers, dtype="int64")[0:-1]] + dtype = data[0].dtype + out_shape = data[0].shape[:axis] + [join_size] + data[0].shape[axis + 1 :] + in_outers_tensor = const_vector(in_outers) + in_cumsum_tensor = const_vector(in_outers_cumsum, name="cumsum") + right_val = np.prod(out_shape[axis:]) + left_val = np.prod(out_shape[:axis]) + + if ( + len(data[0].shape) == 1 + or right_val == 1 + or (left_val == 1 and axis == len(data[0].shape) - 1) + or (left_val == 1 and right_val == 1) + ): + # badly parallelized case + return te.extern( + [out_shape], + list(data) + [in_outers_tensor, in_cumsum_tensor], + lambda ins, outs: gen_ir_1d(ins, ins[-2], ins[-1], outs[0]), + dtype=dtype, + name="concatenate_ext", + ) + + inner = get_const_int(int(left_val)) + outer = get_const_int(int(right_val)) + return te.extern( + [out_shape], + list(data) + [in_outers_tensor, in_cumsum_tensor], + lambda ins, outs: gen_ir(ins, ins[-2], ins[-1], outs[0], inner, outer), + dtype=dtype, + name="concatenate_ext", + ) diff --git a/python/tvm/topi/x86/injective.py b/python/tvm/topi/x86/injective.py index 6492b78d6037..78893397ba31 100644 --- a/python/tvm/topi/x86/injective.py +++ b/python/tvm/topi/x86/injective.py @@ -17,20 +17,22 @@ # pylint: disable=invalid-name """x86 declaration and schedules.""" from tvm import te +from tvm.topi import tag from tvm.tir import IntImm +from tvm.topi.generic.injective import ( + schedule_injective_from_existing as schedule_injective_for_concat, +) from ..utils import is_empty_shape def schedule_injective_from_existing(sch, out): """Schedule for injective op from existing schedule. - Parameters ---------- sch: Schedule The schedule to update. out: Tensor The tensor representing the injective op. - Returns ------- sch: Schedule @@ -61,13 +63,11 @@ def schedule_injective_from_existing(sch, out): def schedule_injective(outs): """X86 schedule for injective op. - Parameters ---------- outs: Array of Tensor The computation graph description of injective in the format of an array of tensors. - Returns ------- sch: Schedule @@ -85,13 +85,11 @@ def schedule_injective(outs): def schedule_concatenate(outs): """X86 schedule for concatenate op. - Parameters ---------- outs: Array of Tensor The computation graph description of injective in the format of an array of tensors. - Returns ------- sch: Schedule @@ -132,5 +130,37 @@ def vectorize(sch, tensor, vectorize_limit): return s +def schedule_concatenate_cpu(outs): + """X86 schedule for concatenate op. + Parameters + ---------- + outs: Array of Tensor + The computation graph description in the format + of an array of tensors. + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def traverse(op): + if tag.is_injective(op.tag): + schedule_injective_for_concat(s, op.output(0)) + + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + + for out in outs: + traverse(out.op) + + return s + + schedule_elemwise = schedule_injective schedule_broadcast = schedule_injective diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e888eccc2b1c..57bf9f36def9 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -346,7 +346,6 @@ RELAY_REGISTER_OP("concatenate") .set_support_level(1) .add_type_rel("Concatenate", ConcatenateRel) .set_attr("FInferCorrectLayout", ConcatenateLayout) - .set_attr("FTVMCompute", ConcatenateCompute) .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(StackAttrs); diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 2b30055c4f42..a8363fd084cd 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -511,6 +511,29 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) { std::vector changed(sch->stages.size(), false); std::vector new_hybrid_body(sch->stages.size()); std::vector hybrid_changed(sch->stages.size(), false); + // (sshtin): this workaround allows to inline extern ops into their consumer. + // All inputs for extern op should not be inlined because inlining may happen + // before TE generation for particular extern op. That may lead to + // crash during lowering or building stages. + // The problem description: + // In case of operations fusing, arguments inlining + // prevents creation of ProducerNode for extern operation. + // Instead of the creation it is supposed to use operation argument as inlined buffer + // but extern_op TIR generation can be peformed after inlining procedure so + // newly generated TIR does not have reference to input data at all. + std::unordered_map ext_ops; + for (size_t i = 0; i < sch->stages.size(); i++) { + Stage stage = sch->stages[i]; + auto ext_op = stage->op.as(); + if (ext_op) { + auto inps = ext_op->InputTensors(); + for (size_t ii = 0; ii < inps.size(); ++ii) { + if (ext_ops.find(inps[ii]->op) == ext_ops.end()) { + ext_ops[inps[ii]->op] = stage->op; + } + } + } + } // inline all the ops for (size_t i = sch->stages.size(); i != 0; --i) { Stage stage = sch->stages[i - 1]; @@ -525,8 +548,13 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) { for (auto iv : compute->axis) { args.push_back(iv->var); } + if (ext_ops.find(stage->op) != ext_ops.end()) { + // sshtin: The extern op can try to get access to the input tensors as a raw data, + // that can lead to error in IR builder. + stage->attach_type = kGroupRoot; + continue; + } ICHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output"; - if (feature_extraction_mode && compute->attrs.count("const_matrix")) { // Use constant value to replace access of const matrices. // This produces wrong IR but is good enough for feature extraction purposes. diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 170850809ad5..f4afc9e90562 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -431,6 +431,103 @@ def test_batch_norm(): ) +def do_concat_test(shapes, t_shape, dtype, axis, dev, target): + varsToConcat = [] + inputData = [] + pos = 0 + for s in shapes: + varsToConcat.append(relay.var("x{}".format(pos), shape=s)) + inputData.append(np.random.rand(*s).astype(dtype)) + pos += 1 + t = relay.var("z", shape=t_shape, dtype=dtype) + z = relay.concatenate(varsToConcat, axis=axis) + z = relay.add(z, t) + params = varsToConcat + params.append(t) + func = relay.Function(params, z) + t_data = np.random.uniform(low=-10, high=10, size=t_shape).astype(dtype) + ref_res = np.concatenate((tuple(inputData)), axis=axis) + t_data + mod = tvm.IRModule.from_expr(func) + + executor = relay.create_executor("graph", mod=mod, device=dev, target=target) + op_res1 = executor.evaluate()(*inputData, t_data) + + tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=0.000001) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + *inputData, t_data + ) + tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=0.000001) + + +@tvm.testing.parametrize_targets("llvm") +def test_concatenate1(target, dev): + np.random.seed(471) + maxNumDimensions = 6 + shape = [4, 32, 16, 1, 31, 20, 21, 8, 28, 7] # just randomly selected 10 numbers + for dtype in ["float32"]: + for dimsNum in range(1, maxNumDimensions): + np.random.shuffle(shape) + for axis in range(0, dimsNum): # range should be (-dimsNum + 1, dimsNum) + numToConcat = np.random.uniform(low=2, high=10, size=(1)).astype("int64")[0] + shapes = [] + # the code below to normalize axes index. For some reasons tvm notifies about error if the axis is negative + normalizedAxis = axis + if axis < 0: + normalizedAxis += dimsNum + finalSize = 0 + for i in range(0, numToConcat): + shp = tuple(shape[:dimsNum]) + finalSize += shape[(i % len(shape))] + shapes.append( + shp[:normalizedAxis] + + tuple([shape[(i % len(shape))]]) + + shp[normalizedAxis + 1 :] + ) + t_shape = shp[:normalizedAxis] + tuple([finalSize]) + shp[normalizedAxis + 1 :] + do_concat_test(shapes, t_shape, dtype, axis, dev, target) + + +@tvm.testing.parametrize_targets("llvm") +def test_concatenate2(target, dev): + # test to cover cases (1, .. , x, 1, .. , 1) + np.random.seed(13) + maxNumDimensions = 6 + shape = [8, 3, 25, 33, 12, 29, 5, 11, 29, 11] # just randomly selected 10 numbers + ind = 0 + for dtype in ["float32"]: + for dimsNum in range(2, maxNumDimensions): + np.random.shuffle(shape) + for axis in range(-dimsNum + 1, dimsNum): # range should be (-dimsNum + 1, dimsNum) + numToConcat = np.random.uniform(low=2, high=10, size=(1)).astype("int64")[0] + shapes = [] + # the code below to normalize axes index. For some reasons tvm notifies about error if the axis is negative + normalizedAxis = axis + if axis < 0: + normalizedAxis += dimsNum + finalSize = 0 + for i in range(0, numToConcat): + axisVal = [1] * dimsNum + axisVal[axis] = shape[(ind % len(shape))] + ind += 1 + finalSize += axisVal[axis] + shapes.append(tuple(axisVal)) + temp = [1] * dimsNum + temp[axis] = finalSize + t_shape = tuple(temp) + do_concat_test(shapes, t_shape, dtype, axis, dev, target) + + +@tvm.testing.parametrize_targets("llvm") +def test_concatenate3(target, dev): + np.random.seed(477) + for dtype in ["float32"]: + axis = -2 + ending = 1 + shapes = [[3, 2, 1, ending], [3, 2, 1, ending]] + t_shape = [3, 2, 2, ending] + do_concat_test(shapes, t_shape, dtype, axis, dev, target) + + def test_batch_norm_fold_const(): axis = 1 dtype = "float32" diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index ad054479fd7b..d707e6b4646b 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -22,6 +22,7 @@ import numpy import pytest +import platform import tvm import tvm.relay @@ -418,14 +419,24 @@ def test_export_byoc_c_module(): with tf.extractfile("./metadata.json") as f: metadata = json.load(f) main_md = metadata["memory"]["functions"]["main"] - assert main_md == [ - { - "constants_size_bytes": 0, - "device": 1, - "io_size_bytes": 4800, - "workspace_size_bytes": 800, - } - ] + if platform.architecture()[0] == "64bit": + assert main_md == [ + { + "constants_size_bytes": 0, + "device": 1, + "io_size_bytes": 4800, + "workspace_size_bytes": 1264, + } + ] + else: + assert main_md == [ + { + "constants_size_bytes": 0, + "device": 1, + "io_size_bytes": 4800, + "workspace_size_bytes": 1248, + } + ] if __name__ == "__main__":