Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add structured result output #1989

Closed
wants to merge 177 commits into from
Closed

[WIP] Add structured result output #1989

wants to merge 177 commits into from

Conversation

williamFalcon
Copy link
Contributor

@williamFalcon williamFalcon commented May 28, 2020

This PR maintains full backward compatibility, but for those who want the option adds an optional argument to the *_step that is a structured dict with validation for managing things to return from each step

Old way (still supported)

def training_step(self, batch, batch_idx...)
    return {...}

New way

  # any loop
  def training_step(self, batch, batch_idx...)
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)

        # structure the return from the training loop
        return Result(loss)

Docs:

            # all options:
            def training_step(...):
                return Result(
                    minimize=loss,
                    checkpoint_on=loss,
                    early_stop_on=loss,
                    logs={'train_loss': loss},
                    progress_bar={'train_loss': loss}
                )

            # most of the time
            # will early stop and save checkpoints based on this metric by default
            return Result(loss)

            # to change what to early stop on
            return Result(loss, early_stop_on=accuracy)

            # to change what to checkpoint on
            return Result(loss, early_stop_on=accuracy, checkpoint_on=bleu_score)

            # shorthand for logging
            result = Result(loss)
            result.log('train_nce_loss', loss)

            # shorthand to put on progress bar
            result.to_bar('train_nce_loss', loss)

Additional benefits:

  • gets rid of 'val_loss' magic. Now user can explicitly set what to stop or save on.
  • gets rid of 'loss' now user can call it whatever they want (we're just minimizing it).
  • adds error checking if the user puts in the wrong details.
  • clear separation of what each return item does

@pep8speaks
Copy link

pep8speaks commented May 28, 2020

Hello @williamFalcon! Thanks for updating this PR.

Line 86:36: E231 missing whitespace after ':'

Line 26:120: E501 line too long (123 > 119 characters)
Line 57:120: E501 line too long (123 > 119 characters)
Line 300:38: W292 no newline at end of file

Line 334:21: E303 too many blank lines (2)
Line 435:120: E501 line too long (120 > 119 characters)

Line 453:13: E265 block comment should start with '# '
Line 689:13: E731 do not assign a lambda expression, use a def

Line 54:35: W292 no newline at end of file

Line 62:28: E226 missing whitespace around arithmetic operator
Line 62:39: E226 missing whitespace around arithmetic operator
Line 223:41: W292 no newline at end of file

Comment last updated at 2020-06-08 13:54:38 UTC

@williamFalcon
Copy link
Contributor Author

@tullie is this in line with what you were thinking?
Will ping you when done

@mergify mergify bot requested a review from a team May 28, 2020 17:00
@tullie
Copy link
Contributor

tullie commented May 28, 2020

Why does StepResult need to be passed in as a training_step argument? That seems unintuitive to me.

Could we aim for something like:

def training_step(self, batch: Tensor, batch_idx: int):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return StepResult(loss=loss, log=loss)

@Borda Borda added the feature Is an improvement or enhancement label May 28, 2020
@Borda Borda added this to the 0.8.0 milestone May 28, 2020
@williamFalcon
Copy link
Contributor Author

williamFalcon commented May 28, 2020

@tullie ok, updated. check out the new API before i finish making all the deep changes

"""
        Result is an OrderedDict that gives type hints, allowed fields and validation for bad user input.

        Use as the return value for:
        - training_step
        - validation_epoch_end
        - training_epoch_end

        .. note:: Plain dictionary returns are supported but are more prone to errors

        We automatically detach anything here for you to avoid holding references to graphs

        Args:
            minimize: Metric to minimize
            logs: dictionary that will be added to your logger(s)
            early_stop_on: Metric for early stopping. If none set, will use minimize by default.
            checkpoint_on: Metric for checkpointing. If none set, will use minimize by default.
            progress_bar: dictionary of values to add to the progress bar
            hiddens: tensor of hiddens to pass to next step when using TBPTT

        .. code-block: python

            # all options:
            def training_step(...):
                return Result(
                    minimize=loss,
                    checkpoint_on=loss,
                    early_stop_on=loss,
                    logs={'train_loss': loss},
                    progress_bar={'train_loss': loss}
                )

            # most of the time
            # will early stop and save checkpoints based on this metric by default
            return Result(loss)

            # to change what to early stop on
            return Result(loss, early_stop_on=accuracy)

            # to change what to checkpoint on
            return Result(loss, early_stop_on=accuracy, checkpoint_on=bleu_score)

            # shorthand for logging
            result = Result(loss)
            result.log('train_nce_loss', loss)

            # shorthand to put on progress bar
            result.to_bar('train_nce_loss', loss)
        """

STILL SUPPORTED:

return {...}

TODO:

  • tests
  • docs
  • update early stopping behavior
  • update checkpoint behavior

@jeremyjordan
Copy link
Contributor

jeremyjordan commented May 29, 2020

related to #1256, happy to see progress here!

as i mentioned in the other issue, it might be nice to use something like pydantic for the structured output since it provides very easy data validation for free

@tullie
Copy link
Contributor

tullie commented May 29, 2020

@williamFalcon yeah this is great. Huge improvement imo! I'll let you finish the TODOs and then do a closer sweep of all the code.

@Borda
Copy link
Member

Borda commented May 29, 2020

as i mentioned in the other issue, it might be nice to use something like pydantic for the structured output since it provides very easy data validation for free

this looks very similar to Namespace except the validation and other features around... :]

@williamFalcon
Copy link
Contributor Author

@jeremyjordan #1256 is awesome, must have missed that haha. Why don't i put together the v1 right now and you can take a stab at V2? happy to make this a joint PR. with this object we can now do whatever we want under the hood for validation :)

@jeremyjordan
Copy link
Contributor

yeah, sounds great!

@mergify
Copy link
Contributor

mergify bot commented May 29, 2020

This pull request is now in conflict... :(

5 similar comments
@mergify
Copy link
Contributor

mergify bot commented May 30, 2020

This pull request is now in conflict... :(

@mergify
Copy link
Contributor

mergify bot commented May 30, 2020

This pull request is now in conflict... :(

@mergify
Copy link
Contributor

mergify bot commented May 30, 2020

This pull request is now in conflict... :(

@mergify
Copy link
Contributor

mergify bot commented May 30, 2020

This pull request is now in conflict... :(

@mergify
Copy link
Contributor

mergify bot commented May 30, 2020

This pull request is now in conflict... :(

@Borda Borda added the Important label Jun 1, 2020
@Borda
Copy link
Member

Borda commented Jun 1, 2020

is it still WIP? 🦝

@williamFalcon
Copy link
Contributor Author

yes. but feel free to comment on the current api

Define parameters that only apply to this model
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--in_features', default=28 * 28, type=int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be parsed automatically from Models units arguments...



class Result(Dict):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding docstring and doctest here would be great, I can do it later on

Comment on lines +169 to +170
minimize = self.__getitem__('minimize')
self.__setitem__('checkpoint_on', minimize)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
minimize = self.__getitem__('minimize')
self.__setitem__('checkpoint_on', minimize)
self.__setitem__('checkpoint_on', self.minimize)

Comment on lines +184 to +185
minimize = self.__getitem__('minimize')
self.__setitem__('early_stop_on', minimize)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
minimize = self.__getitem__('minimize')
self.__setitem__('early_stop_on', minimize)
self.__setitem__('early_stop_on', self.minimize)

Comment on lines 379 to 382
if __name__ == '__main__':
import torch
result = Result()
result.minimize = torch.tensor(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if __name__ == '__main__':
import torch
result = Result()
result.minimize = torch.tensor(1)

this may go to doctest...

# ---------------------
# test dic return only
# ---------------------
model = DeterministicModel()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks the same as the above...

@mergify mergify bot requested a review from a team June 1, 2020 22:49

@log_on_batch_end.setter
def log_on_batch_end(self, x):
if x is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the case you would pass NOne here or any very same setter?

@mergify mergify bot requested a review from a team June 1, 2020 23:08
@williamFalcon williamFalcon marked this pull request as draft June 7, 2020 20:07
@import-antigravity
Copy link

Is this still in active development? It would be really nice to have :)

@williamFalcon
Copy link
Contributor Author

yes!! in #2615

@williamFalcon williamFalcon deleted the result branch July 22, 2020 17:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants