Skip to content

Commit

Permalink
[Relay] Dense alter layout fixed for packed input (apache#8669)
Browse files Browse the repository at this point in the history
* clean up typerel

* add layout transform when input is 3D

* add test

* update doc to clarify that only 2D input data is supported

* add weight_layout attribute in dense

* remove explicit layout transform from dense_alter_op.py

* Add DensePackInferCorrectLayout to insert layout transform

* relax type rel

* revert type rel relax and add check on dim

* introduce DensePackAttrs to avoid breaking dense op

* try fixing arm compute lib test

* Update tests/python/contrib/test_arm_compute_lib/test_dense.py

Co-authored-by: lhutton1 <35535092+lhutton1@users.noreply.github.com>

* formatting

Co-authored-by: lhutton1 <35535092+lhutton1@users.noreply.github.com>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent 7402c89 commit be74cbb
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 44 deletions.
19 changes: 19 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,25 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
}
};

/*! \brief Attributes for dense_pack operator */
struct DensePackAttrs : public tvm::AttrsNode<DensePackAttrs> {
IndexExpr units;
DataType out_dtype;
tvm::String weight_layout;

TVM_DECLARE_ATTRS(DensePackAttrs, "relay.attrs.DensePackAttrs") {
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
TVM_ATTR_FIELD(weight_layout)
.set_default("NK")
.describe("Dimension ordering of weight. Packed layouts, such as NK8n, are possible.");
}
};

/*! \brief Attributes for batch matmul operator. */
struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
DataType out_dtype;
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,9 +1259,9 @@ def dense_shape_func(attrs, inputs, _):
@script
def _dense_pack_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[0] * weight_shape[2]
assert data_shape.shape[0] == 2, "Input data must be 2D"
out[0] = data_shape[0]
out[1] = weight_shape[0] * weight_shape[2]

return out

Expand Down
14 changes: 8 additions & 6 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,9 +1548,9 @@ def dense(data, weight, units=None, out_dtype=""):
return _make.dense(data, weight, units, out_dtype)


def contrib_dense_pack(data, weight, units=None, out_dtype=""):
def contrib_dense_pack(data, weight, weight_layout="NK", units=None, out_dtype=""):
"""Dense operator.
Applies a linear transformation
Applies a linear transformation with packed weight
.. math::
Expand All @@ -1560,25 +1560,27 @@ def contrib_dense_pack(data, weight, units=None, out_dtype=""):
----------
data : tvm.relay.Expr
The input data to the operator,
of shape `(d_1, d_2, ..., d_n, units_in)`.
of shape `(batch, units_in)`.
weight : tvm.relay.Expr
The transformed weight expressions, 3-D matrix,
of shape `(units // pack_weight_tile, units_in, pack_weight_tile)`.
weight_layout: str
The layout of weight, such as "NK" or "NK8n".
units : int, optional
Number of hidden units of the dense transformation.
out_dtype : str, optional
Specifies the output data type for mixed precision dense,
of shape `(d_1, d_2, ..., d_n, units)`.
Specifies the output data type for mixed precision dense.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_dense_pack(data, weight, units, out_dtype)
return _make.contrib_dense_pack(data, weight, weight_layout, units, out_dtype)


def fifo_buffer(data, buffer, axis):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/x86/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
relay.op.get("nn.dense"), attrs, tinfos, out_type, target
)
workload = autotvm.task.get_workload(outs)

if workload:
cfg = dispatch_ctx.query(target, workload)
topi_impl = workload[0]
Expand All @@ -86,7 +87,6 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
topi_impl,
)
dispatch_ctx.update(target, new_workload, cfg)
weight_transform = relay.layout_transform(inputs[1], "NK", weight_layout)
return relay.nn.contrib_dense_pack(inputs[0], weight_transform, None, out_dtype)
return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype)

return None
62 changes: 54 additions & 8 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) {
return Call(op, {data, weight}, Attrs(attrs), {});
}

InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
return InferCorrectLayoutOutput({"NC", "NK"}, {"NC"}, attrs);
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.dense").set_body_typed(MakeDense);

RELAY_REGISTER_OP("nn.dense")
Expand All @@ -221,35 +228,75 @@ RELAY_REGISTER_OP("nn.dense")
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight", "2D Tensor", "Weight matrix.")
.set_support_level(1)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", DenseInferCorrectLayout)
.add_type_rel("Dense", MatmulRel<DenseAttrs>);
// ------------------- relay.nn.dense

// ------------------- relay.nn.contrib_dense_pack
TVM_REGISTER_NODE_TYPE(DensePackAttrs);

// Positional relay function to create dense_pack operator used by frontend FFI.
Expr MakeDensePack(Expr data, Expr weight, IndexExpr units, DataType out_dtype) {
auto attrs = make_object<DenseAttrs>();
Expr MakeDensePack(Expr data, Expr weight, tvm::String weight_layout, IndexExpr units,
DataType out_dtype) {
auto attrs = make_object<DensePackAttrs>();
attrs->units = units;
attrs->out_dtype = out_dtype;
attrs->weight_layout = std::move(weight_layout);
static const Op& op = Op::Get("nn.contrib_dense_pack");
return Call(op, {data, weight}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_dense_pack").set_body_typed(MakeDensePack);

bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr || weight == nullptr) return false;

const DensePackAttrs* param = attrs.as<DensePackAttrs>();
ICHECK(param != nullptr);

ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported";
ICHECK_EQ(weight->shape.size(), 3) << "Weight is not packed";

Array<tvm::PrimExpr> oshape = data->shape;
oshape.Set(1, weight->shape[0] * weight->shape[2]);

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
auto params = attrs.as<DensePackAttrs>();
ICHECK(params);
return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs);
}

RELAY_REGISTER_OP("nn.contrib_dense_pack")
.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`.
- **data**: `(x1, x2, ..., xn, input_dim)`
- **data**: `(batch, input_dim)`
- **weight**: `(units // pack_weight_tile, input_dim, pack_weight_tile)`
- **out**: `(x1, x2, ..., xn, units)`.
- **out**: `(batch, units)`.
)code" TVM_ADD_FILELINE)
.set_attrs_type<DenseAttrs>()
.set_num_inputs(2)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("data", "2D Tensor", "Input data.")
.add_argument("weight", "3D Tensor", "Packed weight matrix.")
.set_support_level(10)
.add_type_rel("DensePack", DensePackRel<DenseAttrs>);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", DensePackInferCorrectLayout)
.add_type_rel("DensePack", DensePackRel);
// ------------------- relay.nn.contrib_dense_pack

// relay.leaky_relu
Expand Down Expand Up @@ -307,7 +354,6 @@ bool PReluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

template <typename T>
InferCorrectLayoutOutput PReluInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
Expand Down Expand Up @@ -343,7 +389,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
.add_argument("alpha", "Tensor", "Input channelwise alpha.")
.set_support_level(3)
.add_type_rel("PRelu", PReluRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout<PReluAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<PReluAttrs>();
Expand Down
23 changes: 0 additions & 23 deletions src/relay/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,29 +116,6 @@ bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

template <typename AttrType>
bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr || weight == nullptr) return false;

const AttrType* param = attrs.as<AttrType>();
ICHECK(param != nullptr);

Array<tvm::PrimExpr> oshape = data->shape;
oshape.Set((oshape.size() - 1), weight->shape[0] * weight->shape[2]);

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

template <typename AttrType>
bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expand Down
7 changes: 6 additions & 1 deletion tests/python/contrib/test_arm_compute_lib/test_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,16 @@ def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False):

if has_bias:
bias_dtype = "int32" if dtype == "uint8" else "float32"
bias_shape = (
[1, weight_shape[0]]
if dtype == "float32" and weight_shape[0] != 1
else [weight_shape[0]]
)
inputs.append(
{
"op": "const",
"name": "",
"attrs": {"shape": [[[weight_shape[0]]]], "dtype": [[bias_dtype]]},
"attrs": {"shape": [[bias_shape]], "dtype": [[bias_dtype]]},
}
)

Expand Down
48 changes: 47 additions & 1 deletion tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,9 @@ def expected():
weight = relay.var("weight", shape=(48, 64))
target_layout = "NK16n"
weight_transform = relay.layout_transform(weight, "NK", target_layout)
y = relay.nn.contrib_dense_pack(x, weight_transform, units=None, out_dtype="float32")
y = relay.nn.contrib_dense_pack(
x, weight_transform, target_layout, units=None, out_dtype="float32"
)
y = relay.Function(analysis.free_vars(y), y)
return y

Expand Down Expand Up @@ -1353,6 +1355,49 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
assert before.body.attrs.layout == "NCHW"


def test_alter_op_dense_packed_data():
def before():
x = relay.var("x", shape=(1, 32, 8, 8))
weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3))
conv = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0])
squeeze = relay.squeeze(pool, axis=[2, 3])
dense = relay.nn.dense(squeeze, relay.var("dense_weight", shape=(16, 32)))
return relay.Function(analysis.free_vars(dense), dense)

def expected():
x = relay.var("x", shape=(1, 32, 8, 8))
conv_weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3))
dense_weight = relay.var("dense_weight", shape=(16, 32))
conv = relay.nn.contrib_conv2d_nchwc(
relay.layout_transform(x, "NCHW", "NCHW8c"),
relay.layout_transform(conv_weight, "OIHW", "OIHW8i8o"),
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW8c",
kernel_layout="OIHW8i8o",
out_layout="NCHW8c",
)
pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0], layout="NCHW8c")
squeeze = relay.squeeze(pool, axis=[2, 3])
dense = relay.nn.contrib_dense_pack(
relay.layout_transform(squeeze, "NC8c", "NC"),
relay.layout_transform(dense_weight, "NK", "NK16n"),
"NK16n",
out_dtype="float32",
)
return relay.Function(analysis.free_vars(dense), dense)

with tvm.target.Target("llvm"):
with TempOpAttr(
"nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout
):
a = run_opt_pass(before(), transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(a, b)


if __name__ == "__main__":
test_alter_op()
test_alter_return_none()
Expand All @@ -1377,3 +1422,4 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
test_alter_op_dense()
test_alter_layout_strided_slice_axes_nhwc()
test_not_inplace_modify()
test_alter_op_dense_packed_data()

0 comments on commit be74cbb

Please sign in to comment.