Skip to content

Commit

Permalink
[TOPI] Batch Norm Training Mode (#14190)
Browse files Browse the repository at this point in the history
Prior to this PR, TOPI batch_norm only supports inference.

This PR adds training: bool flag and momentum: float argument to support training mode (update moving_mean / var and return), which aligns with torch.nn.functional.batch_norm.
  • Loading branch information
SiriusNEO authored Mar 4, 2023
1 parent 736ceca commit 22c47ee
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 21 deletions.
46 changes: 42 additions & 4 deletions python/tvm/topi/nn/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Batch normalization."""
import typing
from functools import reduce

from tvm import te
from tvm import topi
Expand All @@ -31,6 +32,8 @@ def batch_norm(
epsilon: typing.Optional[float] = None,
center: typing.Optional[bool] = None,
scale: typing.Optional[bool] = None,
training: typing.Optional[bool] = None,
momentum: typing.Optional[float] = None,
) -> typing.List[te.Tensor]:
"""Batch normalization layer (Ioffe and Szegedy, 2014).
Expand Down Expand Up @@ -69,6 +72,13 @@ def batch_norm(
If True, scale normalized tensor by gamma. If False, gamma
is ignored.
training : bool, optional, defualt=False
Indicating whether it is in training mode. If True, update
moving_mean and moving_var.
momentum : float, optional, default=0.1
The value used for the moving_mean and moving_var update.
Returns
-------
output : list of tvm.te.Tensor
Expand All @@ -92,19 +102,47 @@ def batch_norm(
if scale is None:
scale = True

if training is None:
training = False

if momentum is None:
momentum = 0.1

shape = [1] * len(data.shape)
shape[axis] = data.shape[axis]

moving_mean_rs = topi.reshape(moving_mean, shape)
moving_var_rs = topi.reshape(moving_var, shape)

out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon)
if training:
reduce_axes = list(range(len(data.shape)))
reduce_axes.remove(axis)
shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in reduce_axes], 1)
data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
data_mean_rs = topi.reshape(data_mean, shape)
data_var = (
topi.sum((data - data_mean_rs) * (data - data_mean_rs), axis=reduce_axes) / shape_prod
)
data_var_rs = topi.reshape(data_var, shape)
out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)
else:
moving_mean_rs = topi.reshape(moving_mean, shape)
moving_var_rs = topi.reshape(moving_var, shape)
out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon)

if scale:
out = out * topi.reshape(gamma, shape)
if center:
out = out + topi.reshape(beta, shape)

if training:
assert 0 <= momentum <= 1, "the valid momentum range is [0, 1]."
data_var = (
topi.sum((data - data_mean_rs) * (data - data_mean_rs), axis=reduce_axes) / shape_prod
)
return [
out,
(1 - momentum) * moving_mean + momentum * data_mean,
(1 - momentum) * moving_var + momentum * data_var,
]

# Moving mean and var aren't updated during test. To avoid
# placeholder reuse, we multiply by 1 and return them.
return [out, moving_mean * 1, moving_var * 1]
33 changes: 29 additions & 4 deletions python/tvm/topi/testing/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def batch_norm(
epsilon: float,
center: bool,
scale: bool,
training: bool,
momentum: float,
):
"""Batch Normalization operator implemented in Numpy.
Expand Down Expand Up @@ -62,6 +64,13 @@ def batch_norm(
If True, scale normalized tensor by gamma. If False, gamma
is ignored.
training : bool
Indicating whether it is in training mode. If True, update
moving_mean and moving_var.
momentum : float
The value used for the moving_mean and moving_var update
Returns
-------
output : np.ndarray
Expand All @@ -76,14 +85,30 @@ def batch_norm(
shape = [1] * len(x.shape)
shape[axis] = x.shape[axis]

moving_mean_rs = moving_mean.reshape(shape)
moving_var_rs = moving_var.reshape(shape)

out = (x - moving_mean_rs) / np.sqrt(moving_var_rs + epsilon)
if training:
reduce_axes = list(range(len(x.shape)))
reduce_axes.remove(axis)
reduce_axes = tuple(reduce_axes)
data_mean = np.mean(x, axis=reduce_axes)
data_var = np.var(x, axis=reduce_axes)
data_mean_rs = np.reshape(data_mean, shape)
data_var_rs = np.reshape(data_var, shape)
out = (x - data_mean_rs) / np.sqrt(data_var_rs + epsilon)
else:
moving_mean_rs = moving_mean.reshape(shape)
moving_var_rs = moving_var.reshape(shape)
out = (x - moving_mean_rs) / np.sqrt(moving_var_rs + epsilon)

if scale:
out = out * gamma.reshape(shape)
if center:
out = out + beta.reshape(shape)

if training:
return [
out,
(1 - momentum) * moving_mean + momentum * data_mean,
(1 - momentum) * moving_var + momentum * data_var,
]

return [out, moving_mean, moving_var]
46 changes: 33 additions & 13 deletions tests/python/topi/python/test_topi_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,37 @@


@pytest.mark.parametrize(
"shape, axis, epsilon, center, scale",
"shape, axis, epsilon, center, scale, training, momentum",
[
((1,), 0, 0.1, True, True),
((2, 3), 0, 0.1, True, True),
((1, 2, 4), 0, 0.1, True, True),
((1, 2, 3, 4), 0, 0.001, False, False),
((2, 3, 4, 1), 1, 0.01, False, True),
((3, 4, 1, 2), 2, 0.1, True, False),
((4, 1, 2, 3), 3, 1.0, True, True),
((1, 2, 4, 4, 5), 0, 0.1, True, True),
((1,), 0, 0.1, True, True, False, 0.1),
((2, 3), 0, 0.1, True, True, False, 0.1),
((1, 2, 4), 0, 0.1, True, True, False, 0.1),
((1, 2, 3, 4), 0, 0.001, False, False, False, 0.1),
((2, 3, 4, 1), 1, 0.01, False, True, False, 0.1),
((3, 4, 1, 2), 2, 0.1, True, False, True, 0.1),
((4, 1, 2, 3), 3, 1.0, True, True, True, 0.2),
((1, 2, 4, 4, 5), 0, 0.1, True, True, True, 0.3),
],
)
def test_batch_norm(shape, axis, epsilon, center, scale):
def test_batch_norm(shape, axis, epsilon, center, scale, training, momentum):
x_np = np.random.random(shape).astype("float32")
gamma_np = np.random.random(shape[axis]).astype("float32")
beta_np = np.random.random(shape[axis]).astype("float32")
moving_mean_np = np.random.random(shape[axis]).astype("float32")
moving_var_np = np.random.random(shape[axis]).astype("float32")

out_x_np, out_moving_mean_np, out_moving_var_np = tvm.topi.testing.batch_norm(
x_np, gamma_np, beta_np, moving_mean_np, moving_var_np, axis, epsilon, center, scale
x_np,
gamma_np,
beta_np,
moving_mean_np,
moving_var_np,
axis,
epsilon,
center,
scale,
training,
momentum,
)

x_te = te.placeholder(shape, name="x", dtype="float32")
Expand All @@ -65,7 +75,17 @@ def test_batch_norm(shape, axis, epsilon, center, scale):
with tvm.target.Target(_DEVICE):
fcompute, fschedule = tvm.topi.testing.dispatch(_DEVICE, _BATCH_NORM_IMPLEMENT)
out_x, out_moving_mean, out_moving_var = fcompute(
x_te, gamma_te, beta_te, moving_mean_te, moving_var_te, axis, epsilon, center, scale
x_te,
gamma_te,
beta_te,
moving_mean_te,
moving_var_te,
axis,
epsilon,
center,
scale,
training,
momentum,
)
s = fschedule([out_x, out_moving_mean, out_moving_var])

Expand Down Expand Up @@ -113,4 +133,4 @@ def test_batch_norm(shape, axis, epsilon, center, scale):


if __name__ == "__main__":
test_batch_norm()
tvm.testing.main()

0 comments on commit 22c47ee

Please sign in to comment.