Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Use memcopy instead of set_handle when num_layer=0, direction=1
Browse files Browse the repository at this point in the history
  • Loading branch information
zixuanweeei committed Oct 29, 2019
1 parent 5eb89fc commit 9c91e45
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/operator/nn/mkldnn/mkldnn_rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,10 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_
char *weights_ptr = static_cast<char *>(w_ptr);
size_t wx_bytes = GetRnnGatesNum(param_.mode) * param_.state_size *
param_.input_size * dtype_bytes; //* DIMS: ngates x state_size x input_size
size_t wh_bytes = GetRnnGatesNum(param_.mode) * param_.state_size *
param_.state_size * dtype_bytes; //* DIMS: ngates x state_size x state_size
char *l2r_wx = weights_ptr;
char *l2r_wh = l2r_wx + wx_bytes; //* DIMS: ngates x state_size * state_size
char *l2r_wh = l2r_wx + wx_bytes; //* DIMS: ngates x state_size * state_size

if (param_.num_layer == 1 && param_.bidirectional) {
//* single bidirectinal layer, concat weights on direction axis
Expand All @@ -494,8 +496,8 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_
ConcatWeights(*weights_iter_r_, 1, {l2r_wh, r2l_wh}, format_tag::ldgoi);
} else if (param_.num_layer == 1 && !param_.bidirectional) {
//* single uni-directional layer, no concatenate operator needed
weights_layer_r_->set_data_handle(l2r_wx);
weights_iter_r_->set_data_handle(l2r_wh);
std::memcpy(weights_layer_r_->get_data_handle(), l2r_wx, wx_bytes);
std::memcpy(weights_iter_r_->get_data_handle(), l2r_wh, wh_bytes);
} else if (param_.num_layer > 1 && !param_.bidirectional) {
//* concat fused multi-layer weights on layer axis
std::vector<void *> l2r_wx_ptrs;
Expand All @@ -514,8 +516,6 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_
}

// Adjust gates order of LBR-GRU among concatenated memory inplace.
//* DIMS: ngates x state_size x state_size (ngates = 3, when mode == gru)
size_t wh_bytes = 3 * param_.state_size * param_.state_size * dtype_bytes;
char* fused_wx = static_cast<char*>(weights_layer_r_->get_data_handle());
char* fused_wh = static_cast<char*>(weights_iter_r_->get_data_handle());
if (param_.mode == rnn_enum::kGru) {
Expand Down Expand Up @@ -922,7 +922,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
weights_version_ = inputs[rnn_enum::kParams].version();
}

if (!initialized_ || ctx.is_train || fwd_trn_vec_.size() == 0) {
if (!initialized_ || ctx.is_train || fwd_inf_vec_.size() == 0) {
Init(ctx, inputs, req, outputs);
}

Expand Down

0 comments on commit 9c91e45

Please sign in to comment.