Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[BUGFIX][1.8.x] Temporary fix for RNN with oneDNN seg faults/core dum…
Browse files Browse the repository at this point in the history
…ps (#19308)

* [1.8.x] Temporary fix for RNN with oneDNN  seg faults/core dumps

* fix sanity

* Fix typo and make compare function static inline
  • Loading branch information
bgawrych committed Oct 27, 2020
1 parent d1c2035 commit 7c86f48
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 10 deletions.
33 changes: 23 additions & 10 deletions src/operator/nn/mkldnn/mkldnn_rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ inline int GetRnnGatesNum(int mode) {
}
}

// Bug in oneDNN <= 1.6 in memory descriptor comparision operators.
// for specific dims and strides in descriptors == operator can return `true`
// but get_size() function will return different size
// TODO(bgawrych): Remove with oneDNN 1.7 upgrade
static inline bool CheckMemDescEquality(const mkldnn::memory::desc &left,
const mkldnn::memory::desc &right) {
return left == right && left.get_size() == right.get_size();
}

void MKLDNNRnnLayerParam::SetDims() {
const int ngates = GetRnnGatesNum(mode);
//* NOTES: LBR-GRU's new gate formula needs two bias. So it has one more bias with LBR-GRU
Expand Down Expand Up @@ -590,13 +599,13 @@ void MKLDNNRnnForwardTraining::SetTrnMem(const MKLDNNRnnForward& fwd) {
weights_iter_ = mkldnn_shared_mem_t(new memory(fwd_trn_.GetIterDesc(), cpu_engine));

// fill weights memory using the reordered weights of fwd_inference primitive
if (fwd.weights_layer_r_->get_desc() == fwd_trn_.GetLayerDesc()) {
if (CheckMemDescEquality(fwd.weights_layer_r_->get_desc(), fwd_trn_.GetLayerDesc())) {
weights_layer_->set_data_handle(fwd.weights_layer_r_->get_data_handle());
} else {
MKLDNNMemoryReorder(*fwd.weights_layer_r_, *weights_layer_);
}

if (fwd.weights_iter_r_->get_desc() == fwd_trn_.GetIterDesc()) {
if (CheckMemDescEquality(fwd.weights_iter_r_->get_desc(), fwd_trn_.GetIterDesc())) {
weights_iter_->set_data_handle(fwd.weights_iter_r_->get_data_handle());
} else {
MKLDNNMemoryReorder(*fwd.weights_iter_r_, *weights_iter_);
Expand Down Expand Up @@ -720,15 +729,15 @@ void MKLDNNRnnBackward::FetchDataWeightsMem(const MKLDNNRnnForwardTraining& fwd)
const mkldnn::memory* valid_mem;
switch (kv.first) {
case MKLDNN_ARG_WEIGHTS_LAYER: {
if (bwd_.weights_layer_desc_ == fwd.fwd_trn_.GetLayerDesc()) {
if (CheckMemDescEquality(bwd_.weights_layer_desc_, fwd.fwd_trn_.GetLayerDesc())) {
this->weights_layer_->set_data_handle(kv.second.get_data_handle());
} else {
MKLDNNMemoryReorder(*fwd.weights_layer_, *this->weights_layer_);
}
valid_mem = this->weights_layer_.get();
} break;
case MKLDNN_ARG_WEIGHTS_ITER: {
if (bwd_.weights_iter_desc_ == fwd.fwd_trn_.GetIterDesc()) {
if (CheckMemDescEquality(bwd_.weights_iter_desc_, fwd.fwd_trn_.GetIterDesc())) {
this->weights_iter_->set_data_handle(kv.second.get_data_handle());
} else {
MKLDNNMemoryReorder(*fwd.weights_iter_, *this->weights_iter_);
Expand Down Expand Up @@ -762,14 +771,14 @@ void MKLDNNRnnBackward::SetWeightsGradsMem() {
this->diff_weights_iter_r_ = std::make_shared<mkldnn::memory>(
native_iter_desc, cpu_engine);

if (native_layer_desc == bwd_.diff_weights_layer_desc_) {
if (CheckMemDescEquality(native_layer_desc, bwd_.diff_weights_layer_desc_)) {
this->diff_weights_layer_ = std::make_shared<mkldnn::memory>(
bwd_.diff_weights_layer_desc_, cpu_engine, diff_weights_layer_r_->get_data_handle());
} else {
this->diff_weights_layer_ = std::make_shared<mkldnn::memory>(
bwd_.diff_weights_layer_desc_, cpu_engine);
}
if (native_iter_desc == bwd_.diff_weights_iter_desc_) {
if (CheckMemDescEquality(native_iter_desc, bwd_.diff_weights_iter_desc_)) {
this->diff_weights_iter_ = std::make_shared<mkldnn::memory>(
bwd_.diff_weights_iter_desc_, cpu_engine, diff_weights_iter_r_->get_data_handle());
} else {
Expand Down Expand Up @@ -821,10 +830,12 @@ void MKLDNNRnnBackward::SetDataGradsMem(
}

void MKLDNNRnnBackward::SetNativeWeightsGrads() const {
if (this->diff_weights_layer_->get_desc() != this->diff_weights_layer_r_->get_desc()) {
if (!CheckMemDescEquality(this->diff_weights_layer_->get_desc(),
this->diff_weights_layer_r_->get_desc())) {
MKLDNNMemoryReorder(*this->diff_weights_layer_, *this->diff_weights_layer_r_);
}
if (this->diff_weights_iter_->get_desc() != this->diff_weights_iter_r_->get_desc()) {
if (!CheckMemDescEquality(this->diff_weights_iter_->get_desc(),
this->diff_weights_iter_r_->get_desc())) {
MKLDNNMemoryReorder(*this->diff_weights_iter_, *this->diff_weights_iter_r_);
}
}
Expand All @@ -843,9 +854,11 @@ void MKLDNNRnnBackward::CommitWeightsGrads(void* diff_weights, void* diff_bias,

void* diff_weights_layer_ptr = this->diff_weights_layer_->get_data_handle();
void* diff_weights_iter_ptr = this->diff_weights_iter_->get_data_handle();
if (this->diff_weights_layer_->get_desc() != this->diff_weights_layer_r_->get_desc())
if (!CheckMemDescEquality(this->diff_weights_layer_->get_desc(),
this->diff_weights_layer_r_->get_desc()))
diff_weights_layer_ptr = this->diff_weights_layer_r_->get_data_handle();
if (this->diff_weights_iter_->get_desc() != this->diff_weights_iter_r_->get_desc())
if (!CheckMemDescEquality(this->diff_weights_iter_->get_desc(),
this->diff_weights_iter_r_->get_desc()))
diff_weights_iter_ptr = this->diff_weights_iter_r_->get_data_handle();

const int num_layer = param.num_layer;
Expand Down
31 changes: 31 additions & 0 deletions tests/python/mkl/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, '../unittest/'))
from common import with_seed
import itertools


def test_mkldnn_model():
Expand Down Expand Up @@ -724,6 +725,36 @@ def check_elemwise_add_training(stype):
for stype in stypes:
check_elemwise_add_training(stype)


@with_seed()
def test_rnn():
SEQ_LENGTH = [2**10, 2**5]
STATE_SIZE = [1, 2]
BATCH_SIZE = [4]
INPUT_SIZE = [4]
def batch_check(seq_length, state_size, batch_size, input_size):
modes_params = [('rnn_relu', mx.np.random.normal(0, 1, ((input_size + state_size + 2)*state_size),)),
('rnn_tanh', mx.np.random.normal(0, 1, ((input_size + state_size + 2)*state_size),)),
('gru', mx.np.random.normal(0, 1, ((input_size + state_size + 2)*state_size*3),))
]
for m, p in modes_params:
data = mx.np.random.normal(0, 1, (seq_length, batch_size, input_size))
state = mx.np.random.normal(0, 1, (1, batch_size, state_size))
data.attach_grad()
state.attach_grad()

with mx.autograd.record():
y = mx.npx.rnn(data=data, parameters=p, mode=m, \
state=state, state_size=state_size, num_layers=1)
assert y.shape == (seq_length, batch_size, state_size)
assert type(y[0]).__name__ == 'ndarray'
y.backward()
assert state.shape == (1, batch_size, state_size)
assert type(state[0]).__name__ == 'ndarray'

for sl, ss, bs, in_s in itertools.product(SEQ_LENGTH, STATE_SIZE, BATCH_SIZE, INPUT_SIZE):
batch_check(sl, ss, bs, in_s)

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 7c86f48

Please sign in to comment.