Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add deferred compute support #17530

Merged
merged 14 commits into from
Mar 23, 2020
Merged

Add deferred compute support #17530

merged 14 commits into from
Mar 23, 2020

Conversation

leezu
Copy link
Contributor

@leezu leezu commented Feb 5, 2020

Description

Implements #16376

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Add imperative deferred compute for Gluon 2

Comments

Copy link
Contributor

@reminisce reminisce left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Just a few comments.

include/mxnet/imperative.h Outdated Show resolved Hide resolved
src/imperative/imperative.cc Outdated Show resolved Hide resolved
python/mxnet/_deferred_compute.py Outdated Show resolved Hide resolved
src/imperative/imperative.cc Outdated Show resolved Hide resolved
src/c_api/c_api_ndarray.cc Outdated Show resolved Hide resolved
src/ndarray/ndarray.cc Outdated Show resolved Hide resolved
Copy link
Contributor

@samskalicky samskalicky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this contribution @leezu!

@leezu leezu force-pushed the deferredcompute branch 2 times, most recently from 96c6213 to d15f901 Compare February 15, 2020 19:37
@leezu leezu force-pushed the deferredcompute branch 9 times, most recently from 7e23173 to e64e03c Compare February 20, 2020 20:58
@leezu leezu force-pushed the deferredcompute branch 6 times, most recently from 14d9ddb to 9cd0420 Compare March 18, 2020 06:09
@eric-haibin-lin
Copy link
Member

LGTM

leezu added 13 commits March 23, 2020 01:39
Require users to call dc.set_variable(array, symbol) for every input array used
in deferred compute. Remove input and input_names arguments from dc.get_symbol.

Thereby prevent users from inadvertently using arrays as inputs without
specifying them in dc.get_symbol. Such use previously yielded "unspecified
inputs" error at time of dc.get_symbol call, making it hard for users to find
out where there code is wrong. Now, fail fast and throw the error as soon as an
"unsupported" array is used.

Note that below examples use the private dc.context and dc.set_variable APIs.
Users will not interact with it. It is used internally in HybridBlock.

Example of pitfall prior to this commit:

  a = mx.np.zeros((10, 10))
  with dc.context():
      # Creating an array from list can't be recorded. Must be specified as input.
      b = mx.np.array([1,2,3])
      c = a[b]
  dc.get_symbol(inputs=a, outputs=c)  # Throws "unspecified input" error.

"Correct" usage prior to this commit:

  a = mx.np.zeros((10, 10))
  with dc.context():
      # Creating an array from list can't be recorded. Must be specified as input.
      b = mx.np.array([1,2,3])
      c = a[b]
  dc.get_symbol(inputs=[a, b], outputs=c)

Following this commit:

  a = mx.np.zeros((10, 10))
  dc.set_variable(a, mx.sym.var('a').as_np_ndarray())
  with dc.context():
      b = mx.np.array([1,2,3])
      c = a[b]  # Throws: b is not associated with a variable or deferred computed
@leezu leezu merged commit 83b5170 into apache:master Mar 23, 2020
@leezu leezu deleted the deferredcompute branch March 23, 2020 18:21
anirudh2290 added a commit to anirudh2290/mxnet that referenced this pull request Mar 27, 2020
* 'master' of https://github.com/apache/incubator-mxnet: (192 commits)
  * impl - FFI for np einsum (apache#17869)
  [Numpy] FFI for diag/diagonal/diag_indices_from (apache#17789)
  [Numpy] Kron operator (apache#17323)
  cmake: Set DMLC_LOG_FATAL_THROW only for building mxnet and not for tvm (apache#17878)
  Add simplified HybridBlock.forward without F (apache#17530)
  Use FP32 copy of weights for norm (multitensor LAMB optimizer) (apache#17700)
  Use multi-tensor sumSQ in clip_global_norm (apache#17652)
  [Numpy] Add op fmax, fmin, fmod (apache#17567)
  Adding sparse support to MXTensor for custom operators (apache#17569)
  Update 3rdparty/mkldnn to v1.2.2 (apache#17313)
  Dynamic subgraph compile support (apache#17623)
  Refactor cpp-package CMakeLists.txt & add missing inference/imagenet_inference (apache#17835)
  staticbuild: Fix potential user-assisted execution of arbitrary code  (apache#17860)
  * FFI for np.argmax and np.argmin (apache#17843)
  ffi for roll/rot90 (apache#17861)
  Skip test_multi_worker_dataloader_release_pool on OS X (apache#17797)
  add ffi for full_like, binary (apache#17811)
  HybridBlock.export() to return created filenames (apache#17758)
  Fix SoftReLU fused operator numerical stability (apache#17849)
  CI: Test clang10 cpu & gpu builds with -WError (apache#17830)
  ...
MoisesHer pushed a commit to MoisesHer/incubator-mxnet that referenced this pull request Apr 10, 2020
Users can now implement HybridBlock.forward instead of HybridBlock.hybrid_forward.
HybridBlock.forward has the same signature as Block.forward. For example:

  class MyBlock(mx.gluon.HybridBlock):
      def __init__(self, *, prefix=None, params=None):
          super().__init__(prefix, params)
          with self.name_scope():
              self.dense = mx.gluon.nn.Dense(units=10)
              self.weight = self.params.get('weight', allow_deferred_init=True)
      def infer_shape(self, x):
          self.weight.shape = (x.shape[1], )
      def forward(self, x):
          return self.dense(x) + self.weight.data(x.context) 

Hybridization of HybridBlock.forward is based on a deferred computation mode in
the MXNet backend, which enables recording computation via tracing in the
mxnet.nd and mxnet.np interfaces. The recorded computation can be exported to a
symbolic representation and is used for optimized execution with the CachedOp.

As tracing is based on the imperative APIs, users can access shape information
of the arrays. As x.shape for some array x is a python tuple, any use of that
shape will be a constant in the recorded graph and may limit the recorded graph
to be used with inputs of the same shape only.

As part of the change from hybrid_forward to forward, we also disable support
for parameter shape inference in the MXNet backend in the case of deferred
parameter initialization. Shape inference in the backend was limited and did by
it's very nature not support dynamic shape operators. Instead, users should now
always implement HybridBlock.infer_shape to set the parameter shapes if the
parameter shape was not set during HybridBlock.__init__. See the example above.

An example of the internal deferred compute APIs is:

  a = mx.np.arange(10)
  dc.set_variable(a, mx.sym.var('a').as_np_ndarray())
  with dc.context():
      b = a ** 2
  symbol = dc.get_symbol(b)
anirudh2290 pushed a commit to anirudh2290/mxnet that referenced this pull request May 29, 2020
Users can now implement HybridBlock.forward instead of HybridBlock.hybrid_forward.
HybridBlock.forward has the same signature as Block.forward. For example:

  class MyBlock(mx.gluon.HybridBlock):
      def __init__(self, *, prefix=None, params=None):
          super().__init__(prefix, params)
          with self.name_scope():
              self.dense = mx.gluon.nn.Dense(units=10)
              self.weight = self.params.get('weight', allow_deferred_init=True)
      def infer_shape(self, x):
          self.weight.shape = (x.shape[1], )
      def forward(self, x):
          return self.dense(x) + self.weight.data(x.context) 

Hybridization of HybridBlock.forward is based on a deferred computation mode in
the MXNet backend, which enables recording computation via tracing in the
mxnet.nd and mxnet.np interfaces. The recorded computation can be exported to a
symbolic representation and is used for optimized execution with the CachedOp.

As tracing is based on the imperative APIs, users can access shape information
of the arrays. As x.shape for some array x is a python tuple, any use of that
shape will be a constant in the recorded graph and may limit the recorded graph
to be used with inputs of the same shape only.

As part of the change from hybrid_forward to forward, we also disable support
for parameter shape inference in the MXNet backend in the case of deferred
parameter initialization. Shape inference in the backend was limited and did by
it's very nature not support dynamic shape operators. Instead, users should now
always implement HybridBlock.infer_shape to set the parameter shapes if the
parameter shape was not set during HybridBlock.__init__. See the example above.

An example of the internal deferred compute APIs is:

  a = mx.np.arange(10)
  dc.set_variable(a, mx.sym.var('a').as_np_ndarray())
  with dc.context():
      b = a ** 2
  symbol = dc.get_symbol(b)
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants