-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Get/set parameters and review of saving and loading #138
Conversation
Main purpose is to avoid saving things that cannot or should not be pickled (e.g. the environment, the replay buffer, pytorch variables, ...).
would be nice to reduce the number of variables but not sure if possible. For info, that's mainly @Artemis-Skade who worked on the saving/loading part. We also had some trouble with saving on GPU, loading on cpu...
Yes, anything that can be saved with
Yep, not sure about that, you choose ;)
agree ;)
I did not know that... but that would mean breaking changes for pytorch users of versions 1.4.x and 1.5.x |
For get/set parameters, I think we maybe don't need them as we can access |
if tensors is not None: | ||
for name in tensors: | ||
recursive_setattr(model, name, tensors[name]) | ||
# py other pytorch variables back in place |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo?
Same thoughts occurred to me: state_dict is already very convenient to use. Two things worry me though: Should we allow loading parameters from numpy arrays directly ( |
Doing
Hmm, |
I will be that guy and argue it would be equally easy to include this in
I think the get/set_params should follow the state dicts of objects specified by
|
Fair enough ;)
Oh, true. But as for now, this is only valid for SAC and it corresponds to a very special variable (entropy temperature). |
Hmm how about the other nn.Modules included in this list? In above SAC example it has policy, actor and critic, all with potentially different parameters. OnPolicyAlgorithms have policy and its optimizer (granted, including all parameters of policy would likely include optimizer too): stable-baselines3/stable_baselines3/common/on_policy_algorithm.py Lines 243 to 247 in f5104a5
|
Added the (long delayed...) functions. If it looks ok, I will check over docs (something seems to be failing there) and update where necessary. |
for name in params: | ||
attr = recursive_getattr(model, name) | ||
attr.load_state_dict(params[name]) | ||
model.set_parameters(params, exact_match=True, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why exact_match
is hardcoded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using exact_match=False would mean some parameters were missing in the saved model file which should not happen unless someone modifies the file. Laying out the hardcoded parameter like this is to signal that we want to make sure every parameter is updated as it was saved, and that nothing is missing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, fair enough.
What can happen is also that parameters are renamed between versions (it happened to me after refactoring the continuous critic, the name of the parameters were not the same)
minor changes required otherwise LGTM ;) Maybe one thing to add to the saved model: the SB3 version (could be checked later if needed, but it is good to have it anyway) |
Is it ready to review now? |
""" | ||
if seed is None: | ||
return | ||
set_random_seed(seed, using_cuda=self.device == th.device("cuda")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since #154 I realized that we should check self.device.type
, should we fix it here or in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like something that could be included here, although would you be able to add it if it is a quick thing? I am reading through the PR but only slowly digesting how it works.
Also any ideas what could be causing the flake8 linting error exactly? My flake8 is not catching it :/ Fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM =) (apart from minor remark)
So at the end, |
Yeap, with the reasoning that these are rarely used, although if we keep using Also flake8 thing pointed out a error I still need to fix, so hold the brakes on merging :) |
@Miffyli a bit late... this introduces breaking change for previously saved policy, unless |
Closes #116
closes #70
Review over saving and loading of models, as well as (possibly) implementing
get_parameters
andset_parameters
akin to stable-baselines2.Changelog
get_torch_variables
->_get_torch_save_params
, and include docstring only in original implementation.excluded_save_params
->_excluded_save_params
, and include docstring only in original implementation.tensors
topytorch_variables
for clarity.save_to_zip_file
by combining duplicate code.get/set_parameters
, which use_get_torch_save_params
to gather/set parameters of different objects.Motivation and Context
Types of changes
Suggestions and TODOs
get_torch_variables
get/set_parameters
.Things to think about
np.savez
as with TF?Checklist:
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)