Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Dense alter layout fixed for packed input #8669

Merged
merged 13 commits into from
Aug 12, 2021
19 changes: 19 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,25 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
}
};

/*! \brief Attributes for dense_pack operator */
struct DensePackAttrs : public tvm::AttrsNode<DensePackAttrs> {
IndexExpr units;
DataType out_dtype;
tvm::String weight_layout;

TVM_DECLARE_ATTRS(DensePackAttrs, "relay.attrs.DensePackAttrs") {
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
TVM_ATTR_FIELD(weight_layout)
.set_default("NK")
.describe("Dimension ordering of weight. Packed layouts, such as NK8n, are possible.");
}
};

/*! \brief Attributes for batch matmul operator. */
struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
DataType out_dtype;
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,9 +1259,9 @@ def dense_shape_func(attrs, inputs, _):
@script
def _dense_pack_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[0] * weight_shape[2]
assert data_shape.shape[0] == 2, "Input data must be 2D"
out[0] = data_shape[0]
out[1] = weight_shape[0] * weight_shape[2]

return out

Expand Down
14 changes: 8 additions & 6 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,9 +1548,9 @@ def dense(data, weight, units=None, out_dtype=""):
return _make.dense(data, weight, units, out_dtype)


def contrib_dense_pack(data, weight, units=None, out_dtype=""):
def contrib_dense_pack(data, weight, weight_layout="NK", units=None, out_dtype=""):
"""Dense operator.
Applies a linear transformation
Applies a linear transformation with packed weight

.. math::

Expand All @@ -1560,25 +1560,27 @@ def contrib_dense_pack(data, weight, units=None, out_dtype=""):
----------
data : tvm.relay.Expr
The input data to the operator,
of shape `(d_1, d_2, ..., d_n, units_in)`.
of shape `(batch, units_in)`.

weight : tvm.relay.Expr
The transformed weight expressions, 3-D matrix,
of shape `(units // pack_weight_tile, units_in, pack_weight_tile)`.

weight_layout: str
The layout of weight, such as "NK" or "NK8n".

units : int, optional
Number of hidden units of the dense transformation.

out_dtype : str, optional
Specifies the output data type for mixed precision dense,
of shape `(d_1, d_2, ..., d_n, units)`.
Specifies the output data type for mixed precision dense.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_dense_pack(data, weight, units, out_dtype)
return _make.contrib_dense_pack(data, weight, weight_layout, units, out_dtype)


def fifo_buffer(data, buffer, axis):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/x86/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
relay.op.get("nn.dense"), attrs, tinfos, out_type, target
)
workload = autotvm.task.get_workload(outs)

if workload:
cfg = dispatch_ctx.query(target, workload)
topi_impl = workload[0]
Expand All @@ -62,7 +63,6 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
topi_impl,
)
dispatch_ctx.update(target, new_workload, cfg)
weight_transform = relay.layout_transform(inputs[1], "NK", weight_layout)
return relay.nn.contrib_dense_pack(inputs[0], weight_transform, None, out_dtype)
return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype)
Copy link
Member Author

@masahi masahi Aug 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @comaniac @yzhliu Please note the change above. layout_transform on inputs are inserted after this call back is completed, and weight_layout is passed so that it is propagated to InferCorrectLayout later.


return None
62 changes: 54 additions & 8 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) {
return Call(op, {data, weight}, Attrs(attrs), {});
}

InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
return InferCorrectLayoutOutput({"NC", "NK"}, {"NC"}, attrs);
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.dense").set_body_typed(MakeDense);

RELAY_REGISTER_OP("nn.dense")
Expand All @@ -221,35 +228,75 @@ RELAY_REGISTER_OP("nn.dense")
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight", "2D Tensor", "Weight matrix.")
.set_support_level(1)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", DenseInferCorrectLayout)
.add_type_rel("Dense", MatmulRel<DenseAttrs>);
// ------------------- relay.nn.dense

// ------------------- relay.nn.contrib_dense_pack
TVM_REGISTER_NODE_TYPE(DensePackAttrs);

// Positional relay function to create dense_pack operator used by frontend FFI.
Expr MakeDensePack(Expr data, Expr weight, IndexExpr units, DataType out_dtype) {
auto attrs = make_object<DenseAttrs>();
Expr MakeDensePack(Expr data, Expr weight, tvm::String weight_layout, IndexExpr units,
DataType out_dtype) {
auto attrs = make_object<DensePackAttrs>();
attrs->units = units;
attrs->out_dtype = out_dtype;
attrs->weight_layout = std::move(weight_layout);
static const Op& op = Op::Get("nn.contrib_dense_pack");
return Call(op, {data, weight}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_dense_pack").set_body_typed(MakeDensePack);

bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr || weight == nullptr) return false;

const DensePackAttrs* param = attrs.as<DensePackAttrs>();
ICHECK(param != nullptr);

ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported";
ICHECK_EQ(weight->shape.size(), 3) << "Weight is not packed";

Array<tvm::PrimExpr> oshape = data->shape;
oshape.Set(1, weight->shape[0] * weight->shape[2]);

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
auto params = attrs.as<DensePackAttrs>();
ICHECK(params);
return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs);
}

RELAY_REGISTER_OP("nn.contrib_dense_pack")
.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`.

- **data**: `(x1, x2, ..., xn, input_dim)`
- **data**: `(batch, input_dim)`
- **weight**: `(units // pack_weight_tile, input_dim, pack_weight_tile)`
- **out**: `(x1, x2, ..., xn, units)`.
- **out**: `(batch, units)`.

)code" TVM_ADD_FILELINE)
.set_attrs_type<DenseAttrs>()
.set_num_inputs(2)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("data", "2D Tensor", "Input data.")
.add_argument("weight", "3D Tensor", "Packed weight matrix.")
.set_support_level(10)
.add_type_rel("DensePack", DensePackRel<DenseAttrs>);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", DensePackInferCorrectLayout)
.add_type_rel("DensePack", DensePackRel);
// ------------------- relay.nn.contrib_dense_pack

// relay.leaky_relu
Expand Down Expand Up @@ -307,7 +354,6 @@ bool PReluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

template <typename T>
InferCorrectLayoutOutput PReluInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
Expand Down Expand Up @@ -343,7 +389,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
.add_argument("alpha", "Tensor", "Input channelwise alpha.")
.set_support_level(3)
.add_type_rel("PRelu", PReluRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout<PReluAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<PReluAttrs>();
Expand Down
23 changes: 0 additions & 23 deletions src/relay/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,29 +116,6 @@ bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

template <typename AttrType>
bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr || weight == nullptr) return false;

const AttrType* param = attrs.as<AttrType>();
ICHECK(param != nullptr);

Array<tvm::PrimExpr> oshape = data->shape;
oshape.Set((oshape.size() - 1), weight->shape[0] * weight->shape[2]);

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

template <typename AttrType>
bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expand Down
7 changes: 6 additions & 1 deletion tests/python/contrib/test_arm_compute_lib/test_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,16 @@ def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False):

if has_bias:
bias_dtype = "int32" if dtype == "uint8" else "float32"
bias_shape = (
[1, weight_shape[0]]
if dtype == "float32" and weight_shape[0] != 1
else [weight_shape[0]]
)
inputs.append(
{
"op": "const",
"name": "",
"attrs": {"shape": [[[weight_shape[0]]]], "dtype": [[bias_dtype]]},
"attrs": {"shape": [[bias_shape]], "dtype": [[bias_dtype]]},
}
)

Expand Down
48 changes: 47 additions & 1 deletion tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,9 @@ def expected():
weight = relay.var("weight", shape=(48, 64))
target_layout = "NK16n"
weight_transform = relay.layout_transform(weight, "NK", target_layout)
y = relay.nn.contrib_dense_pack(x, weight_transform, units=None, out_dtype="float32")
y = relay.nn.contrib_dense_pack(
x, weight_transform, target_layout, units=None, out_dtype="float32"
)
y = relay.Function(analysis.free_vars(y), y)
return y

Expand Down Expand Up @@ -1353,6 +1355,49 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
assert before.body.attrs.layout == "NCHW"


def test_alter_op_dense_packed_data():
def before():
x = relay.var("x", shape=(1, 32, 8, 8))
weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3))
conv = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0])
squeeze = relay.squeeze(pool, axis=[2, 3])
dense = relay.nn.dense(squeeze, relay.var("dense_weight", shape=(16, 32)))
return relay.Function(analysis.free_vars(dense), dense)

def expected():
x = relay.var("x", shape=(1, 32, 8, 8))
conv_weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3))
dense_weight = relay.var("dense_weight", shape=(16, 32))
conv = relay.nn.contrib_conv2d_nchwc(
relay.layout_transform(x, "NCHW", "NCHW8c"),
relay.layout_transform(conv_weight, "OIHW", "OIHW8i8o"),
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW8c",
kernel_layout="OIHW8i8o",
out_layout="NCHW8c",
)
pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0], layout="NCHW8c")
squeeze = relay.squeeze(pool, axis=[2, 3])
dense = relay.nn.contrib_dense_pack(
relay.layout_transform(squeeze, "NC8c", "NC"),
relay.layout_transform(dense_weight, "NK", "NK16n"),
"NK16n",
out_dtype="float32",
)
return relay.Function(analysis.free_vars(dense), dense)

with tvm.target.Target("llvm"):
with TempOpAttr(
"nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout
):
a = run_opt_pass(before(), transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(a, b)


if __name__ == "__main__":
test_alter_op()
test_alter_return_none()
Expand All @@ -1377,3 +1422,4 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
test_alter_op_dense()
test_alter_layout_strided_slice_axes_nhwc()
test_not_inplace_modify()
test_alter_op_dense_packed_data()