Skip to content

Commit

Permalink
[Unity][Op] Group normalization (#14194)
Browse files Browse the repository at this point in the history
* [TOPI] Group normalization

As more and more ML models nowadays contain the group normalization
computation, we find it beneficial to introduce this op to TOPI level.
It will enable us to optimize the group normalization operation as a
whole in a more convenient way.

This PR introduces the group normalization op to TOPI. The group norm
operation was introduced in https://arxiv.org/abs/1803.08494. The
implementation uses tuple reduction, same as the implementation of layer
norm. Implemented with tuple reduction, the corresponding generated TIR
function can be optimized by cross-thread reduction or rfactor through
MetaSchedule.

Prior to this PR, the group normalization operations in frontend models
are translated to a series of operations, which brings inconvenience
when we want to optimize the group norm op as a whole.

With the TOPI implementation of group norm being introduced by #14193,
we can now use it to legalize the high-level group norm op and optimize
it using cross-thread reduction or rfactor via MetaSchedule.


Co-authored-by: Bohan Hou <spectrometerh@gmail.com>
  • Loading branch information
MasterJH5574 and spectrometerHBH authored Mar 4, 2023
1 parent 81bf988 commit 12f7cab
Show file tree
Hide file tree
Showing 19 changed files with 999 additions and 58 deletions.
21 changes: 21 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,27 @@ struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
}
}; // struct LayerNormAttrs

/*! \brief Attributes used in group_norm operator */
struct GroupNormAttrs : public tvm::AttrsNode<GroupNormAttrs> {
int num_groups;
int channel_axis;
Array<Integer> axes;
double epsilon;
bool center;
bool scale;

TVM_DECLARE_ATTRS(GroupNormAttrs, "relax.attrs.GroupNormAttrs") {
TVM_ATTR_FIELD(num_groups).describe("The number of groups to separate the channels into.");
TVM_ATTR_FIELD(channel_axis).describe("The axis that represents the channel.");
TVM_ATTR_FIELD(axes).describe(
"The axes that along which the normalization is applied (excluding the channel axis).");
TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero");
TVM_ATTR_FIELD(center).describe(
"Indicating if the beta offset will be added to the normalized tensor.");
TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied.");
}
}; // struct GroupNormAttrs

/*! \brief Attributes used in dropout operator */
struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
double rate;
Expand Down
151 changes: 151 additions & 0 deletions include/tvm/topi/nn/group_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief group normalization op constructions
* \file nn/group_norm.h
*/
#ifndef TVM_TOPI_NN_GROUP_NORM_H_
#define TVM_TOPI_NN_GROUP_NORM_H_

#include <tvm/te/operation.h>
#include <tvm/topi/tags.h>

#include <algorithm>
#include <string>
#include <vector>

namespace tvm {
namespace topi {
namespace nn {

using namespace tvm::te;

inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta,
int num_groups, int channel_axis, const Array<Integer>& axes,
double epsilon, std::string name = "T_group_norm",
std::string tag = kInjective) {
// reshape data C -> G, C/G
int ndim = data->shape.size();
channel_axis = GetRealAxis(ndim, {channel_axis})[0];

auto shape = data->shape;
auto group_size = floordiv(shape[channel_axis], num_groups);
auto new_shape = Array<PrimExpr>();
for (int i = 0; i < ndim; ++i) {
if (i == channel_axis) {
new_shape.push_back(num_groups);
new_shape.push_back(group_size);
} else {
new_shape.push_back(shape[i]);
}
}
auto data_reshaped = reshape(data, new_shape);
// reshape gamma and beta, C -> G, C/G
Tensor gamma_reshaped;
if (gamma.defined()) {
gamma_reshaped = reshape(gamma, {num_groups, group_size});
}
Tensor beta_reshaped;
if (beta.defined()) {
beta_reshaped = reshape(beta, {num_groups, group_size});
}

// get the new axes to normalize after reshape
std::vector<int> new_axes{channel_axis + 1};
for (auto axis : axes) {
int new_axis = GetRealAxis(ndim, {axis})[0];
if (new_axis < channel_axis) {
new_axes.push_back(new_axis);
} else if (new_axis > channel_axis) {
new_axes.push_back(new_axis + 1);
} else {
ICHECK(false) << "axes can not contain channel axis";
}
}
std::sort(new_axes.begin(), new_axes.end());

// sum x and x^2
ndim = data_reshaped->shape.size();
auto reduce_axes = MakeReduceAxes(new_axes, data_reshaped);
auto target_shape =
MakeReduceTargetShape(new_axes, data_reshaped, /*keepdims=*/false, /*atleast1d=*/true);
auto func = MakeTupleSumReducer();

auto compute = [ndim, &new_axes, &reduce_axes, &func, &data_reshaped](const Array<Var>& indices) {
Array<PrimExpr> eval_range;
int arg_counter = 0;
int red_counter = 0;

for (int i = 0; i < ndim; ++i) {
if (std::find(new_axes.begin(), new_axes.end(), i) != new_axes.end()) {
// new_axes contains i
eval_range.push_back(reduce_axes[red_counter]);
red_counter++;
} else {
eval_range.push_back(indices[arg_counter]);
arg_counter++;
}
}
auto square = [](const PrimExpr& x) { return x * x; };
return func({data_reshaped(eval_range), square(data_reshaped(eval_range))}, reduce_axes,
nullptr);
};

auto temp_x_x2 =
tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce);

auto temp_x = temp_x_x2[0];
auto temp_x2 = temp_x_x2[1];
auto reduce_extent = make_const(data->dtype, 1);
for (auto axis : new_axes) {
reduce_extent *= data_reshaped->shape[axis];
}
auto group_norm_func = [&](const Array<Var>& indices) {
Array<Var> reduce_indices, non_reduce_indices, gamma_indices;
for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
if (std::find(new_axes.begin(), new_axes.end(), i) != new_axes.end()) {
reduce_indices.push_back(indices[i]);
} else {
non_reduce_indices.push_back(indices[i]);
}
}
gamma_indices = {indices[channel_axis], indices[channel_axis + 1]};
auto mean = temp_x(non_reduce_indices) / reduce_extent;
auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
auto group_norm =
(data_reshaped(indices) - mean) * tvm::rsqrt(var + make_const(data->dtype, epsilon));
if (gamma.defined()) {
group_norm = topi::multiply(group_norm, gamma_reshaped(gamma_indices));
}
if (beta.defined()) {
group_norm = topi::add(group_norm, beta_reshaped(gamma_indices));
}
return group_norm;
};
auto group_norm_out = tvm::te::compute(data_reshaped->shape, group_norm_func, name, tag);
auto group_norm_out_reshaped = reshape(group_norm_out, shape);
return group_norm_out_reshaped;
}

} // namespace nn
} // namespace topi
} // namespace tvm

#endif // TVM_TOPI_NN_GROUP_NORM_H_
54 changes: 20 additions & 34 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,44 +465,30 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var:
)

def _group_norm(self, node: fx.node.Node) -> relax.Var:
# torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05,
# affine=True, device=None, dtype=None)
import torch # type: ignore

x = self.env[node.args[0]]
module = self.named_modules[node.target]
num_groups = module.num_groups
num_channels = module.num_channels
eps = module.eps
affine = module.affine

shape = self.shape_of(x)
assert len(shape) == 4
N, C, H, W = shape[0], shape[1], shape[2], shape[3]
assert C == num_channels
assert C % num_groups == 0
grouped_x = self.block_builder.emit(
relax.op.reshape(x, [N, num_groups, C // num_groups, H, W])
)
mean_x = self.block_builder.emit(relax.op.mean(grouped_x, [2, 3, 4], keepdims=True))
sub_x = self.block_builder.emit(relax.op.subtract(grouped_x, mean_x))
square_x = self.block_builder.emit(relax.op.multiply(sub_x, sub_x))
sum_square_x = self.block_builder.emit(relax.op.sum(square_x, [2, 3, 4], keepdims=True))
var_x = self._call_binary_op(relax.op.divide, sum_square_x, (C // num_groups * H * W).value)
var_x_eps = self._call_binary_op(relax.op.add, var_x, eps)
std_x = self.block_builder.emit(relax.op.sqrt(var_x_eps))
norm_x = self.block_builder.emit(relax.op.divide(sub_x, std_x))

if affine:
weight = self.params[module.weight]
bias = self.params[module.bias]
weight_reshape = self.block_builder.emit(
relax.op.reshape(weight, (1, num_groups, C // num_groups, 1, 1))
)
bias_reshape = self.block_builder.emit(
relax.op.reshape(bias, (1, num_groups, C // num_groups, 1, 1))
if module.affine:
gamma = self.params[module.weight]
beta = self.params[module.bias]
else:
gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type)
beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type)

dim = len(self.shape_of(x))
return self.block_builder.emit(
relax.op.nn.group_norm(
x,
gamma,
beta,
num_groups=module.num_groups,
channel_axis=1,
axes=list(range(2, dim)),
epsilon=module.eps,
)
norm_x = self.block_builder.emit(relax.op.multiply(norm_x, weight_reshape))
norm_x = self.block_builder.emit(relax.op.add(norm_x, bias_reshape))
return self.block_builder.emit(relax.op.reshape(norm_x, (N, C, H, W)))
)

def _embedding(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
Expand Down
58 changes: 58 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,64 @@ def layer_norm(
return _ffi_api.layer_norm(data, gamma, beta, axes, epsilon, center, scale) # type: ignore


def group_norm(
data: Expr,
gamma: Expr,
beta: Expr,
num_groups: int,
channel_axis: int,
axes: Union[int, List[int]],
epsilon: float = 1e-5,
center: bool = True,
scale: bool = True,
) -> Expr:
r"""
Group normalization (Yuxin Wu and et al., 2016).
Applies group normalization to the n-dimensional input array.
This operator takes an n-dimensional input array. First separate the input array
into groups along the channel axis. Then apply layer normalization to each group.
Parameters
----------
data : relax.Expr
Input to which group_norm will be applied.
gamma : relax.Expr
The gamma scale factor.
beta : relax.Expr
The beta offset factor.
num_groups : int
Number of groups to separate the channels into.
channel_axis : int
The index of the channel axis in the input data.
axes : Union[int, List[int]]
The axes that along which the normalization is applied (excluding the group axis)
epsilon : float
Small float added to variance to avoid dividing by zero.
center : bool
Indicating if the beta offset will be added to the normalized tensor.
scale : bool
Indicating if the gamma scale will be multiplied.
Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(axes, int):
axes = [axes]
return _ffi_api.group_norm( # type: ignore
data, gamma, beta, num_groups, channel_axis, axes, epsilon, center, scale
)


def dropout(data: Expr, rate: float = 0.5) -> Expr:
"""Applies the dropout operation to the input tensor.
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ def _nn_layer_norm(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.nn.group_norm")
def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
topi.nn.group_norm,
call.args[0],
call.args[1],
call.args[2],
call.attrs.num_groups,
call.attrs.channel_axis,
call.attrs.axes,
call.attrs.epsilon,
)


@register_legalize("relax.nn.dropout")
def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
logging.info("Dropout is handled by frontend translator at this moment and is not legalized.")
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .qnn import *
from .upsampling import *
from .layer_norm import layer_norm
from .group_norm import group_norm
from .local_response_norm import *
from .bitserial_conv2d import *
from .bitserial_dense import *
Expand Down
52 changes: 52 additions & 0 deletions python/tvm/topi/nn/group_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Layer normalization operator."""
from .. import cpp


def group_norm(data, gamma, beta, num_groups, channel_axis, axes, epsilon=1e-5):
"""Group normalization operator.
Parameters
----------
data : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})
gamma: tvm.te.Tensor
1-D with shape (r_0) where r_0 == d_{channel_axis}
beta: tvm.te.Tensor
Optional, 1-D with shape (r_0) where r_0 == d_{channel_axis}
num_groups : int
The number of groups
channel_axis : int
The channel axis
axes : list of int
Axis over the normalization applied, excluding the channel axis
epsilon : float
The epsilon value to avoid division by zero.
Returns
-------
result : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
return cpp.nn.group_norm(data, gamma, beta, num_groups, channel_axis, axes, epsilon)
1 change: 1 addition & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .roi_align_python import roi_align_nchw_python, roi_align_nhwc_python
from .roi_pool_python import roi_pool_nchw_python
from .layer_norm_python import layer_norm_python
from .group_norm_python import group_norm_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
from .gather_python import gather_python
Expand Down
Loading

0 comments on commit 12f7cab

Please sign in to comment.