From 512e9d41b048fa217111f34b348498a85a7b2021 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 22 Aug 2018 09:48:33 +0800 Subject: [PATCH] [MXNET-795] Fix a bug that CutSubgraph works only when each subgraph has its distinct name (#12106) * Copy only when necessary * Fix typo * Add unittest --- python/mxnet/attribute.py | 2 + python/mxnet/symbol/contrib.py | 11 +++++ .../unittest/test_contrib_control_flow.py | 41 +++++++++++++++++++ 3 files changed, 54 insertions(+) diff --git a/python/mxnet/attribute.py b/python/mxnet/attribute.py index 17044ddaef06..1a7bd44c01d0 100644 --- a/python/mxnet/attribute.py +++ b/python/mxnet/attribute.py @@ -20,6 +20,7 @@ from __future__ import absolute_import import threading import warnings +from collections import defaultdict from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass @@ -34,6 +35,7 @@ class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)): The attributes to set for all symbol creations in the scope. """ _current = threading.local() + _subgraph_names = defaultdict(int) def __init__(self, **kwargs): self._old_scope = None diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 38195bd62ffa..f40a372fdbcd 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -124,6 +124,14 @@ def _cut_subgraph(subg): syms.append(s) return syms +def _get_unique_subgraph_name(subgraph_name): + attrs = AttrScope._current.value._attr + if attrs.get("__subgraph_name__", "") != "": + subgraph_name = "".join([attrs["__subgraph_name__"], "$", subgraph_name]) + AttrScope._subgraph_names[subgraph_name] += 1 + subgraph_name = subgraph_name + str(AttrScope._subgraph_names[subgraph_name] - 1) + return subgraph_name + # This construct a subgraph for given output nodes. # If an output node is one of the input nodes, we call identity to make sure # that outputs nodes are different from input nodes. @@ -232,6 +240,7 @@ def check_data(inputs, in_type, msg): # the python function, we need to prune the computation graph constructed from # the function. One way of doing it is to mark the nodes in the computation graph # with AttrScope and prune the nodes without the special attribute. + name = _get_unique_subgraph_name(name) with AttrScope(__subgraph_name__=name): if isinstance(data, list): in_eles = [symbol.var(sym.name) for sym in data] @@ -456,6 +465,7 @@ def _func_wrapper(loop_vars): return list(step_output), list(new_loop_vars) def _create_subgraph(graph_vars, graph_func, subgraph_name): + subgraph_name = _get_unique_subgraph_name(subgraph_name) with AttrScope(__subgraph_name__=subgraph_name): # create new variables with the same name, # them feed them to the given func @@ -619,6 +629,7 @@ def _to_symbol_tuple(inputs, name): return inputs def _create_subgraph(graph_vars, graph_func, subgraph_name): + subgraph_name = _get_unique_subgraph_name(subgraph_name) with AttrScope(__subgraph_name__=subgraph_name): # create new variables with the same name, # them feed them to the given func diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 76d0218775b4..54f22a8fd6a7 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -20,8 +20,10 @@ import mxnet as mx from mxnet import gluon from numpy.testing import assert_allclose, assert_array_equal +from collections import defaultdict from mxnet.test_utils import * from mxnet.base import _as_list +from mxnet.attribute import AttrScope from common import with_seed @@ -1765,6 +1767,45 @@ def hybrid_forward(self, F, data): assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=1e-3, atol=1e-3) +def test_scope(): + class TestBlock1(gluon.HybridBlock): + def __init__(self, prefix=None, params=None): + super(TestBlock1, self).__init__(prefix=prefix, params=params) + def hybrid_forward(self, F, data): + (new_data, ) = F.contrib.cond( + data > 0.5, + then_func=lambda: data * 2, + else_func=lambda: data * 3, + name="my_cond", + ) + return new_data + class TestBlock2(gluon.HybridBlock): + def __init__(self, prefix=None, params=None): + super(TestBlock2, self).__init__(prefix=prefix, params=params) + def hybrid_forward(self, F, data): + (new_data, ) = F.contrib.cond( + data > 0.5, + then_func=lambda: data * 2, + else_func=lambda: data * 3, + name="my_cond", + ) + return new_data + AttrScope._subgraph_names = defaultdict(int) + data = mx.nd.normal(loc=0, scale=1, shape=(1, )) + block1 = TestBlock1() + block1.initialize(ctx=default_context()) + block1.hybridize() + _ = block1(data) + block2 = TestBlock2() + block2.initialize(ctx=default_context()) + block2.hybridize() + _ = block2(data) + assert len(AttrScope._subgraph_names) == 3 + assert AttrScope._subgraph_names['my_cond_else'] == 2 + assert AttrScope._subgraph_names['my_cond_pred'] == 2 + assert AttrScope._subgraph_names['my_cond_then'] == 2 + + if __name__ == '__main__': import nose nose.runmodule()