Skip to content

Commit

Permalink
[Unity][Op] Group normalization
Browse files Browse the repository at this point in the history
This PR introduces the group normalization high-level operator.

Prior to this PR, the group normalization operations in frontend models
are translated to a series of operations, which brings inconvenience
when we want to optimize the group norm op as a whole.

With the TOPI implementation of group norm being introduced by apache#14193,
we can now use it to legalize the high-level group norm op and optimize
it using cross-thread reduction or rfactor via MetaSchedule.

Full implementation credit goes to Bohan.

Co-authored-by: Bohan Hou <spectrometerh@gmail.com>
  • Loading branch information
MasterJH5574 and spectrometerHBH committed Mar 4, 2023
1 parent c8f48d0 commit 74377e6
Show file tree
Hide file tree
Showing 11 changed files with 638 additions and 57 deletions.
21 changes: 21 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,27 @@ struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
}
}; // struct LayerNormAttrs

/*! \brief Attributes used in group_norm operator */
struct GroupNormAttrs : public tvm::AttrsNode<GroupNormAttrs> {
int num_groups;
int channel_axis;
Array<Integer> axes;
double epsilon;
bool center;
bool scale;

TVM_DECLARE_ATTRS(GroupNormAttrs, "relax.attrs.GroupNormAttrs") {
TVM_ATTR_FIELD(num_groups).describe("The number of groups to separate the channels into.");
TVM_ATTR_FIELD(channel_axis).describe("The axis that represents the channel.");
TVM_ATTR_FIELD(axes).describe(
"The axes that along which the normalization is applied (excluding the channel axis).");
TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero");
TVM_ATTR_FIELD(center).describe(
"Indicating if the beta offset will be added to the normalized tensor.");
TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied.");
}
}; // struct GroupNormAttrs

/*! \brief Attributes used in dropout operator */
struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
double rate;
Expand Down
54 changes: 20 additions & 34 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,44 +465,30 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var:
)

def _group_norm(self, node: fx.node.Node) -> relax.Var:
# torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05,
# affine=True, device=None, dtype=None)
import torch # type: ignore

x = self.env[node.args[0]]
module = self.named_modules[node.target]
num_groups = module.num_groups
num_channels = module.num_channels
eps = module.eps
affine = module.affine

shape = self.shape_of(x)
assert len(shape) == 4
N, C, H, W = shape[0], shape[1], shape[2], shape[3]
assert C == num_channels
assert C % num_groups == 0
grouped_x = self.block_builder.emit(
relax.op.reshape(x, [N, num_groups, C // num_groups, H, W])
)
mean_x = self.block_builder.emit(relax.op.mean(grouped_x, [2, 3, 4], keepdims=True))
sub_x = self.block_builder.emit(relax.op.subtract(grouped_x, mean_x))
square_x = self.block_builder.emit(relax.op.multiply(sub_x, sub_x))
sum_square_x = self.block_builder.emit(relax.op.sum(square_x, [2, 3, 4], keepdims=True))
var_x = self._call_binary_op(relax.op.divide, sum_square_x, (C // num_groups * H * W).value)
var_x_eps = self._call_binary_op(relax.op.add, var_x, eps)
std_x = self.block_builder.emit(relax.op.sqrt(var_x_eps))
norm_x = self.block_builder.emit(relax.op.divide(sub_x, std_x))

if affine:
weight = self.params[module.weight]
bias = self.params[module.bias]
weight_reshape = self.block_builder.emit(
relax.op.reshape(weight, (1, num_groups, C // num_groups, 1, 1))
)
bias_reshape = self.block_builder.emit(
relax.op.reshape(bias, (1, num_groups, C // num_groups, 1, 1))
if module.affine:
gamma = self.params[module.weight]
beta = self.params[module.bias]
else:
gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type)
beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type)

dim = len(self.shape_of(x))
return self.block_builder.emit(
relax.op.nn.group_norm(
x,
gamma,
beta,
num_groups=module.num_groups,
channel_axis=1,
axes=list(range(2, dim)),
epsilon=module.eps,
)
norm_x = self.block_builder.emit(relax.op.multiply(norm_x, weight_reshape))
norm_x = self.block_builder.emit(relax.op.add(norm_x, bias_reshape))
return self.block_builder.emit(relax.op.reshape(norm_x, (N, C, H, W)))
)

def _embedding(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
Expand Down
58 changes: 58 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,64 @@ def layer_norm(
return _ffi_api.layer_norm(data, gamma, beta, axes, epsilon, center, scale) # type: ignore


def group_norm(
data: Expr,
gamma: Expr,
beta: Expr,
num_groups: int,
channel_axis: int,
axes: Union[int, List[int]],
epsilon: float = 1e-5,
center: bool = True,
scale: bool = True,
) -> Expr:
r"""
Group normalization (Yuxin Wu and et al., 2016).
Applies group normalization to the n-dimensional input array.
This operator takes an n-dimensional input array. First separate the input array
into groups along the channel axis. Then apply layer normalization to each group.
Parameters
----------
data : relax.Expr
Input to which group_norm will be applied.
gamma : relax.Expr
The gamma scale factor.
beta : relax.Expr
The beta offset factor.
num_groups : int
Number of groups to separate the channels into.
channel_axis : int
The index of the channel axis in the input data.
axes : Union[int, List[int]]
The axes that along which the normalization is applied (excluding the group axis)
epsilon : float
Small float added to variance to avoid dividing by zero.
center : bool
Indicating if the beta offset will be added to the normalized tensor.
scale : bool
Indicating if the gamma scale will be multiplied.
Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(axes, int):
axes = [axes]
return _ffi_api.group_norm( # type: ignore
data, gamma, beta, num_groups, channel_axis, axes, epsilon, center, scale
)


def dropout(data: Expr, rate: float = 0.5) -> Expr:
"""Applies the dropout operation to the input tensor.
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ def _nn_layer_norm(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.nn.group_norm")
def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
topi.nn.group_norm,
call.args[0],
call.args[1],
call.args[2],
call.attrs.num_groups,
call.attrs.channel_axis,
call.attrs.axes,
call.attrs.epsilon,
)


@register_legalize("relax.nn.dropout")
def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
logging.info("Dropout is handled by frontend translator at this moment and is not legalized.")
Expand Down
83 changes: 83 additions & 0 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,89 @@ TVM_REGISTER_OP("relax.nn.layer_norm")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm);

/* relax.nn.group_norm */
TVM_REGISTER_NODE_TYPE(GroupNormAttrs);

Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis,
Array<Integer> axes, double epsilon, bool center, bool scale) {
ObjectPtr<GroupNormAttrs> attrs = make_object<GroupNormAttrs>();
attrs->num_groups = num_groups;
attrs->channel_axis = channel_axis;
attrs->axes = std::move(axes);
attrs->epsilon = epsilon;
attrs->center = center;
attrs->scale = scale;

static const Op& op = Op::Get("relax.nn.group_norm");
return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {});
}

TVM_REGISTER_GLOBAL("relax.op.nn.group_norm").set_body_typed(group_norm);

StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) {
Op op = Downcast<Op>(call->op);
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
const auto* attrs = call->attrs.as<GroupNormAttrs>();

TensorStructInfo data_sinfo = input_sinfo[0];
int channel_axis = -1;
if (!data_sinfo->IsUnknownNdim()) {
channel_axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->channel_axis);
std::vector<int> axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes);
// channel_axis must be in axes.
if (std::find(axes.begin(), axes.end(), channel_axis) != axes.end()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op
<< " expects that channel_axis must not be in axes, but got channel_axis: "
<< channel_axis << ", axes: " << attrs->axes);
}
}
if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " expects that data must be float, but got " << data_sinfo->dtype);
}
arith::Analyzer* analyzer = ctx->GetAnalyzer();
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape != nullptr && channel_axis != -1 &&
analyzer->CanProve(floormod(data_shape->values[channel_axis], attrs->num_groups) != 0)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " expects that the size of channel_axis must be divisible by "
<< attrs->num_groups << ", but got " << data_shape->values[channel_axis]);
}
for (int i = 1; i < static_cast<int>(op->arguments.size()); ++i) {
if (input_sinfo[i]->dtype != data_sinfo->dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " expects that all inputs must have the same dtype, but got "
<< input_sinfo[i]->dtype << " and " << data_sinfo->dtype);
} else if (input_sinfo[i]->ndim != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " expects that all inputs must have ndim=1, but got "
<< input_sinfo[i]->ndim);
} else if (channel_axis != -1) {
const auto* shape = input_sinfo[i]->shape.as<ShapeExprNode>();
if (shape != nullptr && data_shape != nullptr) {
PrimExpr channel_size = data_shape->values[channel_axis];
PrimExpr input_size = shape->values[0];
if (analyzer->CanProve(channel_size != input_size)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " expects that the size of input " << i
<< " must be equal to the size of channel_axis, but got " << input_size
<< " and " << channel_size);
}
}
}
}
return data_sinfo;
}

TVM_REGISTER_OP("relax.nn.group_norm")
.set_attrs_type<GroupNormAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGroupNorm);

/* relax.nn.dropout */
TVM_REGISTER_NODE_TYPE(DropoutAttrs);

Expand Down
4 changes: 4 additions & 0 deletions src/relax/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_
Expr layer_norm(Expr data, Expr gamma, Expr beta, Array<Integer> axes, double epsilon, bool center,
bool scale);

/*! \brief Compute group normalization. */
Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis,
Array<Integer> axes, double epsilon, bool center, bool scale);

/*!
* \brief Applies the dropout operation to the input tensor.
* \param data The input data to the operator.
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_ast_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def f(
y: R.Tensor(("m",), "float32"),
r: R.Tensor(dtype="int64"),
) -> R.Object:
m = T.var("int64")
m = T.int64()
z: R.Tensor((32, m), "float32") = R.multiply(x, y)
w: R.Tensor = R.multiply(z, z)
q: R.Tensor(ndim=2) = R.add(w, w)
Expand Down Expand Up @@ -431,7 +431,7 @@ def test_call_tir():
# also from test_parser
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")):
m, n = T.var("int64"), T.var("int64")
m, n = T.int64(), T.int64()
gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32"))
return gv0

Expand Down
32 changes: 11 additions & 21 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,29 +708,19 @@ def main(
w1: R.Tensor((3,), dtype="float32"),
w2: R.Tensor((3,), dtype="float32"),
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.reshape(
input_1, (1, 3, 1, 10, 10)
)
lv1: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.mean(
lv, axis=[2, 3, 4], keepdims=True
)
lv2: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.subtract(lv, lv1)
lv3: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv2, lv2)
lv4: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sum(
lv3, axis=[2, 3, 4], keepdims=True
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.group_norm(
input_1,
w1,
w2,
num_groups=3,
channel_axis=1,
axes=[2, 3],
epsilon=1.0000000000000001e-05,
center=True,
scale=True,
)
lv5: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.divide(lv4, R.const(100.0))
lv6: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.add(lv5, R.const(1e-05))
lv7: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sqrt(lv6)
lv8: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.divide(lv2, lv7)
lv9: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w1, (1, 3, 1, 1, 1))
lv10: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w2, (1, 3, 1, 1, 1))
lv11: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv8, lv9)
lv12: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.add(lv11, lv10)
lv13: R.Tensor((1, 3, 10, 10), dtype="float32") = R.reshape(lv12, (1, 3, 10, 10))
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv13
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

Expand Down
Loading

0 comments on commit 74377e6

Please sign in to comment.