-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace usage of copy.deepcopy() - Convolution/Batch Norm Fuser in FX #2645
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/2645
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c8436a4 with merge base f05f050 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -104,7 +104,9 @@ def fuse_conv_bn_eval(conv, bn): | |||
module `C` such that C(x) == B(A(x)) in inference mode. | |||
""" | |||
assert(not (conv.training or bn.training)), "Fusion only for eval!" | |||
fused_conv = copy.deepcopy(conv) | |||
fused_conv = type(conv)(conv.in_channels, conv.out_channels, conv.kernel_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fix seems weird? The right to do feels like its implementing a proper __deepcopy__()
for nn modules? @albanD
This popular thread seems to validate this fix https://discuss.pytorch.org/t/deep-copying-pytorch-modules/13514 but idk if this is what we want people to actually do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fix seems weird? The right to do feels like its implementing a proper
__deepcopy__()
for nn modules? @albanDThis popular thread seems to validate this fix https://discuss.pytorch.org/t/deep-copying-pytorch-modules/13514 but idk if this is what we want people to actually do?
we can save and load the model, found from 2385 . Other than this, is there other way which i am missing that will help me make a plausible fix ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that Module is a complex enough class that deepcopying it is very challenging (the same way we don't recommend you serialize it as-is but only the state_dict).
deepcopy() work in most simple cases but it is expected to fail sometimes.
If you only have a regular Conv2d kernel, doing deepcopy or a new constructor is pretty much the same thing though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how would you want me to proceed with the PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what @albanD is saying is that in this specific case deepcopy
-ing a conv layer is just fine, i.e. the original code probably doesn't need to be changed.
a0c5b6e
to
c8436a4
Compare
@@ -150,7 +152,9 @@ def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torc | |||
|
|||
|
|||
def fuse(model: torch.nn.Module) -> torch.nn.Module: | |||
model = copy.deepcopy(model) | |||
model, state_dict = type(model)(), model.state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's only going to work for models that do not take any parameters to __init__()
.
@svekars with this + https://github.com/pytorch/tutorials/pull/2645/files#r1391396959, I'm tempted to think that the originally issue is probably irrelevant for this tutorial. Even if copy.deepcopy(model)
may not be perfect, it's still better than any alternative that has been proposed so far. Perhaps we could close the original issue and still provide credits to the contributor for their efforts?
Closing this and the issue and will give half credit. |
Fixes #2331
Description
Replacing the use of
copy.deepcopy()
in Convolution/Batch Norm Fuser in FX tutorials with use ofload_state_dict
as mentioned in Deep copying PyTorch modules.cc @eellison @suo @gmagogsfm @jamesr66a @msaroufim @SherlockNoMad @albanD @sekyondaMeta @svekars @carljparker @NicolasHug @kit1980 @subramen