Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

graphutils supports torch17 #3076

Merged
merged 100 commits into from
Nov 23, 2020
Merged

Conversation

chicm-ms
Copy link
Contributor

@chicm-ms chicm-ms commented Nov 9, 2020

something changed in pytorch v1.6 which breaks graphutils:

  1. torch.onnx.set_training is removed
  2. in traced graph, some prim::Constant nodes can be shared as inputs of multiple nodes in torch1.6:
    for example:
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
        self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
        self.fc1 = nn.Linear(4 * 4 * 50, 10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(x.size(0), -1)
        
        x = self.fc1(x)
        return x

class MyModel2(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.submodel = MyModel()

    def forward(self, x):
        return self.submodel(x)

In traced graph of MyModel2, we found that, conv1 and conv2 share a few prim::Constant input nodes, bn1 and bn2 also shares a few prim::Constant nodes as input, this breaks graphutils.

Filter prune algo implementation (microsoft#1655)
document the dispatcher working dir (microsoft#1866)
@chicm-ms chicm-ms changed the title Graph utils support torch17 graphutils supports torch17 Nov 9, 2020
@liuzhe-lz liuzhe-lz self-requested a review November 10, 2020 07:08
@liuzhe-lz liuzhe-lz mentioned this pull request Nov 10, 2020
77 tasks
else:
inputs.append(input_name)
if input_name in output_to_node:
for predecessor_node in output_to_node[input_name]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the element in output_to_node is a list? This means one output is generated by more than one node?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assertion added to ensure the list has 1 element at most.

@@ -177,7 +177,6 @@ def channel_prune(model):
pruner.compress()
pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)

@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check if the test cases work with PyTorch 1.6+ in following scripts?

  • test_graph_utils.py
  • test_compression_utils.py
  • test_pruners.py
  • test_dependecy_aware.py

Copy link
Contributor Author

@chicm-ms chicm-ms Nov 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch 1.6+ turned on for

test_compression_utils.py
test_pruners.py
test_dependecy_aware.py

for test_graph_utils.py, there are some expected file are generated by Yuge, they may need to be upgraded for pytorch 1.6+

@ultmaster , would you please help upgrade the protobuf test cases in test_graph_utils.py for pytroch1.6+ ?

node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
inputs.add(input_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If predecessor_node is already merged into this node group, then I guess it's output should not be taken as the input of this node group?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking added, please review again.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I didn't explain myself clearly, sorry about that.
What I want to say is that maybe we can remove the else part of the following code.

 if not self._is_key_func(predecessor_node):
    node_group.append(predecessor_node)
    node_queue.put(predecessor_node)
 else:
    inputs.add(input_name)

Because whether the input_name should be added to the inputs(inputs) of this nodes group is only decided by if predecessor_node is in nodes? It has nothing to do with whether predecessor_node is a key node/func? I guess? Please correct me if I'm wrong. Thanks~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not work removing the else as I tested. Because if the predecessor node is key func, the predecessor node will be merged into another node. There is only one key func node within one merged node.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! Thanks~

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants