Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[RFC] Deferred compute in imperative interface to unify imperative and symbolic interface #16376

Closed
leezu opened this issue Oct 5, 2019 · 9 comments
Labels
RFC Post requesting for comments

Comments

@leezu
Copy link
Contributor

leezu commented Oct 5, 2019

A new deferred computation (DC) argument to the imperative MXNet APIs is
proposed. If enabled, memory allocation and computation is deferred as long as
possible. Users can export the computational graph recorded during deferred
computation, which enables hybridization support.

Arrays for which DC is enabled are called lazy. Other arrays are called
normal. Inplace operations on lazy arrays are unsupported.

Storage allocation and computation for lazy arrays is deferred until their
results are required by conversion to numpy or use as input to an operator
creating a normal array. Accessing attributes such as shape can also trigger
computation if the attribute can't be inferred.

Update: The proposed implementation in #17530 differs slightly from the API previously described in this RFC. Thus I deleted the API docs in this RFC. Please refer to the PR. For example, a global state is used to enable / disable deferred compute, instead of introducing a new invocation API MXImperativeDeferredInvokeEx.

FAQ

How about Autograd, NDArray.autograd_entry_ and AGInfo?
Autograd inside deferred computation (DC) mode can be supported.

Relation of Autograd and DC: While autograd’s RecordOp provides a similar
recording functionality to the deferred computation, the autograd graph is not
the same as a computational graph: NDArray::Detach() serves to detach a node
from the autograd graph by deleting NDArray.entry_, though the NodeEntry is
still required for reconstructing the computational history of how this NDArray
came to be.

Are reqs like kInPlace supported?
No. For now only kWriteTo is supported in DC mode.

The plan is to replace inplace operations with kWriteTo operations, writing to
a new (lazy) array. The framework should be smart enough to decide when to reuse
memory and when not. It shouldn’t be required for users to specify that they
want an inplace operation.

How is context attribute handled, specifically context changes?

Cross-device copy must be represented as operator (CrossDeviceCopyOp) which
requires special handling in the graph executor.

How is incomplete shape information handled?
shape property triggers computation if shape is accessed and can't be inferred completely.
Users can access static_shape if they wan't to avoid triggering computation.

Python (Gluon)

Based on DC, hybridization in Gluon is simplified:

Instead of implementing def hybrid_forward(self, F, x, ...) in HybridBlock,
users can opt to implement def forward(self, x, ...) in HybridBlock.

Hybridization based on DC works by the HybridBlock performing the following
steps (if it is not called by a parent block being hybridized)

  • keeping a reference to the input arrays and a reference to the parameter
    arrays to pass them to MXNDArrayGetDeferredComputeSymbol;
  • enabling deferred compute mode
  • running forward
  • exporting to symbol and create CachedOp; Run CachedOp

A (internal) global context variable tracks if hybridization is ongoing. If set
to False and a Block is called that is to be hybridized, the global context
variable is set to True and the Block goes through all 4 steps outlined above;
finally the context variable is set back to False after the export to Symbol
step is finished.

Usage example

class Net(nn.HybridBlock):  
    def forward(self, x, ...):
        ...

Hybridizing gluon.Blocks?

DC could be used to support hybridzing Block if all logic can be traced. A
separate effort may add logic to detect these cases and add hybridization
support based on DC. For now we rely on user to signify hybridization support by
subclassing HybridBlock.

Parameter Shape Inference

For HybridBlock making use of DC for hybridization, we request users to
implement HybridBlock.infer_shape to infer the parameters shape given the
inputs.

Currently, if HybridBlock.infer_shape is not implemented, backward shape
inference is used to infer the shape of parameters. However backward shape
inference is not supported in all cases (cf #14253,
#14983 (comment))
and relying on it for parameter shape inference is brittle. Thus for consistency
and simplicity we require infer_shape method implementation when using
hybridization based on DC.

@mxnet-label-bot
Copy link
Contributor

Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended label(s): Feature

@szha
Copy link
Member

szha commented Oct 6, 2019

Thanks for the proposal, @leezu. Since this is a major change, I have some questions regarding the plan.

First, should we restrict this mode to only apply to the new numpy arrays? Since the deferred compute mode won't support reverse shape inference, new blocks that implement the forward interface will not work without implementing the parameter shape inference logic in infer_shape. This also applies when migrating the existing Gluon blocks in our API. Since we have plan to adopt numpy array in Gluon, the two changes can potentially happen at the same time.

Also, could you elaborate on what the changes are to the infer_shape, especially on how and when it's invoked during deferred initialization?

@asmushetzel
Copy link
Contributor

Sounds really interesting. Can you elaborate a bit more about specific use cases that this enables or simplifies? Is there something that can't be done today that this would enable? Are there major pain points that this would address compared to hybrid-blocks? Etc..

@leezu
Copy link
Contributor Author

leezu commented Oct 7, 2019

Thank you @szha and @asmushetzel for looking through the RFC.

Can you elaborate a bit more about specific use cases that this enables or simplifies? Is there something that can't be done today that this would enable? Are there major pain points that this would address compared to hybrid-blocks? Etc..

The RFC is not so much about extending what is possible, but improving the user experience. A major issue of the existing API is that mx.nd and mx.sym are distinct and partially incompatible. The issue of both being distinct is partially addressed by existing HybridBlock at the cost of making the issue of their incompatibility even more severe. Some of this is tracked in [Bug] Inconsistency between HybridBlock and Block.

Unifying symbolic and imperative mode with deferred compute also works towards [RFC] Introducing NumPy-compatible coding experience into MXNet. While with deferred compute we only trace a computational graph (as with current symbolic API), a logical next step is to provide support for parsing the AST of user provided implementation and directly hybridize it without tracing. You can find some more discussion on it in #14253. AST transformation also benefits from a unified interface, as a separate imperative and symbolic frontend would be meaningless.

First, should we restrict this mode to only apply to the new numpy arrays?

It may be feasible to provide support also for the normal ndarray interface. That said, I suggest to consider such support as a bonus. Providing backwards compatibility adds complexity for existing ndarray, which doesn't apply to new numpy arrays. The final decision could be taken later.

Since the deferred compute mode won't support reverse shape inference, new blocks that implement the forward interface will not work without implementing the parameter shape inference logic in infer_shape. This also applies when migrating the existing Gluon blocks in our API. Since we have plan to adopt numpy array in Gluon, the two changes can potentially happen at the same time.

Agree that both should happen at the same time

could you elaborate on what the changes are to the infer_shape, especially on how and when it's invoked during deferred initialization?

No conceptual change to the existing infer_shape API is required.
The current implementation works as follows, during forward, if called imperatively

https://github.com/apache/incubator-mxnet/blob/4940ec0e7408fad2443f921131cf1ada72724c38/python/mxnet/gluon/block.py#L1084-L1097

where _deferred_infer_shape calls infer_shape.
Exactly the same logic applies with proposed deferred compute mode. In Line 1091 a DeferredInitializationError will be caught, which is then handled by user-implemented implementation of infer_shape. If the user did not implement infer_shape, we raise a warning containing information on the requirement to implement infer_shape given the lack of general backward shape inference support.

@zachgk zachgk added the RFC Post requesting for comments label Nov 4, 2019
@szha
Copy link
Member

szha commented Dec 8, 2019

How's this project going?

@SunDoge
Copy link

SunDoge commented Jan 10, 2020

Is there any progress? I really like the static_shape part. Currently, the symbol has no shape attribute which makes it hard to use some ops in HybridBlock, for example

def hybrid_forward(self, F, feat):
    _B, C, H, W = feat.shape
    x = F.linspace(-1, 1, H)

even if I know the C, H, W will never change and I will never access the batch size B. I only need the shape once and the shape should be cached. This RFC may fix it.

@apeforest
Copy link
Contributor

This seems to be a big change to the existing operator mode (imperative and symbolic). Could you please provide more information.

AFAIK, symbolic API already does deferred init, imperative API is provided to improve user experience. Based on this RFC, what's the advantage of this new deferred_compute mode? As a user, when should I use it or not.

Another question. We all know deferred init cause bad user experience when it comes to debugging. Would this RFC address the debuggability issue?

If it's about performance optimization, could we have some initial data of using this new deferred mode vs. existing imperative mode?

Thanks,

Lin

@leezu
Copy link
Contributor Author

leezu commented Jan 28, 2020

This seems to be a big change to the existing operator mode (imperative and symbolic).

Essentially the motivation for deferred compute is to extend imperative mode to enable users to "construct a symbol" without using symbolic API. This addresses confusion around having two APIs and prevents divergence between imperative and symbolic APIs. There's no need to drop the existing imperative / symbolic APIs due to deferred compute.

Could you please provide more information.

Please ask a question and I'll answer ;)

AFAIK, symbolic API already does deferred init, imperative API is provided to improve user experience. Based on this RFC, what's the advantage of this new deferred_compute mode? As a user, when should I use it or not.

Based on deferred compute we can simplify gluon.HybridBlock API so that it matches the gluon.Block API. For example, consider you'd like to reimplement Dense(HybridBlock) based on extended HybridBlock API based on deferred compute:

class Dense(HybridBlock):
    def __init__(self, units, use_bias=True, flatten=True,
                 dtype='float32', weight_initializer=None, bias_initializer='zeros',
                 in_units=0): 
        super().__init__()
        self._flatten = flatten
        self._units = units
        self.weight = gluon.Parameter(shape=(units, in_units),
                                      init=weight_initializer, dtype=dtype,
                                      allow_deferred_init=True)
        if use_bias:
            self.bias = gluon.Parameter(shape=(units,),
                                        init=bias_initializer, dtype=dtype,
                                        allow_deferred_init=True)
        else:
            self.bias = None

    def forward(self, x):  # We allow users to overwrite forward() directly.    
        ctx = x.context
        return npx.FullyConnected(x, self.weight.data(ctx), self.bias.data(ctx),
              no_bias=bias is None, num_hidden=self._units,
              flatten=self._flatten, name='fwd')

HybridBlock can wrap the execution of forward into a deferred compute session and obtain a symbolic representation of the computation and pass it to CachedOp.

There would be no reason for users to explicitly use the API.

Another question. We all know deferred init cause bad user experience when it comes to debugging. Would this RFC address the debuggability issue?

This RFC is orthogonal to deferred init. When updating gluon.HybridBlock API based on deferred compute, one option is to require statically known shapes of weights at construction time if users implement def forward. For backwards compatibility we likely want to keep deferred init around for existing code relying on mx.sym and implementing def hybrid_forward.

However, the other option is to allow deferred initialization of weights and require users to implement infer_shape:

https://github.com/apache/incubator-mxnet/blob/910c608f682a47fc2c43375b5f5a426b563e5821/python/mxnet/gluon/block.py#L1073-L1075

This works around the failures of symbolic shape inference for deferred init in case of dynamic shape ops, while still allowing users to decide the shape of weight at first forward.

In the example above, it could look like:

class Dense(HybridBlock):
    def __init__(self, units, use_bias=True, flatten=True,
                 dtype='float32', weight_initializer=None, bias_initializer='zeros',
                 in_units=0): 
        [...]

    def infer_shape(self, x):
        self.weight.shape = (self.weight.shape[0], x.shape[1])

    def forward(self, x):
        [...]

If it's about performance optimization, could we have some initial data of using this new deferred mode vs. existing imperative mode?

There is the option to improve performance of imperative mode by deferring the computation and optimizing the computational graph before performing the computation. But this is not the main motivation and I haven't optimized for this use-case (yet). In the gluon.HybridBlock case, we only run with deferred compute once to construct the symbolic graph and then pass over to CachedOp for optimized execution.

@szha
Copy link
Member

szha commented Mar 29, 2020

This feature has been merged in #17530. Thanks for the great work @leezu

@szha szha closed this as completed Mar 29, 2020
samskalicky referenced this issue in Kh4L/incubator-mxnet May 16, 2020
…ption

Signed-off-by: Serge Panev <spanev@nvidia.com>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
RFC Post requesting for comments
Projects
None yet
Development

No branches or pull requests

7 participants