Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Successive unpack (#2768)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-ningxin authored Aug 12, 2020
1 parent e7fccfb commit 5d2a59f
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 9 deletions.
30 changes: 22 additions & 8 deletions src/sdk/pynni/nni/_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,15 @@ def _is_key_func(self, node_cpp):
return True
if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]:
# We cannot merge the List/Tuple
# Construct/Unpack func into other nodes, else it
# Unpack func into other nodes, else it
# may lead to a graph construction error.
# The reason why we donnot take the construct node
# also as a key node is that `cat` operation node need
# the last(previous) visited node to infer the mask. If
# we take the Construct node as the important node, the
# predecessor of the `cat` node will always be a construct
# node, which means we cannot infer the mask for the cat
# operation.
return True
return False

Expand All @@ -556,9 +563,13 @@ def unpack_manually(self):
_logger.debug('List/Tuple Construct Node(cpp) %s', str(last_cpp))
_logger.debug('List/Tuple Unpack Node(cpp) %s', str(unpack_cpp))
assert len(list(unpack_cpp.outputs())) == len(list(last_cpp.inputs()))
for _input, _output in zip(last_cpp.inputs(), unpack_cpp.outputs()):
_debug_input = _input.debugName()
_debug_output = _output.debugName()
errmsg = '%s Input number: %d if inconsistent with the output number %d' % (unpack_cpp, \
len(node.inputs), len(list(last_cpp.inputs())))

assert len(node.inputs) == len(list(last_cpp.inputs())), errmsg
for _debug_input, _debug_output in zip(node.inputs, node.outputs):
# _debug_input = _input.debugName()
# _debug_output = _output.debugName()
if _debug_input in self.input_to_node and _debug_output in self.input_to_node:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.
Expand All @@ -570,10 +581,13 @@ def unpack_manually(self):
self.input_to_node[_debug_input].remove(node)
# add the following nodes of _output into the input_to_node[_debug_input]
self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output])
if _debug_input in self.output_to_node and _debug_output in self.output_to_node:
# output_to_node[_debug_output] is a NodePyGroup, because one output
# tensor only can be generated by one node.
self.output_to_node[_debug_output] = self.output_to_node[_debug_input]
# just remove the _debug_output from the grapgh index. So that we can also skip
# the construct and tuple
if _debug_output in self.input_to_node:
for following_node in self.input_to_node[_debug_output]:
_tmp_index = following_node.inputs.index(_debug_output)
following_node.inputs[_tmp_index] = _debug_input


self.unpacked = True

Expand Down
98 changes: 97 additions & 1 deletion src/sdk/pynni/tests/test_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest
from unittest import TestCase, main

from nni._graph_utils import build_module_graph, build_graph, TorchModuleGraph
from nni._graph_utils import build_module_graph, build_graph, TorchModuleGraph, TUPLE_UNPACK_KIND

class BackboneModel1(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -194,5 +194,101 @@ def forward(self, x):
assert len(nodes) == 1
node = nodes[0]

@unittest.skipIf(torch.__version__ < "1.4.0", "not supported")
def test_module_unpack(self):
"""
test the tuple/list unpack function of TorchModuleGraph.
Following models are from the issue 2756
https://github.com/microsoft/nni/issues/2756.
MyModule will have two successive tuple unpack operations
between the B and C.
"""
class CBR(nn.Module):
def __init__(self, i, o):
super(CBR, self).__init__()
self.conv1 = nn.Conv2d(i, o, kernel_size=1)
self.bn1 = nn.BatchNorm2d(o)
self.act1 = nn.ReLU()

def forward(self, x):
return self.act1(self.bn1(self.conv1(x)))


class A(nn.Module):
def __init__(self):
super(A, self).__init__()
self.conv1 = CBR(3, 6, )
self.conv2 = CBR(6, 8, )
self.conv3 = CBR(6, 12)

def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x1)
return (x2, x3)


class B1(nn.Module):
def __init__(self):
super(B1, self).__init__()
self.conv1 = CBR(12, 32)
self.conv2 = CBR(32, 32)
self.conv3 = CBR(32, 32)

def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
return (x1, x2, x3)

class B(nn.Module):
def __init__(self):
super(B, self).__init__()
self.b = B1()

def forward(self, x):
return self.b(x[-1])

class C(nn.Module):
def __init__(self):
super(C, self).__init__()
self.conv1 = CBR(8, 32)
self.conv2 = CBR(12, 32)
self.conv3 = CBR(32, 32)
self.conv4 = CBR(32, 32)
self.conv5 = CBR(32, 32)

def forward(self, x):
return(self.conv1(x[0]), self.conv2(x[1]), self.conv3(x[2]),self.conv4(x[3]),self.conv5(x[4]))

class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.a = A()
self.b = B()
# self.dummy = Dummy()
self.c = C()

def forward(self, x):
x_a = self.a(x)
x_b = self.b(x_a)
xc = self.c(x_a + x_b)
return xc

dummy_input = torch.rand(1, 3, 28, 28)
model = MyModule()
graph = TorchModuleGraph(model, dummy_input)
graph.unpack_manually()
for node in graph.nodes_py.nodes_op:
# The input of the function nodes should
# not come from the TupleUnpack node, because
# all the TupleUnpack nodes have been removed(unpacked)
# manually
for _input in node.inputs:
if _input in graph.output_to_node:
preprocessor = graph.output_to_node[_input]
assert preprocessor.op_type != TUPLE_UNPACK_KIND


if __name__ == '__main__':
main()

0 comments on commit 5d2a59f

Please sign in to comment.