-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Support projection feature for LSTM on CPU (Only Inference) #17702
Conversation
981f9b0
to
759285b
Compare
@@ -385,7 +382,9 @@ The definition of GRU here is slightly different from paper but compatible with | |||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the projection support is clear in the documentation. Could you update the documentation with LSTMP support when projection_size is set? You can refer to https://github.com/apache/incubator-mxnet/blob/62a85f365b819829fedb60116f803e0c0a3c554c/python/mxnet/gluon/contrib/rnn/rnn_cell.py#L197
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Thanks for pointing out that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Please take a review again. Thanks.
CI has passed last time. The latest commit just added some documents for the projection feature. Accordingly, it should have no impact on functionality. Let's wait for CI validation. Please take a review. Thanks. @ciyongch @pengzhao-intel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution. Minor comments.
src/operator/rnn.cc
Outdated
@@ -385,7 +399,9 @@ The definition of GRU here is slightly different from paper but compatible with | |||
}) | |||
.set_attr<mxnet::FInferShape>("FInferShape", RNNShape) | |||
.set_attr<nnvm::FInferType>("FInferType", RNNType) | |||
#if MXNET_USE_MKLDNN == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's merge this check with the one at L407.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure you know that, if MKL-DNN is not used, FInferStorageType will not be registered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. The default storage type inference function will be executed when FInferStorageType is not registered.
src/operator/rnn.cc
Outdated
@@ -427,7 +443,9 @@ NNVM_REGISTER_OP(_backward_RNN) | |||
.set_attr_parser(ParamParser<RNNParam>) | |||
.set_attr<bool>("TIsLayerOpBackward", true) | |||
.set_attr<nnvm::TIsBackward>("TIsBackward", true) | |||
#if MXNET_USE_MKLDNN == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, merge this check with the one at L450.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks.
src/operator/rnn_impl.h
Outdated
const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, (P ? P : H))); | ||
Tensor<cpu, 2, DType> whr(w_ptr, Shape2(1, 1)); | ||
if (P > 0) | ||
whr = Tensor<cpu, 2, DType>(wh.dptr_ + P * 4 * H, Shape2(P, H)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's put this into the same line of if (P > 0)
or add {. .. }
for it, like what you're doing at L236.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put them into the same line, as well as L236.
* test solution for -Werror=maybe-uninitialized * Check device type when create state * Document the projection feature of LSTM for RNN operator * Minor fix
Thanks for all of your great works. I will merge PR after passed CI. |
…7702) * Support projection feature for LSTM on CPU * test solution for -Werror=maybe-uninitialized * Check device type when create state * Document the projection feature of LSTM for RNN operator * Minor fix * Re-run CI
…7702) * Support projection feature for LSTM on CPU * test solution for -Werror=maybe-uninitialized * Check device type when create state * Document the projection feature of LSTM for RNN operator * Minor fix * Re-run CI
* Support projection feature for LSTM on CPU (Only Inference) (#17702) * Support projection feature for LSTM on CPU * test solution for -Werror=maybe-uninitialized * Check device type when create state * Document the projection feature of LSTM for RNN operator * Minor fix * Re-run CI * Fix issue of zeros gradients w.r.t. RNN bias when num_layers > 1 (#17872) * Fix issue of zeros gradients w.r.t. RNN bias when num_layers > 1 * Use nd.copy() to initialize parameters of new operator * Add check for output states * Initialize i2h/h2h_weights with zeros for rnn_relu/tanh, and reduce size * Split fused rnn layer test into tests of individual mode * Skip lstm and gru tests on CPU context without DNNL
…18038) * Support projection feature for LSTM on CPU (Only Inference) (apache#17702) * Support projection feature for LSTM on CPU * test solution for -Werror=maybe-uninitialized * Check device type when create state * Document the projection feature of LSTM for RNN operator * Minor fix * Re-run CI * Fix issue of zeros gradients w.r.t. RNN bias when num_layers > 1 (apache#17872) * Fix issue of zeros gradients w.r.t. RNN bias when num_layers > 1 * Use nd.copy() to initialize parameters of new operator * Add check for output states * Initialize i2h/h2h_weights with zeros for rnn_relu/tanh, and reduce size * Split fused rnn layer test into tests of individual mode * Skip lstm and gru tests on CPU context without DNNL
…7702) * Support projection feature for LSTM on CPU * test solution for -Werror=maybe-uninitialized * Check device type when create state * Document the projection feature of LSTM for RNN operator * Minor fix * Re-run CI
Description
gluon.rnn.LSTM
has an argument -projection_size
- which enables the projection feature for LSTM operator. Previously, this feature is not supported on the CPU backend. This PR aims to add it to the CPU backend. It should be noticed that only the forward pass is ready in this PR. Backward pass andneeds_grads
scenario is not adapted to this feature. When it runs into those contents, it throws an error.@ciyongch @pengzhao-intel @TaoLv