Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-246] operators for Synchronized BatchNorm #10303

Closed
wants to merge 2 commits into from

Conversation

zhanghang1989
Copy link
Contributor

@zhanghang1989 zhanghang1989 commented Mar 28, 2018

Description

Backend operators for Synchronized Batch Norm. Design idea as in http://hangzh.com/SynchronizeBN/

Special thanks to Haibin @eric-haibin-lin

  1. AllReduce
  2. Decouple Batch Norm
  3. SumSquare

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • AllReduce, tests, (and when applicable, API doc)
  • Decouple Batch Norm, tests, (and when applicable, API doc)
  • SumSquare, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@zhanghang1989 zhanghang1989 changed the title operators for Synchronized BatchNorm [MXNET-246] operators for Synchronized BatchNorm Mar 28, 2018
Copy link
Contributor

@piiswrong piiswrong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A single CrossDeviceBatchNorm would be much more efficient

@@ -164,7 +164,12 @@ inline void SetShapeType(const Context& ctx,
NDArrayStorageType storage_type = static_cast<NDArrayStorageType>(out_storage_types[i]);
if (outputs[i]->is_none()) {
if (storage_type == kDefaultStorage) {
*outputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]);
if (outputs.size() == inputs.size()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

???

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for AllReduce operator, the outputs should be in the same device as the corresponding input.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too hacky

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I aggree. I am not familiar with very backend code. Could you give suggestions?

@zhanghang1989
Copy link
Contributor Author

Only all-reduce operation need cross-device communication. The other operations should be calculated individually at different device. That's why I separate those operations to different "operators".

@piiswrong
Copy link
Contributor

why not combine them into the same operator? It would save memory and be much more efficient

@eric-haibin-lin eric-haibin-lin self-assigned this Mar 28, 2018
@chinakook
Copy link
Contributor

There are someone in TuSimple shared idea of this. It may help you.
https://zhuanlan.zhihu.com/p/27069202

@zhanghang1989
Copy link
Contributor Author

Thanks @chinakook Actually, we already implemented synchronized batch normalization internally. Just thinking about releasing a proper API for the user.

@zhanghang1989
Copy link
Contributor Author

MXNet Gluon Cross-GPU Batch Norm is Implemented here https://github.com/zhanghang1989/MXNet-Gluon-SyncBN . Feel free to try and leave comments.

@eric-haibin-lin eric-haibin-lin removed their assignment May 21, 2018
@zhanghang1989
Copy link
Contributor Author

Closing this in favor of the new PR #11502

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants