Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP][MXNET-107] Fused LSTM implementation for CPU #10104

Merged
merged 39 commits into from
May 14, 2018
Merged
Changes from 2 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fd24ed2
register RNN fused-API with nnvm, finish single-layer && undirection …
Mar 8, 2018
ba0fe6d
fix coding style and lint complains
TaoLv Mar 8, 2018
a3c34ab
add single-layer && undirectional LSTM backward function
Mar 8, 2018
b5c1ef7
make interface universal for other RNN mode
Mar 9, 2018
73ed6dd
share intermediate result between forward and backward in a trick way
Mar 9, 2018
d72fe17
add comments for important parameters
Mar 12, 2018
d6811b5
modify testcase
Mar 14, 2018
d0306e5
Fix coding style and error message
TaoLv Mar 14, 2018
c2e7c8f
fix openmp collapse error
Mar 15, 2018
154aa3b
fix const
Mar 15, 2018
7c0cc29
remove rnn.cu and skip related testcases temporarily for building on GPU
Mar 15, 2018
b59f009
support multi-layer and bidirectional for lstm inference
Mar 17, 2018
26d32d2
remove some testcaseS in test_gluon_rnn.py to build on GPU
Mar 18, 2018
1b89cff
remove testcase between fp32 and fp64 temporarily
Mar 22, 2018
afd831d
retrigger ci
TaoLv Mar 22, 2018
ce818d3
fix some logs
Mar 26, 2018
f24ee4b
use a better way to share memory
Mar 26, 2018
d51dafd
fix cudnn registration
Mar 26, 2018
cdaadf7
fix invariant calculations and enable some gpu testcases
Mar 26, 2018
4161f3b
add thread local cache for cudnn rnn op
TaoLv Mar 26, 2018
f3dcb07
add thread local cache for rnn op
Mar 28, 2018
09f6e9a
fix bugs
Mar 28, 2018
c28bbc8
remove some testcases to check segmentfault
Mar 29, 2018
3370cb4
remove cudnn registeration to check segmentfault
Mar 29, 2018
46af847
support multi-layer for LSTM Training
Mar 30, 2018
e42e7f9
modify lstm testcase
Apr 2, 2018
e5b8b51
add bidirectional support for lstm
Apr 3, 2018
8a67315
fix gluon and coding style
Apr 4, 2018
78edb41
fix bugs
Apr 4, 2018
f50f5c0
remove nnvm registration
Apr 8, 2018
35a4a4b
enable gpu testcases
Apr 9, 2018
19ef217
add detailed descriptions
Apr 9, 2018
b0cfcf8
add dropout check
Apr 10, 2018
b6b567e
fix workspace size
Apr 27, 2018
1471836
Merge remote-tracking branch 'upstream/master' into lstm
TaoLv May 8, 2018
a52b5ef
dropout is not supported, add unit test for it
TaoLv May 9, 2018
a60de72
Merge remote-tracking branch 'upstream/master' into lstm
TaoLv May 9, 2018
3c61b84
fix review comments
TaoLv May 12, 2018
aeb8e9d
Merge remote-tracking branch 'upstream/master' into lstm
TaoLv May 12, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def check_rnn_consistency(cell1, cell2, T, N, I, H):
mod1.forward(batch, is_train=False)
mod2.forward(batch, is_train=False)
assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4)

# check training
mod1.forward(batch, is_train=True)
mod2.forward(batch, is_train=True)
Expand All @@ -63,7 +63,7 @@ def check_rnn_consistency(cell1, cell2, T, N, I, H):
mod2.backward(out_grads=[dy])
assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-4)

@with_seed(0)
@with_seed()
def test_lstm():
T, N, I, H = 5, 32, 800, 800
fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='lstm', get_next_state=True, prefix='')
Expand All @@ -73,7 +73,7 @@ def test_lstm():
stack.add(mx.rnn.LSTMCell(H, prefix='l2_'))
check_rnn_consistency(fused, stack, T, N, I, H)

@with_seed(0)
@with_seed()
def test_lstm_bidirectional():
T, N, I, H = 5, 20, 800, 800
fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='lstm',
Expand All @@ -91,6 +91,24 @@ def test_lstm_bidirectional():

check_rnn_consistency(stack, fused, T, N, I, H)

# Currently, fused LSTM operator doesn't support dropout.
# Will change this test after dropout is supported
@with_seed()
def test_lstm_dropout():
X = mx.sym.Variable('x')
Params = mx.sym.Variable('params')
HX = mx.sym.Variable('state')
CX = mx.sym.Variable('state_cell')
T, N, I, H = 300, 20, 800, 800
rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, state_cell=CX,
state_size=H, num_layers=5, mode='lstm', p=0.5, state_outputs=True, name='LSTM')
exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I))
try:
out = exe.forward(is_train=False)
out[0].wait_to_read()
assert False # should not reach here
except mx.base.MXNetError as err:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent approach! This will ensure we don't miss it to re-enable the test when we introduce dropout. Great job

Copy link
Member

@TaoLv TaoLv May 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Also to ensure the failure happens at a proper position and correct error message is presented. Follow @reminisce 's idea in PR 10844 .

assert str(err).find('Dropout is not supported at the moment') != -1

def np_softmax(x, axis=-1):
# fix for old numpy on Travis not supporting keepdims
Expand Down