Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

batch_norm - running mean and running var dont get updated when scale and center are false #18673

Closed
sammieghabra opened this issue Jul 8, 2020 · 2 comments
Labels

Comments

@sammieghabra
Copy link

sammieghabra commented Jul 8, 2020

Description

running mean and running var dont get updated when scale and center are false in batch_norm

Error Message

There is no error message, but the parameters running_mean and running_var don't get updated when scale and center are False.

To Reproduce

from mxnet import gluon
from mxnet.gluon import HybridBlock, Block
from mxnet import initializer
from mxnet.symbol import Variable, BlockGrad
from mxnet.initializer import Constant

import numpy as np

class ShiftScaleLayer(HybridBlock):
    def __init__(self, axis=-1, momentum=0.9, epsilon=1e-5, center=False, scale=False,
                 use_global_stats=False, beta_initializer='zeros', gamma_initializer='ones',
                 running_mean_initializer='zeros', running_variance_initializer='ones',
                 in_channels=0, **kwargs):
        super(ShiftScaleLayer, self).__init__(**kwargs)
        self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': momentum,
                        'fix_gamma': not scale, 'use_global_stats': use_global_stats}
        if in_channels != 0:
            self.in_channels = in_channels

        self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
                                     shape=(in_channels,), init=gamma_initializer,
                                     allow_deferred_init=True,
                                     differentiable=scale)
        self.beta = self.params.get('beta', grad_req='write' if center else 'null',
                                    shape=(in_channels,), init=beta_initializer,
                                    allow_deferred_init=True,
                                    differentiable=center)
        self.running_mean = self.params.get('running_mean', grad_req='null',
                                            shape=(in_channels,),
                                            init=running_mean_initializer,
                                            allow_deferred_init=True,
                                            differentiable=False)
        self.running_var = self.params.get('running_var', grad_req='null',
                                           shape=(in_channels,),
                                           init=running_variance_initializer,
                                           allow_deferred_init=True,
                                           differentiable=False)

    def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
        return F.BatchNorm(x, gamma, beta, running_mean, running_var,
                          name='fwd', **self._kwargs)

def print_params(title, net):
    """
    Helper function to print out the state of parameters of NormalizationHybridLayer
    """
    print(title)
    hybridlayer_params = {k: v for k, v in net.collect_params().items() }

    for key, value in hybridlayer_params.items():
        print('{} = {}\n'.format(key, value.data()))

from mxnet.gluon import nn
from mxnet.gluon.nn import Dense
from mxnet import nd

net = gluon.nn.HybridSequential()                             # Define a Neural Network as a sequence of hybrid blocks
with net.name_scope():                                        # Used to disambiguate saving and loading net parameters
    net.add(ShiftScaleLayer())
    net.add(Dense(10))

net.initialize(initializer.Xavier(magnitude=2.24))                # Initialize parameters of all layers
net.hybridize()

input = nd.array([[[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]])
label = nd.array([[[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]])

mse_loss = gluon.loss.L2Loss()                                # Mean squared error between output and label
trainer = gluon.Trainer(net.collect_params(),                 # Init trainer with Stochastic Gradient Descent (sgd) optimization method and parameters for it
                        'sgd',
                        {'learning_rate': 0.1, 'momentum': 0.9 })

from mxnet import autograd

with autograd.record():                                       # Autograd records computations done on NDArrays inside "with" block
    output = net(input)                                       # Run forward propogation

    print_params("=========== Parameters after forward pass ===========\n", net)
    loss = mse_loss(output, label)
    print(output)

loss.backward()                                               # Backward computes gradients and stores them as a separate array within each NDArray in .grad field
trainer.step(input.shape[0])                                  # Trainer updates parameters of every block, using .grad field using oprimization method (sgd in this example)
                                                              # We provide batch size that is used as a divider in cost function formula
print_params("=========== Parameters after backward pass ===========\n", net)

print(net(input))

Steps to reproduce

(Paste the commands you ran that produced the error.)

  1. Run the python script from above
  2. Observe that the ShiftScale layer's running mean and running var are not getting updated after backwards prop when scale and center are false.

What have you tried to solve it?

N/A

Environment

MXNet 1.6

paste outputs here

@wkcn
Copy link
Member

wkcn commented Jul 9, 2020

It is a bug and we have fixed it in PR #18500 #18517 #18518
Could you please try the latest version of MXNet, like MXNet 1.7 or MXNet 2.0 in https://dist.mxnet.io/python?

@ChaiBapchya
Copy link
Contributor

Closing since the issue seems to be fixed & no activity from issue creator.

@sammieghabra
Feel free to reopen if the issue persists with later versions of MXNet.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

3 participants