Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Aug 9, 2018
1 parent 031b8f1 commit d5ab088
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions tests/python/unittest/test_contrib_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,40 @@ def test_foreach_rnn():
check_foreach_rnn(cell_type, num_states)


@with_seed()
def test_cut_subgraph_while_loop():
class TestLayer(gluon.HybridBlock):
def __init__(self, prefix=None, params=None):
super(TestLayer, self).__init__(prefix=prefix, params=params)
def hybrid_forward(self, F, data):
out1, data1 = F.contrib.while_loop(
cond=lambda i: i <= 5,
func=lambda i: (None, (i + 1, )),
loop_vars=(data, ),
max_iterations=10,
)
out2, data2 = F.contrib.while_loop(
cond=lambda i: i <= 10,
func=lambda i: (None, (i + 1, )),
loop_vars=data1[0],
max_iterations=10,
)
return data2[0]
data = mx.nd.normal(loc=0, scale=1, shape=(1, ))
layer = TestLayer()
layer.initialize(ctx=default_context())
res1 = layer(data)
with mx.autograd.record():
res1 = layer(data)
layer = TestLayer()
layer.initialize(ctx=default_context())
layer.hybridize()
res2 = layer(data)
with mx.autograd.record():
res2 = layer(data)
assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit d5ab088

Please sign in to comment.