diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index 25e513f42e..3fa6cd0eab 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -530,8 +530,15 @@ def _is_key_func(self, node_cpp): return True if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]: # We cannot merge the List/Tuple - # Construct/Unpack func into other nodes, else it + # Unpack func into other nodes, else it # may lead to a graph construction error. + # The reason why we donnot take the construct node + # also as a key node is that `cat` operation node need + # the last(previous) visited node to infer the mask. If + # we take the Construct node as the important node, the + # predecessor of the `cat` node will always be a construct + # node, which means we cannot infer the mask for the cat + # operation. return True return False @@ -556,9 +563,13 @@ def unpack_manually(self): _logger.debug('List/Tuple Construct Node(cpp) %s', str(last_cpp)) _logger.debug('List/Tuple Unpack Node(cpp) %s', str(unpack_cpp)) assert len(list(unpack_cpp.outputs())) == len(list(last_cpp.inputs())) - for _input, _output in zip(last_cpp.inputs(), unpack_cpp.outputs()): - _debug_input = _input.debugName() - _debug_output = _output.debugName() + errmsg = '%s Input number: %d if inconsistent with the output number %d' % (unpack_cpp, \ + len(node.inputs), len(list(last_cpp.inputs()))) + + assert len(node.inputs) == len(list(last_cpp.inputs())), errmsg + for _debug_input, _debug_output in zip(node.inputs, node.outputs): + # _debug_input = _input.debugName() + # _debug_output = _output.debugName() if _debug_input in self.input_to_node and _debug_output in self.input_to_node: # input_to_node[_debug_input] is a list of NodePyGroup, because # one tensor can be used as input for multiple nodes at the same time. @@ -570,10 +581,13 @@ def unpack_manually(self): self.input_to_node[_debug_input].remove(node) # add the following nodes of _output into the input_to_node[_debug_input] self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output]) - if _debug_input in self.output_to_node and _debug_output in self.output_to_node: - # output_to_node[_debug_output] is a NodePyGroup, because one output - # tensor only can be generated by one node. - self.output_to_node[_debug_output] = self.output_to_node[_debug_input] + # just remove the _debug_output from the grapgh index. So that we can also skip + # the construct and tuple + if _debug_output in self.input_to_node: + for following_node in self.input_to_node[_debug_output]: + _tmp_index = following_node.inputs.index(_debug_output) + following_node.inputs[_tmp_index] = _debug_input + self.unpacked = True diff --git a/src/sdk/pynni/tests/test_graph_utils.py b/src/sdk/pynni/tests/test_graph_utils.py index 92851bc91c..f6181d5482 100644 --- a/src/sdk/pynni/tests/test_graph_utils.py +++ b/src/sdk/pynni/tests/test_graph_utils.py @@ -15,7 +15,7 @@ import unittest from unittest import TestCase, main -from nni._graph_utils import build_module_graph, build_graph, TorchModuleGraph +from nni._graph_utils import build_module_graph, build_graph, TorchModuleGraph, TUPLE_UNPACK_KIND class BackboneModel1(nn.Module): def __init__(self): @@ -194,5 +194,101 @@ def forward(self, x): assert len(nodes) == 1 node = nodes[0] + @unittest.skipIf(torch.__version__ < "1.4.0", "not supported") + def test_module_unpack(self): + """ + test the tuple/list unpack function of TorchModuleGraph. + Following models are from the issue 2756 + https://github.com/microsoft/nni/issues/2756. + MyModule will have two successive tuple unpack operations + between the B and C. + """ + class CBR(nn.Module): + def __init__(self, i, o): + super(CBR, self).__init__() + self.conv1 = nn.Conv2d(i, o, kernel_size=1) + self.bn1 = nn.BatchNorm2d(o) + self.act1 = nn.ReLU() + + def forward(self, x): + return self.act1(self.bn1(self.conv1(x))) + + + class A(nn.Module): + def __init__(self): + super(A, self).__init__() + self.conv1 = CBR(3, 6, ) + self.conv2 = CBR(6, 8, ) + self.conv3 = CBR(6, 12) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x1) + return (x2, x3) + + + class B1(nn.Module): + def __init__(self): + super(B1, self).__init__() + self.conv1 = CBR(12, 32) + self.conv2 = CBR(32, 32) + self.conv3 = CBR(32, 32) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + return (x1, x2, x3) + + class B(nn.Module): + def __init__(self): + super(B, self).__init__() + self.b = B1() + + def forward(self, x): + return self.b(x[-1]) + + class C(nn.Module): + def __init__(self): + super(C, self).__init__() + self.conv1 = CBR(8, 32) + self.conv2 = CBR(12, 32) + self.conv3 = CBR(32, 32) + self.conv4 = CBR(32, 32) + self.conv5 = CBR(32, 32) + + def forward(self, x): + return(self.conv1(x[0]), self.conv2(x[1]), self.conv3(x[2]),self.conv4(x[3]),self.conv5(x[4])) + + class MyModule(nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.a = A() + self.b = B() + # self.dummy = Dummy() + self.c = C() + + def forward(self, x): + x_a = self.a(x) + x_b = self.b(x_a) + xc = self.c(x_a + x_b) + return xc + + dummy_input = torch.rand(1, 3, 28, 28) + model = MyModule() + graph = TorchModuleGraph(model, dummy_input) + graph.unpack_manually() + for node in graph.nodes_py.nodes_op: + # The input of the function nodes should + # not come from the TupleUnpack node, because + # all the TupleUnpack nodes have been removed(unpacked) + # manually + for _input in node.inputs: + if _input in graph.output_to_node: + preprocessor = graph.output_to_node[_input] + assert preprocessor.op_type != TUPLE_UNPACK_KIND + + if __name__ == '__main__': main()