Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Trainer reference from lightning module and datamodule #7315

Open
edenlightning opened this issue May 2, 2021 · 2 comments
Open

Remove Trainer reference from lightning module and datamodule #7315

edenlightning opened this issue May 2, 2021 · 2 comments
Assignees
Labels
design Includes a design discussion discussion In a discussion stage feature Is an improvement or enhancement
Milestone

Comments

@edenlightning
Copy link
Contributor

edenlightning commented May 2, 2021

Motivation: Remove trainer reference from LightningModule and LightningDataModule.

Benefits:

  • Better backwards compatibility: it should be easier to use the LightningModule of one version of Lightning with the Trainer from another version
  • Better API clarity: users today who call Trainer functions like fit/validate/predict inside of the LightningModule will run into issues with state management. The Trainer assumes these top-level APIs are not called from within another trainer context.

Possible solution: think about introducing a TrainerContext / TrainerState object to pass state and deprecate references to the trainer. This would be mostly read-only data that the LightningModule could leverage for various settings like progress (epoch/step count), distributed setting (global rank, local rank, data parallel/ddp/deepspeed/etc) and more.

The ambition is the lightning module has a tightly controlled view of the trainer, while the trainer has full insight into the module.

cc @Borda @tchaton @justusschock @awaelchli

@edenlightning edenlightning added feature Is an improvement or enhancement help wanted Open to be worked on design Includes a design discussion labels May 2, 2021
@edenlightning edenlightning added this to the v1.4 milestone May 2, 2021
@edenlightning edenlightning modified the milestones: v1.4, v1.5 Jul 6, 2021
@dhfromkorea
Copy link

dhfromkorea commented Aug 3, 2021

Hello, @edenlightning @ananthsub may I ask a question about this proposal and for your advice?

Context
My organization happens to have a use case for call Trainer functions like fit/validate/predict inside of the LightningModule as per our algorithm. When doing so naively, it caused changes to the trainer state, leading to a totally different execution path, upon completing an inner top-level API call.

More concretely, for algorithms that need to do inference during training (not validation, but to support training steps), I think PL doesn't have good support today.

I implemented a thin wrapper (subclass) over pl.Trainer that implements a new API called predict_in_fit where a fresh Trainer instance is created with a model & a data module are passed to it. While this works well for our use case, It left me with a nagging feeling that we are __ working around __ the original design design for Trainer: "The Trainer assumes these top-level APIs are not called from within another trainer context".

My wish
Thus, I am looking for a better solution that will hopefully be made available officially in the next PL version.

Now, here comes a question.

Possible solution: for v1.4, think about introducing a TrainerContext / TrainerState object

Could you help me understand this idea more concretely possibly with an example? I would like to understand if the idea would support our use case.

A direction I was having in mind, if I were part of the core PL dev team is something like:

class MyModel(LightningModule):
    def on_train_dataloader(self):
           the_singleton_trainer_ref = self.trainer  # or via some other way if we want to remove the ref in LM.

           with new_trainer_context(the_singleton_trainer_ref) as new_trainer:
                   # do whatever we want. 
                   # new_trainer would have a fresh state but inherits all other data members.
                   new_trainer.predict(model=self, datamodule=...) 
                   
                   # we need to make sure the work inside the context manager is side effect free.
                   # 1. the original trainer state should preserved the same as before after the context is closed.
                   # 2. no side effect for model. for instance, teardown should not move a model from GPU to CPU if user
                   #     doesn't want it. This behavior is not configurable today.
                   # 3. no side effect for the data module.

The context manager could be implemented:

from contextlib import contextmanager

@contextmanager
def new_trainer_context(trainer: pl. Trainer) -> Iterator[pl.Trainer]:
       # how to do this efficiently would require internal knowledge of Trainer
       # hence, an internal custom API, rather than copy.copy or copy.deepcopy.
       new_trainer = trainer.copy_for_new_context()   
       
       yield new_trainer
       
       # if the two below are expected to be different, they should be different.
       # e.g. top-level context: fit. inner context: predict. 
       assert trainer.state != new_trainer.state

       # clean up as needed.

I am not an expert in PL but I have been reading its source code for a while. I'd love to hear your expert advice and help brainstorm.

Thank you for taking the time to read this.

@ananthsub
Copy link
Contributor

@dhfromkorea

My apologies for the slow reply.

By Trainer Context, I do not mean a Python context manager. Rather, I mean a read-only view of the Trainer state. This would allow the LightningModule and DataModule to have access to information in a controlled manner without exposing the full Trainer object here.

More concretely, for algorithms that need to do inference during training (not validation, but to support training steps), I think PL doesn't have good support today.

Fundamentally, I think this should be solved at the loop level. Though I would like to better understand the constraints of your use case (e.g. why iterating over a dataloader inside the LightningModule is not feasible for you vs. needing to call trainer.predict). This is a huge challenge for precisely the reasons you mentioned: it is very difficult to manage storing and resuming trainer states appropriately. I am also curious if there's another way to split up your use case into a pipeline rather than doing everything at once.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion discussion In a discussion stage feature Is an improvement or enhancement
Projects
None yet
Development

No branches or pull requests

5 participants