Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add unittest for foreach
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Aug 10, 2018
1 parent 09af6d9 commit 54943d9
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/python/unittest/test_contrib_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,41 @@ def test_foreach_rnn():
check_foreach_rnn(cell_type, num_states)


@with_seed()
def test_cut_subgraph_foreach():
class TestLayer(gluon.HybridBlock):
def __init__(self, prefix=None, params=None):
super(TestLayer, self).__init__(prefix=prefix, params=params)

def hybrid_forward(self, F, inputs, states):
def step1(data, states):
return data + 1, states
out1, states1 = F.contrib.foreach(step1, inputs, states)
out2, states2 = F.contrib.foreach(step1, out1, states)
def step2(data, states):
return data + states[0], states1
out, states = F.contrib.foreach(step2, out2, states)
return out

data = mx.nd.normal(loc=0, scale=1, shape=(5, 10))
states = mx.nd.normal(loc=0, scale=1, shape=(10))
layer = TestLayer()
layer.initialize(ctx=default_context())
res1 = layer(data, [states])

with mx.autograd.record():
res1 = layer(data, [states])

layer = TestLayer()
layer.initialize(ctx=default_context())
layer.hybridize()
res2 = layer(data, [states])

with mx.autograd.record():
res2 = layer(data, [states])
assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)


@with_seed()
def test_cut_subgraph_while_loop():
class TestLayer(gluon.HybridBlock):
Expand Down

0 comments on commit 54943d9

Please sign in to comment.