Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Support a Trainer.train() API #10888

Closed
ananthsub opened this issue Dec 1, 2021 · 10 comments
Closed

[RFC] Support a Trainer.train() API #10888

ananthsub opened this issue Dec 1, 2021 · 10 comments
Labels
discussion In a discussion stage feature Is an improvement or enhancement trainer

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Dec 1, 2021

🚀 Feature

Add a new entry point to the Trainer which runs only the training loop with no validation.

Motivation

This makes it clear that if users only define training_step and train_dataloader then they can call train without any risk of errors due to not implementing validation hooks. Though the framework checks this today.

Another motivation is that users who do implement validation steps/dataloaders may only want to run training without validation. (for example, in the case of online training). Today, those users would need to ensure they set limit_val_batches=0 before calling trainer.fit

Finally, such a feature makes it easier to interleave train/validate/test/predict calls. For example, past requests have been made to run the test loop after each validation pass. In conjunction with #10444 this makes writing more complex patterns far simpler with Lightning.

This is slightly different from loop customization. In this case, I don't want to change any of the fundamental building blocks, but I may want to change the order/sequencing in which they're called.

t = Trainer(...)
m = MyLightningModule(...)

def my_custom_interleaving(t: Trainer, m: LightningModule):
    t.train(m, dry_run=True)  # make sure everything runs okay, fail fast if not
    t.train(m, max_epochs=1)
    t.validate(m)
    t.test(m)
    t.predict(m)
    t.fit(m, max_epochs=10)
    t.validate(m)
    t.test(m)
    t.predict(m)

Pitch

Offer a top-level function on the Trainer:

def train(
    self,
    model: "pl.LightningModule",
    train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
    datamodule: Optional[LightningDataModule] = None,
    ckpt_path: Optional[str] = None,
 ) -> None:

Alternatives

One could try to work around this as follows:

t = Trainer(..., limit_val_batches=0)
t.fit(model)
t.limit_val_batches = 1.0
t.validate(model)
t.limit_val_batches = 0
...

However, this is somewhat clunky to write, and requires users to dig through the various trainer properties/attributes to reset state across calls, which is not straightforward.

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @Borda @justusschock @kaushikb11 @awaelchli @ananthsub @ninginthecloud @jjenniferdai @rohitgr7

@ananthsub ananthsub added the feature Is an improvement or enhancement label Dec 1, 2021
@carmocca
Copy link
Contributor

carmocca commented Dec 2, 2021

Do you think users might want to implement different training logic when they call trainer.fit vs trainer.train? If yes, do you think our current on_train_* hooks should have been named on_fit_*?

I guess trainer.fit validation and trainer.validate has the same problem anyways. Users can access trainer.state.fn == "validate" to differentiate.

@ananthsub
Copy link
Contributor Author

ananthsub commented Dec 3, 2021

Do you think users might want to implement different training logic when they call trainer.fit vs trainer.train? If yes, do you think our current on_train_* hooks should have been named on_fit_*?

No, I think on_train_* hooks are the right call, as train is a fundamental building block ( represented by RunningStage)https://github.com/PyTorchLightning/pytorch-lightning/blob/a28b4cd0c0bba30c21cae571e650877f66cf5588/pytorch_lightning/trainer/states.py#L56

and Trainer functions are higher-level compositions which can run 1+ RunningStages: https://github.com/PyTorchLightning/pytorch-lightning/blob/a28b4cd0c0bba30c21cae571e650877f66cf5588/pytorch_lightning/trainer/states.py#L34

as you note, if the user wants some logic to happen in validation hooks with trainer.fit but not with trainer.validate then the trainer provides access to the state (running stage, fn) to distinguish

@ananthsub ananthsub changed the title Support a Trainer.train() API [RFC] Support a Trainer.train() API Dec 8, 2021
@ananthsub ananthsub added the discussion In a discussion stage label Dec 8, 2021
@carmocca
Copy link
Contributor

Some questions:

  • Would you expect that trainer.train shares the results and progress tracking state with trainer.fit?
  • Would you ask for this feature if trainer.limit_val_batches was part of trainer.fit?
  • Does this addition mean trainer.fit should raise an exception when there's no validation?

@ananthsub
Copy link
Contributor Author

Would you expect that trainer.train shares the results and progress tracking state with trainer.fit?

No, I'd expect train to operate independently, the same way trainer.validate keeps track of its own state.

Would you ask for this feature if trainer.limit_val_batches was part of trainer.fit?

The need would be far less, but I think a dedicated entry point is clearer for users and provides greater confidence that the framework doesn't initialize or check anything related to validation (including validation sanity checks). fit with limit_val_batches=0 would functionally be the same thing though.

Does this addition mean trainer.fit should raise an exception when there's no validation?

I think we could keep the same behavior of not checking validation if limit_val_batches=0 and users calls trainer.fit. But by default, yes we could throw an exception.

Then users have a really clear path for onboarding:

  • Implement the train hooks (training_step, train_dataloader, configure_optimizers) and call trainer.train()
  • Implement the validation hooks (validation_step, val_dataloader) and call trainer.validate
  • Implement the test hooks (test_step, test_dataloader) and call trainer.test
  • Implement the predict hooks (predict_step, predict_dataloader) and call trainer.predict
  • Implement both training and validation and call trainer.fit

One could argue that fit is best practice compared to train since one should always have training & validation data split out, but it's not like the framework enforces this today, given the checks if hooks are overridden and skipping those stages if they're not.

@stale
Copy link

stale bot commented Jan 14, 2022

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Jan 14, 2022
@carmocca carmocca removed the won't fix This will not be worked on label Jan 14, 2022
@Borda
Copy link
Member

Borda commented Jan 19, 2022

Add a new entry point to the Trainer which runs only the training loop with no validation.

Could you please provide a model use case?
Personally, I feel that it goes against the PL best practice principle, which in most cases includes progress monitoring and that shall be done on the validation set, not on training...
Also, would you get the same functionally if you in Trainer's init limit number of validation bathes to zero?

@carmocca
Copy link
Contributor

carmocca commented Jan 19, 2022

It's a sensible proposal but has the big drawback of creating confusion with trainer.fit.
Especially because fit with no validation data will still be supported.
If fitting with no training data had been supported, we might not have added trainer.validate

However, I see the advantage for external loop customization or orchestrating multiple calls.

Would you implement the loop class used as a copy of the FitLoop and its children with all validation logic removed or would you just disable the validation data and use the FitLoop?

@tchaton
Copy link
Contributor

tchaton commented Jan 21, 2022

After reflecting on this, @ananthsub I believe this shouldn't be added.

First, because the Trainer API is final, but more importantly because it would force bad practices on the user. I am 100 % sure @williamFalcon and co at the beginning thought hard about it and the fact this option doesn't exist was meant to be from scratch.

IMO, the best practice induced by the trainer.fit default is to perform a sanity checking. Furthermore, opt-in out is quite simple but should be the responsibility of the user, e.g Trainer(limit_val_batches=0) or no validation_datalaoder.

+1 for added confusion.

@awaelchli @carmocca I would recommend to close this RFC.

Thanks @ananthsub for your time and effort proposing this :)

@gzerveas
Copy link

gzerveas commented May 3, 2022

The existing solution e.g. Trainer(limit_val_batches=0) is sufficient, but if you are wondering about possible use cases for not performing validation, here is one:
one first defines a custom train/validation split of the training set for optimizing hyperparameters, and once those are fixed, one wants to use the entire training set for training. Evaluation may happen on a separate test set at the end, or maybe it is not possible at all (e.g. hidden test set).

@extragoya
Copy link

The solution is indeed sufficient, although improved documentation would be welcome. For those interested, another use case is training models to reconstruct shapes or scenes, e.g., DeepSDF https://openaccess.thecvf.com/content_CVPR_2019/papers/Park_DeepSDF_Learning_Continuous_Signed_Distance_Functions_for_Shape_Representation_CVPR_2019_paper.pdf. I believe neural radiance fields would have a similar use case, and are very popular.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion In a discussion stage feature Is an improvement or enhancement trainer
Projects
No open projects
Status: No status
Development

No branches or pull requests

6 participants