Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep initialization of H for all-weights and last-layer separate #72

Merged
merged 4 commits into from
Dec 21, 2021

Conversation

runame
Copy link
Collaborator

@runame runame commented Dec 15, 2021

This is an alternative to #71: it addresses a bug resulting from #62, first reported here (note: links to PR discussion in private repository). For the last-layer LA flavors the Hessian approximation H was first initialized for all-weights, which leads to out-of-memory errors for larger models.

The advantage of the previous fix is that it keeps the classes for all-weights and last-layer flavors more strictly separated, which might make it harder to introduce similar bugs in the future. However, many more classes are necessary. @aleximmer and me agreed that this additional complexity is probably not worth it.

Changes:

  • Add better tests for the initialization of all the Laplace classes, also with a large model (Wide ResNet 50-2), as suggested in Add tests with larger model architectures #69.
  • Fix the H initialization bug. The posterior_precision falls back to the prior before calling fit() for the first time for most cases. Exceptions: a last-layer flavor which doesn't get the last_layer_name passed as an argument and low-rank Laplace. For these two cases, H will be None. When trying to call posterior_precision in these cases, a descriptive error will be raised.
  • Remove redundant code + minor fixes.

@runame runame added the bug Something isn't working label Dec 15, 2021
@runame runame added this to the NeurIPS Prerelease milestone Dec 15, 2021
@edaxberger
Copy link
Collaborator

Looks great, and much simpler than the other solution indeed! I also tested it with WILDS and it works well.

Is it a problem that for last-layer with no last_layer_name passed and low-rank, posterior_precision is not defined? Does this mean that doing continual learning would be more difficult with these flavours (I guess even if so, we wouldn't want people to use those for CL anyways)?

@runame
Copy link
Collaborator Author

runame commented Dec 15, 2021

Thanks for testing it with WILDS!

I don't think the two exceptions are a problem:

  1. As you say, last-layer flavors should most likely not be used for continual learning anyway. Also, it is still possible to use them by passing the last_layer_name argument or simply a few more lines of code in the actual continual learning script.
  2. Low-rank does not even support fitting repeatedly (without overriding H), hence it is not an option for continual learning anyway.

@edaxberger
Copy link
Collaborator

Yes, good points!

@aleximmer
Copy link
Owner

We could additionally prohibit trying to do CL with these classes by adjusting the fit method to have no override argument and default to override=True.

@runame
Copy link
Collaborator Author

runame commented Dec 17, 2021

I think that's also ok. I don't really have any use case in mind where one might want to use override=False with the last-layer flavors. And if that changes, we can easily enable the option again. Alternatively, we can raise an error like we currently do for low-rank Laplace, to avoid confusion of the user (in principle there is no reason why there should be no override argument for last-layer flavors).

@runame
Copy link
Collaborator Author

runame commented Dec 18, 2021

Now a descriptive error gets raised when override=False for low-rank or last-layer Laplace approximations.

@edaxberger
Copy link
Collaborator

Great, I agree that a descriptive error is more useful/clear than just not offering the option at all (and we might still add the feature at some point if we think it's useful at all). Happy to merge this in.

@runame
Copy link
Collaborator Author

runame commented Dec 20, 2021

I think @aleximmer wanted to take a closer look today. After that we can merge it.

@aleximmer
Copy link
Owner

lgtm

@runame runame merged commit 7e42de8 into main Dec 21, 2021
@runame runame deleted the fix-H-init-alt branch December 21, 2021 09:27
@edaxberger edaxberger mentioned this pull request Dec 21, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants