Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] LRScheduler batch option #626

Merged

Conversation

guybuk
Copy link
Contributor

@guybuk guybuk commented Apr 27, 2020

Closes #610

I could use some assistance with schedulers that update once an epoch:

optim._LRScheduler initializes with self.last_epoch=0. This means that if we're training for 0,1,2,3,4 epochs, the respective values for last_epoch at the end of every epoch (where we update LR) will be 1,2,3,4,5. This requires changes I made in the tests:

assert lr_policy.lr_scheduler_.last_epoch == max_epochs-1 => assert lr_policy.lr_scheduler_.last_epoch == max_epochs

The reason this seems to have worked until now is in fact because of what I think is a bug:

It seems like until now, inside on_epoch_end, epoch=len(net.history)-1=>epoch=0. This means that when self.lr_scheduler_.step(epoch) is called for the first time, the scheduler does not actually take a step because self.last_step==epoch==0, and last_epoch would stay at 0 for one extra epoch, which is why the test above passes (the one with max_epochs - 1).

I changed this code so that epoch=len(history) so the step would make sense, and replaced the order of history.record with lr_scheduler_.step so that from now on history.record will record the LR with which the step was performed rather than the LR with which the next step will be performed (which was this way until now).

Finally, I feel like as far as I can tell, all is in order. Except for ReduceLROnPlateau, which I'm struggling to understand how it works in general, and how the unique code for it inside on_epoch_end works.

Thanks!

@BenjaminBossan
Copy link
Collaborator

BenjaminBossan commented Apr 27, 2020

Thanks again for working on this @guybuk. I'll take a closer look when I have more time.

@thomasjpfan is the expert for LR schedulers, hopefully he has time to answer your questions.

@thomasjpfan thomasjpfan self-requested a review April 27, 2020 19:14
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting the first epoch to -1 was intention because the API before pytorch 1.1 had it initialized to -1 and was changed in 1.1 with: pytorch/pytorch#7889. It was pretty weird before 1.1 and the change was for the better.

For skorch, since we are support torch > 1.1, this change should be okay.


@pytest.mark.parametrize('policy, kwargs', [
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3}),
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not be backwards compatible with TorchCyclicLR with default arguments. As in TorchCyclicLR would step every epoch if a user does not pass in step_every='batch'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the implementation is fine and we just need to restore the if TorchCyclicLR and isinstance(self.lr_scheduler_, TorchCyclicLR):... segment?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that step_every is a parameter to LRScheduler, if we keep the isinstance(self.lr_scheduler_, TorchCyclicLR) piece, we will not be strictly following the step_every parameter.

We can have a step_every='auto' option that special cases the CosineAnnealingLR and TorchCyclicLR. Specifically, if the class is CosineAnnealingLR or TorchCyclicLR we use batch steps, otherwise we use epoch step. We will still allow the batch and epoch options for custom lr schdulers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would mean introducing the entire concept of an 'auto' option ,and having to be backwards compatible for it in the future. However all it'll do is special case cyclic (for the record, cosine annealing isn't hard coded to work with batches, only cyclic).

This basically means that if we don't want 'auto' to just be a cover up for the cyclic case and instead be meaningful, we will have to add reasonable recommendations for every scheduler we introduce (which will be the default auto options). That will most likely involve having to read the paper from which it was introduced, or it's documentation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a conundrum. Do you know why some schedulers prefer batch and other epoch? Is it some derived property or more arbitrary?

we will have to add reasonable recommendations for every scheduler we introduce

This is certainly not the way to go. Could we not just use whatever the default is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a conundrum. Do you know why some schedulers prefer batch and other epoch? Is it some derived property or more arbitrary?

Not for certain. Naturally if you wanted your lr to change smoothly you'd do it per batch rather than epoch. I'm just plugging in a scheduler as is written in its documentation.

This is certainly not the way to go. Could we not just use whatever the default is?

A scheduler doesn't have a default. It has a step function which the user can use whenever and wherever they'd like. Setting mode epoch or batch as defaults will break backwards compatibility for anything that isn't intended to use the default mode we choose.

Since the Cyclic scheduler is a special case for which skorch hacked a solution, I suggested leaving that solution as is (so that it won't conflict with this feature or have compatibility issues) until skorch makes an update that breaks compatibility, and just get rid of that hack then.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points. Maybe it would be worth it to actually break backwards compatibility here for the sake of getting a clean solution. In fact, we could even detect when such a breakage happens and issue a warning to the user, with instructions how to revert to the old behavior. Would that be a compromise?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that be a compromise?

I would be happy with a deprecation warning with a suggestion of using step_every='batch' in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -57,6 +57,10 @@ class LRScheduler(Callback):
Pass ``None`` to disable placing events in history.
**Note:** This feature works only for pytorch version >=1.4

step_every: str, (default='epoch'_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
step_every: str, (default='epoch'_
step_every: str, (default='epoch')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@BenjaminBossan
Copy link
Collaborator

For skorch, since we are support torch > 1.1, this change should be okay.

We actually do support 1.1 to 1.4. However, I was thinking of adding 1.5, maybe we should drop 1.1 then? My impression is that since PyTorch introduces so few breaking changes, most people use the more recent versions, but I have no data to back that up.

@thomasjpfan
Copy link
Member

Sorry I mistyped. I meant to say that we support pytorch >= 1.1 (with the equality), we can depend on the new behavior. So we do not need to drop 1.1.

@guybuk
Copy link
Contributor Author

guybuk commented May 18, 2020

Looked at the ReduceLROnPlateau tests, and shifted the if statements in on_epoch_end around so the tests pass. I probably need someone who knows the feature to make sure it's still ok.

elif self.lr_scheduler_.mode == 'min':
score = np.inf
elif epoch:
score = net.history[-1, self.monitor]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm no expert on lr schedulers, so can't comment on the exact logic that should be applied here. However, what I'm seeing is that the logic might have a "hole", where score wouldn't be defined at all if none of the if or elif matches. In the logic before, score would always be defined.

Could it be that we don't actually need elif epoch? Previously, we accessed net.history[-2, self.monitor], i.e. the second to last row in the history. Therefore, we had to make sure that we're not in the first epoch (at least that's my interpretation of things). Now that we're accessing net.history[-1, self.monitor], we can probably never get an index error here, hence don't need the guard. Would that make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Pushed a change.

if policy == TorchCyclicLR or policy == "TorchCyclicLR":
warnings.warn(
"The LRScheduler callback makes a step "
"every epoch by default from now on. To have the cyclic lr "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For deprecation, we must keep the current behavior and state that the behavior will change in a Future version:

Starting in 0.10, the LRScheduler will make a step every epoch by default. To have the cyclic lr schedule update every batch set step_every='batch'".

I am thinking of moving this logic into initialize:

def initialize(self):
	self.step_every_ = self.step_every
    if policy == TorchCyclicLR or policy == "TorchCyclicLR" and self.step_every=='epoch':
		self.step_every_ = 'batch'

And then use self.step_every_ in the other methods.

@BenjaminBossan How do you feel about this approach to deprecating the behavior? The simpler approach would be to update in __init__.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify the advantage of having both self.step_every and self.step_every_? It looks to me like having a second variable is unnecessary.

Once we decide whether to place this code in initialize or init, I'll make the changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Traditionally the sckit-learn API does not allow for changing the __init__ parameters during construction. But since this is a callback and not an estimator, we may not need to strictly follow this API contract.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please try to stick with the sklearn convention here and don't change the __init__ parameters. This means making the change during initialize (which is equivalent to fit in sklearn) and use the trailing underscore. The reason why we need this is because it makes it possible to use set_params on callbacks.

How do you feel about this approach to deprecating the behavior?

When it comes to deprecation: I don't have a strong opinion. From my point of view, it would also be okay to change the behavior to the new behavior, raise a warning and pass a message that explains how to get back to the old behavior. But deprecation is also totally fine with me.

Regardless of how, could you please add a # TODO: ... comment, @guybuk, so that we don't forget to eventually change the behavior and remove the warning?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When it comes to deprecation: I don't have a strong opinion. From my point of view, it would also be okay to change the behavior to the new behavior, raise a warning and pass a message that explains how to get back to the old behavior. But deprecation is also totally fine with me.

In this case, I am +0.5 with moving fast and go with the behavior change.

Let's change the behavior with a FutureWarning, we can remove the warning in 0.10. With the behavior change, we will not need to have a self.step_every_.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Let me know if there's anything else.

Also, thanks for holding my hand throughout the process. I hope it was worth you time 😃

"every epoch by default from now on. To have the cyclic lr "
"scheduler update every batch, "
"set step_every='batch'",
DeprecationWarning,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FutureWarning shows by default to users.

Suggested change
DeprecationWarning,
FutureWarning,

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHANGES.md needs to be updated with this new feature.

Thank you for working on this @guybuk !

@@ -123,6 +115,15 @@ def initialize(self):
self.policy_ = self._get_policy_cls()
self.lr_scheduler_ = None
self.batch_idx_ = 0
# TODO: Remove this warning on 0.10 release
if self.policy_ == TorchCyclicLR or self.policy_ == "TorchCyclicLR":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only need to warn if step_every is 'epoch'.

Suggested change
if self.policy_ == TorchCyclicLR or self.policy_ == "TorchCyclicLR":
if self.policy_ == TorchCyclicLR or self.policy_ == "TorchCyclicLR" and self.step_every == 'epoch':

We should also have a test to make sure the warning message is raised.

kwargs):
with pytest.warns(None) as record:
scheduler = LRScheduler(
policy, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not need to parametrize

Suggested change
policy, **kwargs)
TorchCyclicLR, base_lr=123, max_lr=999)

Comment on lines 119 to 120
if self.policy_ == TorchCyclicLR or self.policy_ == "TorchCyclicLR" \
and self.step_every == 'epoch':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
if self.policy_ == TorchCyclicLR or self.policy_ == "TorchCyclicLR" \
and self.step_every == 'epoch':
if (self.policy_ == TorchCyclicLR or self.policy_ == "TorchCyclicLR"
and self.step_every == 'epoch'):

classifier_data,
policy,
kwargs):
with pytest.warns(None) as record:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with pytest.warns(None) as record:
msg = "The LRScheduler now makes a step every epoch by default. "
with pytest.warns(FutureWarning, match=msg) as record:

After this change, we can remove the lines that come after assert len(record) == 1

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

CHANGES.md Outdated Show resolved Hide resolved
@thomasjpfan thomasjpfan changed the title [WIP] LRScheduler batch option [MRG] LRScheduler batch option May 26, 2020
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Copy link
Collaborator

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this @guybuk, I believe it's a very cool feature to have. It took a few iterations but I'm sure it was worth it. Also many thanks @thomasjpfan for your review work.

@BenjaminBossan BenjaminBossan merged commit 5ddae2a into skorch-dev:master May 27, 2020
BenjaminBossan added a commit that referenced this pull request Aug 30, 2020
This release of skorch contains a few minor improvements and some nice additions. As always, we fixed a few bugs and improved the documentation. Our [learning rate scheduler](https://skorch.readthedocs.io/en/latest/callbacks.html#skorch.callbacks.LRScheduler) now optionally logs learning rate changes to the history; moreover, it now allows the user to choose whether an update step should be made after each batch or each epoch.

If you always longed for a metric that would just use whatever is defined by your criterion, look no further than [`loss_scoring`](https://skorch.readthedocs.io/en/latest/scoring.html#skorch.scoring.loss_scoring). Also, skorch now allows you to easily change the kind of nonlinearity to apply to the module's output when `predict` and `predict_proba` are called, by passing the `predict_nonlinearity` argument.

Besides these changes, we improved the customization potential of skorch. First of all, the `criterion` is now set to `train` or `valid`, depending on the phase -- this is useful if the criterion should act differently during training and validation. Next we made it easier to add custom modules, optimizers, and criteria to your neural net; this should facilitate implementing architectures like GANs. Consult the [docs](https://skorch.readthedocs.io/en/latest/user/neuralnet.html#subclassing-neuralnet) for more on this. Conveniently, [`net.save_params`](https://skorch.readthedocs.io/en/latest/net.html#skorch.net.NeuralNet.save_params) can now persist arbitrary attributes, including those custom modules.
As always, these improvements wouldn't have been possible without the community. Please keep asking questions, raising issues, and proposing new features. We are especially grateful to those community members, old and new, who contributed via PRs:

```
Aaron Berk
guybuk
kqf
Michał Słapek
Scott Sievert
Yann Dubois
Zhao Meng
```

Here is the full list of all changes:

### Added

- Added the `event_name` argument for `LRScheduler` for optional recording of LR changes inside `net.history`. NOTE: Supported only in Pytorch>=1.4
- Make it easier to add custom modules or optimizers to a neural net class by automatically registering them where necessary and by making them available to set_params
- Added the `step_every` argument for `LRScheduler` to set whether the scheduler step should be taken on every epoch or on every batch.
- Added the `scoring` module with `loss_scoring` function, which computes the net's loss (using `get_loss`) on provided input data.
- Added a parameter `predict_nonlinearity` to `NeuralNet` which allows users to control the nonlinearity to be applied to the module output when calling `predict` and `predict_proba` (#637, #661)
- Added the possibility to save the criterion with `save_params` and with checkpoint callbacks
- Added the possibility to save custom modules with `save_params` and with checkpoint callbacks

### Changed

- Removed support for schedulers with a `batch_step()` method in `LRScheduler`.
- Raise `FutureWarning` in `CVSplit` when `random_state` is not used. Will raise an exception in a future (#620)
- The behavior of method `net.get_params` changed to make it more consistent with sklearn: it will no longer return "learned" attributes like `module_`; therefore, functions like `sklearn.base.clone`, when called with a fitted net, will no longer return a fitted net but instead an uninitialized net; if you want a copy of a fitted net, use `copy.deepcopy` instead;`net.get_params` is used under the hood by many sklearn functions and classes, such as `GridSearchCV`, whose behavior may thus be affected by the change. (#521, #527)
- Raise `FutureWarning` when using `CyclicLR` scheduler, because the default behavior has changed from taking a step every batch to taking a step every epoch. (#626)
- Set train/validation on criterion if it's a PyTorch module (#621)
- Don't pass `y=None` to `NeuralNet.train_split` to enable the direct use of split functions without positional `y` in their signatures. This is useful when working with unsupervised data (#605).
- `to_numpy` is now able to unpack dicts and lists/tuples (#657, #658)
- When using `CrossEntropyLoss`, softmax is now automatically applied to the output when calling `predict` or `predict_proba`

### Fixed

- Fixed a bug where `CyclicLR` scheduler would update during both training and validation rather than just during training.
- Fixed a bug introduced by moving the `optimizer.zero_grad()` call outside of the train step function, making it incompatible with LBFGS and other optimizers that call the train step several times per batch (#636)
- Fixed pickling of the `ProgressBar` callback (#656)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cosine annealing called every epoch instead of every batch
3 participants