diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc index f713c497a077..05401b2b80f9 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn.cc +++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc @@ -483,8 +483,10 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_ char *weights_ptr = static_cast(w_ptr); size_t wx_bytes = GetRnnGatesNum(param_.mode) * param_.state_size * param_.input_size * dtype_bytes; //* DIMS: ngates x state_size x input_size + size_t wh_bytes = GetRnnGatesNum(param_.mode) * param_.state_size * + param_.state_size * dtype_bytes; //* DIMS: ngates x state_size x state_size char *l2r_wx = weights_ptr; - char *l2r_wh = l2r_wx + wx_bytes; //* DIMS: ngates x state_size * state_size + char *l2r_wh = l2r_wx + wx_bytes; //* DIMS: ngates x state_size * state_size if (param_.num_layer == 1 && param_.bidirectional) { //* single bidirectinal layer, concat weights on direction axis @@ -494,8 +496,8 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_ ConcatWeights(*weights_iter_r_, 1, {l2r_wh, r2l_wh}, format_tag::ldgoi); } else if (param_.num_layer == 1 && !param_.bidirectional) { //* single uni-directional layer, no concatenate operator needed - weights_layer_r_->set_data_handle(l2r_wx); - weights_iter_r_->set_data_handle(l2r_wh); + std::memcpy(weights_layer_r_->get_data_handle(), l2r_wx, wx_bytes); + std::memcpy(weights_iter_r_->get_data_handle(), l2r_wh, wh_bytes); } else if (param_.num_layer > 1 && !param_.bidirectional) { //* concat fused multi-layer weights on layer axis std::vector l2r_wx_ptrs; @@ -514,8 +516,6 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_ } // Adjust gates order of LBR-GRU among concatenated memory inplace. - //* DIMS: ngates x state_size x state_size (ngates = 3, when mode == gru) - size_t wh_bytes = 3 * param_.state_size * param_.state_size * dtype_bytes; char* fused_wx = static_cast(weights_layer_r_->get_data_handle()); char* fused_wh = static_cast(weights_iter_r_->get_data_handle()); if (param_.mode == rnn_enum::kGru) { @@ -928,7 +928,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx, weights_version_ = inputs[rnn_enum::kParams].version(); } - if (!initialized_ || is_training || fwd_trn_vec_.size() == 0) { + if (!initialized_ || is_training || fwd_inf_vec_.size() == 0) { Init(ctx, inputs, req, outputs); }