Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Successive unpack #2768

Merged
merged 5 commits into from
Aug 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions src/sdk/pynni/nni/_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,15 @@ def _is_key_func(self, node_cpp):
return True
if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]:
# We cannot merge the List/Tuple
# Construct/Unpack func into other nodes, else it
# Unpack func into other nodes, else it
# may lead to a graph construction error.
# The reason why we donnot take the construct node
# also as a key node is that `cat` operation node need
# the last(previous) visited node to infer the mask. If
# we take the Construct node as the important node, the
# predecessor of the `cat` node will always be a construct
# node, which means we cannot infer the mask for the cat
# operation.
return True
return False

Expand All @@ -556,9 +563,13 @@ def unpack_manually(self):
_logger.debug('List/Tuple Construct Node(cpp) %s', str(last_cpp))
_logger.debug('List/Tuple Unpack Node(cpp) %s', str(unpack_cpp))
assert len(list(unpack_cpp.outputs())) == len(list(last_cpp.inputs()))
for _input, _output in zip(last_cpp.inputs(), unpack_cpp.outputs()):
_debug_input = _input.debugName()
_debug_output = _output.debugName()
errmsg = '%s Input number: %d if inconsistent with the output number %d' % (unpack_cpp, \
len(node.inputs), len(list(last_cpp.inputs())))

assert len(node.inputs) == len(list(last_cpp.inputs())), errmsg
for _debug_input, _debug_output in zip(node.inputs, node.outputs):
# _debug_input = _input.debugName()
# _debug_output = _output.debugName()
if _debug_input in self.input_to_node and _debug_output in self.input_to_node:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.
Expand All @@ -570,10 +581,13 @@ def unpack_manually(self):
self.input_to_node[_debug_input].remove(node)
# add the following nodes of _output into the input_to_node[_debug_input]
self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output])
if _debug_input in self.output_to_node and _debug_output in self.output_to_node:
# output_to_node[_debug_output] is a NodePyGroup, because one output
# tensor only can be generated by one node.
self.output_to_node[_debug_output] = self.output_to_node[_debug_input]
# just remove the _debug_output from the grapgh index. So that we can also skip
# the construct and tuple
if _debug_output in self.input_to_node:
for following_node in self.input_to_node[_debug_output]:
_tmp_index = following_node.inputs.index(_debug_output)
following_node.inputs[_tmp_index] = _debug_input
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much clear for this logic



self.unpacked = True

Expand Down
98 changes: 97 additions & 1 deletion src/sdk/pynni/tests/test_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest
from unittest import TestCase, main

from nni._graph_utils import build_module_graph, build_graph, TorchModuleGraph
from nni._graph_utils import build_module_graph, build_graph, TorchModuleGraph, TUPLE_UNPACK_KIND

class BackboneModel1(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -194,5 +194,101 @@ def forward(self, x):
assert len(nodes) == 1
node = nodes[0]

@unittest.skipIf(torch.__version__ < "1.4.0", "not supported")
def test_module_unpack(self):
"""
test the tuple/list unpack function of TorchModuleGraph.
Following models are from the issue 2756
https://github.com/microsoft/nni/issues/2756.
MyModule will have two successive tuple unpack operations
between the B and C.
"""
class CBR(nn.Module):
def __init__(self, i, o):
super(CBR, self).__init__()
self.conv1 = nn.Conv2d(i, o, kernel_size=1)
self.bn1 = nn.BatchNorm2d(o)
self.act1 = nn.ReLU()

def forward(self, x):
return self.act1(self.bn1(self.conv1(x)))


class A(nn.Module):
def __init__(self):
super(A, self).__init__()
self.conv1 = CBR(3, 6, )
self.conv2 = CBR(6, 8, )
self.conv3 = CBR(6, 12)

def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x1)
return (x2, x3)


class B1(nn.Module):
def __init__(self):
super(B1, self).__init__()
self.conv1 = CBR(12, 32)
self.conv2 = CBR(32, 32)
self.conv3 = CBR(32, 32)

def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
return (x1, x2, x3)

class B(nn.Module):
def __init__(self):
super(B, self).__init__()
self.b = B1()

def forward(self, x):
return self.b(x[-1])

class C(nn.Module):
def __init__(self):
super(C, self).__init__()
self.conv1 = CBR(8, 32)
self.conv2 = CBR(12, 32)
self.conv3 = CBR(32, 32)
self.conv4 = CBR(32, 32)
self.conv5 = CBR(32, 32)

def forward(self, x):
return(self.conv1(x[0]), self.conv2(x[1]), self.conv3(x[2]),self.conv4(x[3]),self.conv5(x[4]))

class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.a = A()
self.b = B()
# self.dummy = Dummy()
self.c = C()

def forward(self, x):
x_a = self.a(x)
x_b = self.b(x_a)
xc = self.c(x_a + x_b)
return xc

dummy_input = torch.rand(1, 3, 28, 28)
model = MyModule()
graph = TorchModuleGraph(model, dummy_input)
graph.unpack_manually()
for node in graph.nodes_py.nodes_op:
# The input of the function nodes should
# not come from the TupleUnpack node, because
# all the TupleUnpack nodes have been removed(unpacked)
# manually
for _input in node.inputs:
if _input in graph.output_to_node:
preprocessor = graph.output_to_node[_input]
assert preprocessor.op_type != TUPLE_UNPACK_KIND


if __name__ == '__main__':
main()