-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add structured result output #1989
Conversation
Hello @williamFalcon! Thanks for updating this PR.
Comment last updated at 2020-06-08 13:54:38 UTC |
@tullie is this in line with what you were thinking? |
Why does StepResult need to be passed in as a training_step argument? That seems unintuitive to me. Could we aim for something like:
|
@tullie ok, updated. check out the new API before i finish making all the deep changes """
Result is an OrderedDict that gives type hints, allowed fields and validation for bad user input.
Use as the return value for:
- training_step
- validation_epoch_end
- training_epoch_end
.. note:: Plain dictionary returns are supported but are more prone to errors
We automatically detach anything here for you to avoid holding references to graphs
Args:
minimize: Metric to minimize
logs: dictionary that will be added to your logger(s)
early_stop_on: Metric for early stopping. If none set, will use minimize by default.
checkpoint_on: Metric for checkpointing. If none set, will use minimize by default.
progress_bar: dictionary of values to add to the progress bar
hiddens: tensor of hiddens to pass to next step when using TBPTT
.. code-block: python
# all options:
def training_step(...):
return Result(
minimize=loss,
checkpoint_on=loss,
early_stop_on=loss,
logs={'train_loss': loss},
progress_bar={'train_loss': loss}
)
# most of the time
# will early stop and save checkpoints based on this metric by default
return Result(loss)
# to change what to early stop on
return Result(loss, early_stop_on=accuracy)
# to change what to checkpoint on
return Result(loss, early_stop_on=accuracy, checkpoint_on=bleu_score)
# shorthand for logging
result = Result(loss)
result.log('train_nce_loss', loss)
# shorthand to put on progress bar
result.to_bar('train_nce_loss', loss)
""" STILL SUPPORTED:
TODO:
|
related to #1256, happy to see progress here! as i mentioned in the other issue, it might be nice to use something like |
@williamFalcon yeah this is great. Huge improvement imo! I'll let you finish the TODOs and then do a closer sweep of all the code. |
this looks very similar to |
@jeremyjordan #1256 is awesome, must have missed that haha. Why don't i put together the v1 right now and you can take a stab at V2? happy to make this a joint PR. with this object we can now do whatever we want under the hood for validation :) |
yeah, sounds great! |
This pull request is now in conflict... :( |
5 similar comments
This pull request is now in conflict... :( |
This pull request is now in conflict... :( |
This pull request is now in conflict... :( |
This pull request is now in conflict... :( |
This pull request is now in conflict... :( |
is it still WIP? 🦝 |
yes. but feel free to comment on the current api |
Define parameters that only apply to this model | ||
""" | ||
parser = ArgumentParser(parents=[parent_parser], add_help=False) | ||
parser.add_argument('--in_features', default=28 * 28, type=int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be parsed automatically from Models units arguments...
|
||
|
||
class Result(Dict): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding docstring and doctest here would be great, I can do it later on
minimize = self.__getitem__('minimize') | ||
self.__setitem__('checkpoint_on', minimize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minimize = self.__getitem__('minimize') | |
self.__setitem__('checkpoint_on', minimize) | |
self.__setitem__('checkpoint_on', self.minimize) |
minimize = self.__getitem__('minimize') | ||
self.__setitem__('early_stop_on', minimize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minimize = self.__getitem__('minimize') | |
self.__setitem__('early_stop_on', minimize) | |
self.__setitem__('early_stop_on', self.minimize) |
if __name__ == '__main__': | ||
import torch | ||
result = Result() | ||
result.minimize = torch.tensor(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if __name__ == '__main__': | |
import torch | |
result = Result() | |
result.minimize = torch.tensor(1) |
this may go to doctest...
tests/trainer/test_trainer.py
Outdated
# --------------------- | ||
# test dic return only | ||
# --------------------- | ||
model = DeterministicModel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks the same as the above...
|
||
@log_on_batch_end.setter | ||
def log_on_batch_end(self, x): | ||
if x is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the case you would pass NOne here or any very same setter?
Is this still in active development? It would be really nice to have :) |
yes!! in #2615 |
This PR maintains full backward compatibility, but for those who want the option adds an optional argument to the *_step that is a structured dict with validation for managing things to return from each step
Old way (still supported)
New way
Docs:
Additional benefits: