Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Formalize progress tracking inside of the trainer internals #6429

Closed
ananthsub opened this issue Mar 9, 2021 · 9 comments · Fixed by #8362
Closed

Formalize progress tracking inside of the trainer internals #6429

ananthsub opened this issue Mar 9, 2021 · 9 comments · Fixed by #8362
Assignees
Labels
design Includes a design discussion discussion In a discussion stage feature Is an improvement or enhancement help wanted Open to be worked on refactor

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Mar 9, 2021

🚀 Feature

We should better enforce progress tracking across these dimensions:
Stage: training, evaluation (validation and test), and prediction loops
Granularity: batches, steps, epochs

  • batches vs steps: steps = optimizer steps (parameter updates) and applies to training loop only. this will differ from batches when gradient accumulation is used, for instance

Motivation

  • Provide consistency across trainer stages for users
  • Better debugging
  • Create mid-epoch resumable state for training

Pitch

See the full example here

  1. Create a dataclass that each Loop in Lightning is responsible for maintaining. All of this state is local to a particular trainer rank. These dataclasses live as attributes on the corresponding loop.
@dataclass
class LoopProgress:
    # this also serves as the index of the current epoch
    total_epochs_processed: int = 0
    # this monotonically increases and is summed across epochs
    total_batches_processed: int = 0
    # this resets at the end of the epoch back to 0
    batches_processed_this_epoch: int = 0
    
    # convenience utils for accessing state
    def bump_batch(self, increment: int = 1):
        self.total_batches_processed += increment
        self.batches_processed_this_epoch += increment
    
    def bump_epoch(self, increment: int = 1):
        self.total_epochs_processed += increment
    
    def reset_batch_in_epoch(self):
        self.batches_processed_this_epoch = 0

The train loop extends this to track optimizer steps

@dataclass
class TrainLoopProgress(LoopProgress):
    total_optimizer_steps_processed: int = 0
    optimizer_steps_processed_this_epoch: int = 0

    def bump_step(self, increment: int = 1):
        self.total_optimizer_steps_processed += increment
        self.optimizer_steps_processed_this_epoch += increment
    
    def reset_step_in_epoch(self):
        self.optimizer_steps_processed_this_epoch = 0
  1. The Trainer maintains its own tracker that has references to the individual loop progress trackers
@dataclass
class Progress:
    train_progress: TrainLoopProgress
    val_progress: LoopProgress
    test_progress: LoopProgress
    predict_progress: LoopProgress
  1. For convenience, we could offer synchronization utilities to sum the progress state across ranks to get totals
  2. We can also offer convenience utilities to get the totals across different stages
  3. We update the loops to populate this state, and make backwards compatible changes to reference this state
  4. We save the progress state into checkpoints as part of the trainer state
  5. We handle loading this state when resuming from checkpoints
@ananthsub ananthsub added feature Is an improvement or enhancement help wanted Open to be worked on design Includes a design discussion discussion In a discussion stage labels Mar 9, 2021
@wittenator
Copy link

That sounds interesting! In your scenario: Would the trainer keep track of the different metrics and then deposit the data in the model, so that the data survives the lifecycle of the trainer?

@carmocca
Copy link
Contributor

I really like this. This could also serve as part of the interface for custom loops to interact with the trainer state. cc: @justusschock

But I wouldn't do it via a callback, i'd have it as a property trainer.tracker and each loop is responsible of updating it.

total_batch_count

This can be a property calculated from the others, but what do you sum? just training and validation? also test?

do we need to track test/predict?

We can have them, just because why not, even though they shouldn't impact the trainer state

Much of the accounting already exists via the progress bar callback, so we can use this as a starting point to see what's missing, and make sure the properties being updated are documented and made available accordingly

The progress bar could be generated from this data and the current metrics available

Or about implementing this via callbacks (as what's done today) vs adding it to the core train/evaluation/predict loops?

To me it needs to be out of the loops and out of the callbacks into the trainer state. The loops would modify it and the callbacks would read it.

Would the trainer keep track of the different metrics and then deposit the data in the model, so that the data survives the lifecycle of the trainer?

This is a great observation. We could also put a tracker in the model so with just a model checkpoint, you can know exactly how many batches has it seen. This model tracker would outlive the trainer and the trainer tracker would add to the model state.

For example, loading a model trained for 1000 epochs, and training for 10 epochs would mean that trainer.tracker.train_epochs == 10 and model.tracker.train_epochs == 1010

@ananthsub
Copy link
Contributor Author

ananthsub commented Mar 11, 2021

Yes totally @carmocca I should've refined this further.

  • Loops should be required to implement an interface that updates this tracker
  • The state should be tracked per loop across batches/steps/epochs
  • The state should be exposed to the trainer
  • The trainer can group together across different loop types (train/val/test/predict)
  • The trainer can re-expose this grouped tracking state to other components
  • Callbacks and the Lightning/DataModule should read this state via the trainer (can we enforce this as read-only?)

@ananthsub ananthsub changed the title Formalize progress tracking inside of the trainer internals [WIP] Formalize progress tracking inside of the trainer internals Mar 12, 2021
@willleeney
Copy link

Will this integrate with the supported loggers? What I would like to be able to do is pause training mid epoch, and resume training at that point using a new instance of Trainer and the same logger as before by passing resume_from_checkpoint=checkpoint_path?

@justusschock
Copy link
Member

@willleeney No, this won't happen. This is just tracking the trainer states. Resuming to work correctly has to be implemented by each logger individually (which often is not that trivial, since they handle it in very different ways). Also this is not dumped to the checkpoint.

@willleeney
Copy link

@justusschock no worries, as long as it creates the mid-epoch resumable state for training, it'll be good enough for me :)

@wittenator
Copy link

wittenator commented Mar 17, 2021

@carmocca If there is anything I can help with, I would be happy to chime in since that is a blocker for my work anyway.

@willleeney
Copy link

@wittenator I currently have a quick workaround that solves my issue with the logger until this is done but I still have an issue relating to this... It would be really useful if the trainer state could switch between multiple data loaders on request, executing a custom number of training batch steps, so that both data loaders don't iterate through at the same time?

Sorry I don't know if this is the right place to ask or if this is already implemented? But what I need is to be able to keep track of how far through each data loader I am, in order to iterate a through a chosen number of batches ?

@ananthsub ananthsub changed the title [WIP] Formalize progress tracking inside of the trainer internals Formalize progress tracking inside of the trainer internals Mar 20, 2021
@carmocca
Copy link
Contributor

@wittenator @willleeney We plan to resolve all those issues, but it will be done in several steps so the best way you can help right now is to keep track of the PRs and review them 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion discussion In a discussion stage feature Is an improvement or enhancement help wanted Open to be worked on refactor
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants