Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Aug 9, 2018
1 parent 3b449ef commit 2d7ffe2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/mxnet/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import absolute_import
import threading
import warnings
from collections import defaultdict

from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass

Expand All @@ -34,6 +35,7 @@ class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)):
The attributes to set for all symbol creations in the scope.
"""
_current = threading.local()
_subgraph_names = defaultdict(int)

def __init__(self, **kwargs):
self._old_scope = None
Expand Down
11 changes: 11 additions & 0 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ def _cut_subgraph(subg):
syms.append(s)
return syms

def _get_unique_subgraph_name(subgraph_name):
attrs = AttrScope._current.value._attr
if attrs.get("__subgraph_name__", "") != "":
subgraph_name = "".join([attrs["__subgraph_name__"], "$", subgraph_name])
subgraph_name = subgraph_name + str(AttrScope._subgraph_names.get(subgraph_name, 0))
AttrScope._subgraph_names[subgraph_name] += 1
return subgraph_name

# This construct a subgraph for given output nodes.
# If an output node is one of the input nodes, we call identity to make sure
# that outputs nodes are different from input nodes.
Expand Down Expand Up @@ -234,6 +242,7 @@ def check_data(inputs, in_type, msg):
# the python function, we need to prune the computation graph constructed from
# the function. One way of doing it is to mark the nodes in the computation graph
# with AttrScope and prune the nodes without the special attribute.
name = _get_unique_subgraph_name(name)
with AttrScope(__subgraph_name__=name):
if isinstance(data, list):
in_eles = [symbol.var(sym.name) for sym in data]
Expand Down Expand Up @@ -458,6 +467,7 @@ def _func_wrapper(loop_vars):
return list(step_output), list(new_loop_vars)

def _create_subgraph(graph_vars, graph_func, subgraph_name):
subgraph_name = _get_unique_subgraph_name(subgraph_name)
with AttrScope(__subgraph_name__=subgraph_name):
# create new variables with the same name,
# them feed them to the given func
Expand Down Expand Up @@ -618,6 +628,7 @@ def _to_symbol_tuple(inputs, name):
return inputs

def _create_subgraph(graph_vars, graph_func, subgraph_name):
subgraph_name = _get_unique_subgraph_name(subgraph_name)
with AttrScope(__subgraph_name__=subgraph_name):
# create new variables with the same name,
# them feed them to the given func
Expand Down

0 comments on commit 2d7ffe2

Please sign in to comment.