-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[blocked by #1756] Add decorator to auto-move data for inference #1526
Conversation
Hello @HenryJia! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2020-05-17 12:52:06 UTC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thanks for looking into this.
My main concern is that now the LightningModule is not anymore behaving like the nn.Module since you overwrite __call__
. I think the idea is that LM can be used just like an nn.Module outside of PL. This PR kinda breaks that. At the minimum, I would show a warning that this is done automatically because the user may be wondering why data transfers are slow (cpu to gpu for example).
Also, the assertion that the parameters must be on the same device is not correct.
I'm not completely overwriting it, I am using super to call nn.Module's
On second thought yes that makes sense, I'll get rid of that. I'll have to rethink how I detect the device that the model is on then |
@awaelchli I've now moved the data transfer code to pytorch_lighting.utilities to remove the issue of code duplication, and added warnings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also missing changelog
@Borda done! |
This pull request is now in conflict... :( |
Codecov Report
@@ Coverage Diff @@
## master #1526 +/- ##
=======================================
+ Coverage 88% 89% +1%
=======================================
Files 69 70 +1
Lines 4316 3833 -483
=======================================
- Hits 3805 3415 -390
+ Misses 511 418 -93 |
This pull request is now in conflict... :( |
@Borda I think I've dealt with all the issues you pointed out, would you mind reviewing again? |
|
||
if callable(getattr(batch, 'to', None)): | ||
if warn_on_transfer: | ||
rank_zero_warn('Auto transferred data to device {}'.format(device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move the warning out to the __call__
, because 1. this utility function is more general (it is used in other parts) and 2. this function is recursive, so if a dict of tensors is passed in, the warning would be shown multiple times.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awaelchli I thought about this, a problem with that is I effectively run into the same code duplication problem by trying to detect which device it's on in __call__
since I'd need to recurse into whatever format the data is in again in almost the exact same way
Also, I believe rank_zero_warn will only warn once anyway so that is not an issue
output = model(x) # Lightning will automove data here and warn you of it | ||
|
||
""" | ||
devices = [p.device for p in self.parameters()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure, if looping always over all the params is a good idea. Can we maybe cache the devices somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mcarilli are we missing a simpler/cleaner way of doing this?
x = x.cpu()
model.cuda()
# this works in lightning
out = model(x)
# out is cuda tensor now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I don't understand what you're trying to do here, but it looks like you're only using device[0], so why collect them all? Also if the model params do reside on multiple devices, it's hard to predict which device the user actually wants the input data to reside on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I only apply automatic data transfer if we are dealing with the simple case of the model residing on one device, as trying to auto transfer data when the model is spread across multiple devices is very non-trivial and is heavily dependent on model structure
device = devices[0] | ||
data = transfer_data_to_device(data, device.type, device.index, warn_on_transfer=True) | ||
kwargs = transfer_data_to_device(kwargs, device.type, device.index, warn_on_transfer=True) | ||
return super(LightningModule, self).__call__(*data, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Borda and I discussed this and we both agree, that we shouldn't do this in the Module (at least not by default). In our opinion we should always be able to use lightning module as nn module.
What I propose is the following:
We change this part to a decorator, that can be added to forward and is automatically added from trainer. (sorry for that coming so late).
In that case you could make the decorator a class, that also caches devices eventually :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the idea of adding the decorator dynamically there!
This pull request is now in conflict... :( |
@HenryJia maybe this should be a broader effort to do distributed inference? But we probably need a LightningModule method? Here's a brainstorm on how we might be able to solve distributed inference? model = LightningModule.load_from_checkpoint(...)
model.init_distributed(backend='', gpus=2, etc...)
model(x) |
Makes sense, I'll look at this again with fresh ideas at a later point, I'm a bit busy as of late with other things right now |
This pull request is now in conflict... :( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HenryJia this is great! Let's go with the suggestion by @justusschock and @Borda to build a decorator instead!
Sounds good, I'll get back on this in a couple of weeks time when all my university exams are over |
Let us know if some of us should pick up your work and continue :) |
I am afraid that letting it sleep and finish it in couple weeks would be a bit difficult regarding continues development... |
|
@williamFalcon This PR is continued over here #1905 |
Sorry for the late reply. |
We can close it or you can merge the changes in my branch into yours if you'd like to make adjustments. Either one is fine with me. Note that my PR is rebased onto #1756, not master, but it also contains your commits. I think one thing we need to figure out is how we will best apply the decorator. |
Btw, for changing destination branch, there is no need to close PR, we can just change the destination branch =) |
I read through your branch. It all looks good to me and better than where I left this. So I'll close this PR for now (unless anyone has any objections) |
Before submitting
What does this PR do?
Attempt to implement, fixes #1412
Currently only works for GPUs and not TPUs
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃