Skip to content

Commit

Permalink
[MXNET-795] Fix a bug that CutSubgraph works only when each subgraph …
Browse files Browse the repository at this point in the history
…has its distinct name (apache#12106)

* Copy only when necessary

* Fix typo

* Add unittest
  • Loading branch information
junrushao authored and anirudh2290 committed Sep 19, 2018
1 parent e4ec86c commit 512e9d4
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/mxnet/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import absolute_import
import threading
import warnings
from collections import defaultdict

from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass

Expand All @@ -34,6 +35,7 @@ class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)):
The attributes to set for all symbol creations in the scope.
"""
_current = threading.local()
_subgraph_names = defaultdict(int)

def __init__(self, **kwargs):
self._old_scope = None
Expand Down
11 changes: 11 additions & 0 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ def _cut_subgraph(subg):
syms.append(s)
return syms

def _get_unique_subgraph_name(subgraph_name):
attrs = AttrScope._current.value._attr
if attrs.get("__subgraph_name__", "") != "":
subgraph_name = "".join([attrs["__subgraph_name__"], "$", subgraph_name])
AttrScope._subgraph_names[subgraph_name] += 1
subgraph_name = subgraph_name + str(AttrScope._subgraph_names[subgraph_name] - 1)
return subgraph_name

# This construct a subgraph for given output nodes.
# If an output node is one of the input nodes, we call identity to make sure
# that outputs nodes are different from input nodes.
Expand Down Expand Up @@ -232,6 +240,7 @@ def check_data(inputs, in_type, msg):
# the python function, we need to prune the computation graph constructed from
# the function. One way of doing it is to mark the nodes in the computation graph
# with AttrScope and prune the nodes without the special attribute.
name = _get_unique_subgraph_name(name)
with AttrScope(__subgraph_name__=name):
if isinstance(data, list):
in_eles = [symbol.var(sym.name) for sym in data]
Expand Down Expand Up @@ -456,6 +465,7 @@ def _func_wrapper(loop_vars):
return list(step_output), list(new_loop_vars)

def _create_subgraph(graph_vars, graph_func, subgraph_name):
subgraph_name = _get_unique_subgraph_name(subgraph_name)
with AttrScope(__subgraph_name__=subgraph_name):
# create new variables with the same name,
# them feed them to the given func
Expand Down Expand Up @@ -619,6 +629,7 @@ def _to_symbol_tuple(inputs, name):
return inputs

def _create_subgraph(graph_vars, graph_func, subgraph_name):
subgraph_name = _get_unique_subgraph_name(subgraph_name)
with AttrScope(__subgraph_name__=subgraph_name):
# create new variables with the same name,
# them feed them to the given func
Expand Down
41 changes: 41 additions & 0 deletions tests/python/unittest/test_contrib_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import mxnet as mx
from mxnet import gluon
from numpy.testing import assert_allclose, assert_array_equal
from collections import defaultdict
from mxnet.test_utils import *
from mxnet.base import _as_list
from mxnet.attribute import AttrScope
from common import with_seed


Expand Down Expand Up @@ -1765,6 +1767,45 @@ def hybrid_forward(self, F, data):
assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=1e-3, atol=1e-3)


def test_scope():
class TestBlock1(gluon.HybridBlock):
def __init__(self, prefix=None, params=None):
super(TestBlock1, self).__init__(prefix=prefix, params=params)
def hybrid_forward(self, F, data):
(new_data, ) = F.contrib.cond(
data > 0.5,
then_func=lambda: data * 2,
else_func=lambda: data * 3,
name="my_cond",
)
return new_data
class TestBlock2(gluon.HybridBlock):
def __init__(self, prefix=None, params=None):
super(TestBlock2, self).__init__(prefix=prefix, params=params)
def hybrid_forward(self, F, data):
(new_data, ) = F.contrib.cond(
data > 0.5,
then_func=lambda: data * 2,
else_func=lambda: data * 3,
name="my_cond",
)
return new_data
AttrScope._subgraph_names = defaultdict(int)
data = mx.nd.normal(loc=0, scale=1, shape=(1, ))
block1 = TestBlock1()
block1.initialize(ctx=default_context())
block1.hybridize()
_ = block1(data)
block2 = TestBlock2()
block2.initialize(ctx=default_context())
block2.hybridize()
_ = block2(data)
assert len(AttrScope._subgraph_names) == 3
assert AttrScope._subgraph_names['my_cond_else'] == 2
assert AttrScope._subgraph_names['my_cond_pred'] == 2
assert AttrScope._subgraph_names['my_cond_then'] == 2


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 512e9d4

Please sign in to comment.