From d2528b78c0f734f04ebc09b9a8e86ff93f32cafe Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 11 Aug 2021 16:50:58 +0900 Subject: [PATCH] introduce DensePackAttrs to avoid breaking dense op --- include/tvm/relay/attrs/nn.h | 17 ++++++++++++++++- src/relay/op/nn/nn.cc | 8 +++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 0f9e01ebb5fd0..77cba5fa2ff18 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -992,11 +992,26 @@ struct DenseAttrs : public tvm::AttrsNode { IndexExpr units; tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite DataType out_dtype; - tvm::String weight_layout; TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") { TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); + // use 0 bits to indicate none. + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + } +}; + +/*! \brief Attributes for dense_pack operator */ +struct DensePackAttrs : public tvm::AttrsNode { + IndexExpr units; + DataType out_dtype; + tvm::String weight_layout; + + TVM_DECLARE_ATTRS(DensePackAttrs, "relay.attrs.DensePackAttrs") { + TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); + // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index c7c9b3b9e293d..a05e460dc6809 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -233,10 +233,12 @@ RELAY_REGISTER_OP("nn.dense") // ------------------- relay.nn.dense // ------------------- relay.nn.contrib_dense_pack +TVM_REGISTER_NODE_TYPE(DensePackAttrs); + // Positional relay function to create dense_pack operator used by frontend FFI. Expr MakeDensePack(Expr data, Expr weight, tvm::String weight_layout, IndexExpr units, DataType out_dtype) { - auto attrs = make_object(); + auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; attrs->weight_layout = std::move(weight_layout); @@ -253,7 +255,7 @@ bool DensePackRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* weight = types[1].as(); if (data == nullptr || weight == nullptr) return false; - const DenseAttrs* param = attrs.as(); + const DensePackAttrs* param = attrs.as(); ICHECK(param != nullptr); ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported"; @@ -275,7 +277,7 @@ InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, const Array& old_in_types) { - auto params = attrs.as(); + auto params = attrs.as(); ICHECK(params); return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs); }