Skip to content

Commit

Permalink
introduce DensePackAttrs to avoid breaking dense op
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Aug 11, 2021
1 parent d676de7 commit d2528b7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
17 changes: 16 additions & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -992,11 +992,26 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
DataType out_dtype;
tvm::String weight_layout;

TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

/*! \brief Attributes for dense_pack operator */
struct DensePackAttrs : public tvm::AttrsNode<DensePackAttrs> {
IndexExpr units;
DataType out_dtype;
tvm::String weight_layout;

TVM_DECLARE_ATTRS(DensePackAttrs, "relay.attrs.DensePackAttrs") {
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
Expand Down
8 changes: 5 additions & 3 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,12 @@ RELAY_REGISTER_OP("nn.dense")
// ------------------- relay.nn.dense

// ------------------- relay.nn.contrib_dense_pack
TVM_REGISTER_NODE_TYPE(DensePackAttrs);

// Positional relay function to create dense_pack operator used by frontend FFI.
Expr MakeDensePack(Expr data, Expr weight, tvm::String weight_layout, IndexExpr units,
DataType out_dtype) {
auto attrs = make_object<DenseAttrs>();
auto attrs = make_object<DensePackAttrs>();
attrs->units = units;
attrs->out_dtype = out_dtype;
attrs->weight_layout = std::move(weight_layout);
Expand All @@ -253,7 +255,7 @@ bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr || weight == nullptr) return false;

const DenseAttrs* param = attrs.as<DenseAttrs>();
const DensePackAttrs* param = attrs.as<DensePackAttrs>();
ICHECK(param != nullptr);

ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported";
Expand All @@ -275,7 +277,7 @@ InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
auto params = attrs.as<DenseAttrs>();
auto params = attrs.as<DensePackAttrs>();
ICHECK(params);
return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs);
}
Expand Down

0 comments on commit d2528b7

Please sign in to comment.