diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 694001f612e7..77cba5fa2ff1 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1003,6 +1003,25 @@ struct DenseAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for dense_pack operator */ +struct DensePackAttrs : public tvm::AttrsNode { + IndexExpr units; + DataType out_dtype; + tvm::String weight_layout; + + TVM_DECLARE_ATTRS(DensePackAttrs, "relay.attrs.DensePackAttrs") { + TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); + + // use 0 bits to indicate none. + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + TVM_ATTR_FIELD(weight_layout) + .set_default("NK") + .describe("Dimension ordering of weight. Packed layouts, such as NK8n, are possible."); + } +}; + /*! \brief Attributes for batch matmul operator. */ struct BatchMatmulAttrs : public tvm::AttrsNode { DataType out_dtype; diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 96cef8bc3588..a9ccc5aa2d24 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1259,9 +1259,9 @@ def dense_shape_func(attrs, inputs, _): @script def _dense_pack_shape_func(data_shape, weight_shape): out = output_tensor((data_shape.shape[0],), "int64") - for i in const_range(out.shape[0] - 1): - out[i] = data_shape[i] - out[out.shape[0] - 1] = weight_shape[0] * weight_shape[2] + assert data_shape.shape[0] == 2, "Input data must be 2D" + out[0] = data_shape[0] + out[1] = weight_shape[0] * weight_shape[2] return out diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 64b397a4d4f9..e882bcf7e271 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1548,9 +1548,9 @@ def dense(data, weight, units=None, out_dtype=""): return _make.dense(data, weight, units, out_dtype) -def contrib_dense_pack(data, weight, units=None, out_dtype=""): +def contrib_dense_pack(data, weight, weight_layout="NK", units=None, out_dtype=""): """Dense operator. - Applies a linear transformation + Applies a linear transformation with packed weight .. math:: @@ -1560,25 +1560,27 @@ def contrib_dense_pack(data, weight, units=None, out_dtype=""): ---------- data : tvm.relay.Expr The input data to the operator, - of shape `(d_1, d_2, ..., d_n, units_in)`. + of shape `(batch, units_in)`. weight : tvm.relay.Expr The transformed weight expressions, 3-D matrix, of shape `(units // pack_weight_tile, units_in, pack_weight_tile)`. + weight_layout: str + The layout of weight, such as "NK" or "NK8n". + units : int, optional Number of hidden units of the dense transformation. out_dtype : str, optional - Specifies the output data type for mixed precision dense, - of shape `(d_1, d_2, ..., d_n, units)`. + Specifies the output data type for mixed precision dense. Returns ------- result : tvm.relay.Expr The computed result. """ - return _make.contrib_dense_pack(data, weight, units, out_dtype) + return _make.contrib_dense_pack(data, weight, weight_layout, units, out_dtype) def fifo_buffer(data, buffer, axis): diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 5684dbc4eaeb..b839363bde43 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -63,6 +63,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): relay.op.get("nn.dense"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) + if workload: cfg = dispatch_ctx.query(target, workload) topi_impl = workload[0] @@ -86,7 +87,6 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): topi_impl, ) dispatch_ctx.update(target, new_workload, cfg) - weight_transform = relay.layout_transform(inputs[1], "NK", weight_layout) - return relay.nn.contrib_dense_pack(inputs[0], weight_transform, None, out_dtype) + return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype) return None diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index a96f167df2bb..a05e460dc680 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -206,6 +206,13 @@ Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { return Call(op, {data, weight}, Attrs(attrs), {}); } +InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + return InferCorrectLayoutOutput({"NC", "NK"}, {"NC"}, attrs); +} + TVM_REGISTER_GLOBAL("relay.op.nn._make.dense").set_body_typed(MakeDense); RELAY_REGISTER_OP("nn.dense") @@ -221,35 +228,75 @@ RELAY_REGISTER_OP("nn.dense") .add_argument("data", "nD Tensor", "Input data.") .add_argument("weight", "2D Tensor", "Weight matrix.") .set_support_level(1) + .set_attr("FInferCorrectLayout", DenseInferCorrectLayout) .add_type_rel("Dense", MatmulRel); // ------------------- relay.nn.dense // ------------------- relay.nn.contrib_dense_pack +TVM_REGISTER_NODE_TYPE(DensePackAttrs); + // Positional relay function to create dense_pack operator used by frontend FFI. -Expr MakeDensePack(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { - auto attrs = make_object(); +Expr MakeDensePack(Expr data, Expr weight, tvm::String weight_layout, IndexExpr units, + DataType out_dtype) { + auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; + attrs->weight_layout = std::move(weight_layout); static const Op& op = Op::Get("nn.contrib_dense_pack"); return Call(op, {data, weight}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_dense_pack").set_body_typed(MakeDensePack); +bool DensePackRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr || weight == nullptr) return false; + + const DensePackAttrs* param = attrs.as(); + ICHECK(param != nullptr); + + ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported"; + ICHECK_EQ(weight->shape.size(), 3) << "Weight is not packed"; + + Array oshape = data->shape; + oshape.Set(1, weight->shape[0] * weight->shape[2]); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + // assign output type + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + +InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + auto params = attrs.as(); + ICHECK(params); + return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs); +} + RELAY_REGISTER_OP("nn.contrib_dense_pack") .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. -- **data**: `(x1, x2, ..., xn, input_dim)` +- **data**: `(batch, input_dim)` - **weight**: `(units // pack_weight_tile, input_dim, pack_weight_tile)` -- **out**: `(x1, x2, ..., xn, units)`. +- **out**: `(batch, units)`. )code" TVM_ADD_FILELINE) .set_attrs_type() .set_num_inputs(2) - .add_argument("data", "nD Tensor", "Input data.") + .add_argument("data", "2D Tensor", "Input data.") .add_argument("weight", "3D Tensor", "Packed weight matrix.") .set_support_level(10) - .add_type_rel("DensePack", DensePackRel); + .set_attr("FInferCorrectLayout", DensePackInferCorrectLayout) + .add_type_rel("DensePack", DensePackRel); // ------------------- relay.nn.contrib_dense_pack // relay.leaky_relu @@ -307,7 +354,6 @@ bool PReluRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -template InferCorrectLayoutOutput PReluInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, @@ -343,7 +389,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. .add_argument("alpha", "Tensor", "Input channelwise alpha.") .set_support_level(3) .add_type_rel("PRelu", PReluRel) - .set_attr("FInferCorrectLayout", PReluInferCorrectLayout) + .set_attr("FInferCorrectLayout", PReluInferCorrectLayout) .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 3dc63b31a205..6bc21473af18 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -116,29 +116,6 @@ bool MatmulRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -template -bool DensePackRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr || weight == nullptr) return false; - - const AttrType* param = attrs.as(); - ICHECK(param != nullptr); - - Array oshape = data->shape; - oshape.Set((oshape.size() - 1), weight->shape[0] * weight->shape[2]); - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - // assign output type - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - template bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { diff --git a/tests/python/contrib/test_arm_compute_lib/test_dense.py b/tests/python/contrib/test_arm_compute_lib/test_dense.py index e6620a4bc1cb..6bdff0fdb857 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_dense.py +++ b/tests/python/contrib/test_arm_compute_lib/test_dense.py @@ -150,11 +150,16 @@ def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False): if has_bias: bias_dtype = "int32" if dtype == "uint8" else "float32" + bias_shape = ( + [1, weight_shape[0]] + if dtype == "float32" and weight_shape[0] != 1 + else [weight_shape[0]] + ) inputs.append( { "op": "const", "name": "", - "attrs": {"shape": [[[weight_shape[0]]]], "dtype": [[bias_dtype]]}, + "attrs": {"shape": [[bias_shape]], "dtype": [[bias_dtype]]}, } ) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 3cee84f6a4dd..1607b7df4a1e 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -1315,7 +1315,9 @@ def expected(): weight = relay.var("weight", shape=(48, 64)) target_layout = "NK16n" weight_transform = relay.layout_transform(weight, "NK", target_layout) - y = relay.nn.contrib_dense_pack(x, weight_transform, units=None, out_dtype="float32") + y = relay.nn.contrib_dense_pack( + x, weight_transform, target_layout, units=None, out_dtype="float32" + ) y = relay.Function(analysis.free_vars(y), y) return y @@ -1353,6 +1355,49 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): assert before.body.attrs.layout == "NCHW" +def test_alter_op_dense_packed_data(): + def before(): + x = relay.var("x", shape=(1, 32, 8, 8)) + weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3)) + conv = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) + pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0]) + squeeze = relay.squeeze(pool, axis=[2, 3]) + dense = relay.nn.dense(squeeze, relay.var("dense_weight", shape=(16, 32))) + return relay.Function(analysis.free_vars(dense), dense) + + def expected(): + x = relay.var("x", shape=(1, 32, 8, 8)) + conv_weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3)) + dense_weight = relay.var("dense_weight", shape=(16, 32)) + conv = relay.nn.contrib_conv2d_nchwc( + relay.layout_transform(x, "NCHW", "NCHW8c"), + relay.layout_transform(conv_weight, "OIHW", "OIHW8i8o"), + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW8c", + kernel_layout="OIHW8i8o", + out_layout="NCHW8c", + ) + pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0], layout="NCHW8c") + squeeze = relay.squeeze(pool, axis=[2, 3]) + dense = relay.nn.contrib_dense_pack( + relay.layout_transform(squeeze, "NC8c", "NC"), + relay.layout_transform(dense_weight, "NK", "NK16n"), + "NK16n", + out_dtype="float32", + ) + return relay.Function(analysis.free_vars(dense), dense) + + with tvm.target.Target("llvm"): + with TempOpAttr( + "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout + ): + a = run_opt_pass(before(), transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b) + + if __name__ == "__main__": test_alter_op() test_alter_return_none() @@ -1377,3 +1422,4 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): test_alter_op_dense() test_alter_layout_strided_slice_axes_nhwc() test_not_inplace_modify() + test_alter_op_dense_packed_data()