Skip to content

Commit

Permalink
[MXNET-614] Adding Synchronized Batch Normalization (apache#11502)
Browse files Browse the repository at this point in the history
* sync batch norm

* global rank and barrier

* lint

* cpplint

* pylint

* doc

* add ref

* customized barrier

* cpplint

* get rid of pthread

* address comments

* warning

* pylint

* gpu unitest

* gpu 0

* mv to cpu test

* Revert "mv to cpu test"

This reverts commit 24543c9.

* ndev = 2

* debuging

* sum prod

* lint

* contrib, ngpu

* code style

* code style

* forward backward

* test

* cpu test

* fix deconstruction

* doc indent

* doc

* doc

* address comments

* typo

* asnumpy
  • Loading branch information
zhanghang1989 authored and eric-haibin-lin committed Jul 14, 2018
1 parent 7992c2f commit 6f520a0
Show file tree
Hide file tree
Showing 6 changed files with 908 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/api/python/gluon/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p
HybridConcurrent
Identity
SparseEmbedding
SyncBatchNorm
```

### Recurrent neural network
Expand Down
84 changes: 81 additions & 3 deletions python/mxnet/gluon/contrib/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
# coding: utf-8
# pylint: disable= arguments-differ
"""Custom neural network layers in model_zoo."""
__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding']
__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding',
'SyncBatchNorm']

from .... import nd
import warnings
from .... import nd, test_utils
from ...block import HybridBlock, Block
from ...nn import Sequential, HybridSequential
from ...nn import Sequential, HybridSequential, BatchNorm

class Concurrent(Sequential):
"""Lays `Block`s concurrently.
Expand Down Expand Up @@ -157,3 +159,79 @@ def __repr__(self):
s = '{block_name}({input_dim} -> {output_dim}, {dtype})'
return s.format(block_name=self.__class__.__name__,
**self._kwargs)

class SyncBatchNorm(BatchNorm):
"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device.
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_.
Parameters
----------
in_channels : int, default 0
Number of channels (feature maps) in input data. If not specified,
initialization will be deferred to the first time `forward` is called
and `in_channels` will be inferred from the shape of input data.
num_devices : int, default number of visible GPUs
momentum: float, default 0.9
Momentum for the moving average.
epsilon: float, default 1e-5
Small float added to variance to avoid dividing by zero.
center: bool, default True
If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: bool, default True
If True, multiply by `gamma`. If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
use_global_stats: bool, default False
If True, use global moving statistics instead of local batch-norm. This will force
change batch-norm into a scale shift operator.
If False, use local batch-norm.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Initializer for the gamma weight.
moving_mean_initializer: str or `Initializer`, default 'zeros'
Initializer for the moving mean.
moving_variance_initializer: str or `Initializer`, default 'ones'
Initializer for the moving variance.
Inputs:
- **data**: input tensor with arbitrary shape.
Outputs:
- **out**: output tensor with the same shape as `data`.
Reference:
.. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating \
deep network training by reducing internal covariate shift." *ICML 2015*
.. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, \
Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
"""
def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5,
center=True, scale=True, use_global_stats=False, beta_initializer='zeros',
gamma_initializer='ones', running_mean_initializer='zeros',
running_variance_initializer='ones', **kwargs):
super(SyncBatchNorm, self).__init__(1, momentum, epsilon, center, scale, use_global_stats,
beta_initializer, gamma_initializer,
running_mean_initializer, running_variance_initializer,
in_channels, **kwargs)
num_devices = self._get_num_devices() if num_devices is None else num_devices
self._kwargs = {'eps': epsilon, 'momentum': momentum,
'fix_gamma': not scale, 'use_global_stats': use_global_stats,
'ndev': num_devices, 'key': self.prefix}

def _get_num_devices(self):
warnings.warn("Caution using SyncBatchNorm: "
"if not using all the GPUs, please mannually set num_devices",
UserWarning)
num_devices = len(test_utils.list_gpus())
num_devices = num_devices if num_devices > 0 else 1
return num_devices

def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
return F.contrib.SyncBatchNorm(x, gamma, beta, running_mean, running_var,
name='fwd', **self._kwargs)
Loading

0 comments on commit 6f520a0

Please sign in to comment.