Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schedulers like get_linear_schedule_with_warmup need access to the length of the train dataset #1038

Closed
bilal2vec opened this issue Mar 4, 2020 · 21 comments
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on

Comments

@bilal2vec
Copy link
Contributor

🐛 Bug

If you're using a lr scheduler that needs access to the number of batches in the train dataset like @huggingface's get_linear_schedule_with_warmup, there's currently no way to access the dataset in configure_optimizers() because it looks like it is called before train_dataloader().

It would be nice to have some way to load the datasets before the optimizers and make the dataset available to other methods with something like self.train_dataset = train_dataset.

Code sample:

train_steps = int(len(train_dataset) / (batch_size * grad_steps) * epochs)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=
                      int(0.1 * train_steps), num_training_steps=train_steps)
@bilal2vec bilal2vec added bug Something isn't working help wanted Open to be worked on labels Mar 4, 2020
@festeh
Copy link
Contributor

festeh commented Mar 4, 2020

Well you can pass train dataset or loader in constructor, so it will be available as a field, any reason not to do so?

@rmrao
Copy link
Contributor

rmrao commented Mar 4, 2020

I would support this change - as far as I know it should just change the order in which the dataset / optimizer methods are called w/o impacting anything else.

Another note on these schedulers is that I believe you have to override optimizer_step to step the scheduler on each update, rather than per-epoch (which is the only thing lightning supports as of now).

@pableeto
Copy link

pableeto commented Mar 6, 2020

@rmrao It seems latest pl supports per-step lr schedulers as in #941 .

@rmrao
Copy link
Contributor

rmrao commented Mar 6, 2020

Ah ok great!

@bilal2vec
Copy link
Contributor Author

Fixed with #941, closing

@marrrcin
Copy link

I think that this issue should not be closed yet. From what I can see in the PR for #941 it started to support granular LR stepping, but it does not cover usage for something like the get_linear_schedule_with_warmup mentioned in the first post of this issue as creating such schedule requires access to number of epochs (or total number of steps). Am I missing something?

@Borda Borda reopened this Mar 25, 2020
@Borda
Copy link
Member

Borda commented Mar 25, 2020

@SkafteNicki pls ^^

@SkafteNicki
Copy link
Member

Agree that PR #941 only covers the granular LR stepping.
Regarding the case with get_linear_schedule_with_warmup, I do not think lightning needs a specific feature to support this, since the user can already achieve this with a bit of code:

def configure_optimizers(self):
      optimizer = ...
      train_steps = len(self.training_dataloader()) * self.hparams.max_epochs
      lr_scheduler = get_linear_scheduler_with_warmup(optimzer, num_warmup_steps=
                      int(0.1 * train_steps), num_training_steps=train_steps

where the trainer is then initialized with pl.Trainer(max_epochs=self.hparams.max_epochs) (and probably early stopping disabled).
If we want this to be a fully supported feature, then we need to expose the trainer to the model i.e. make it an argument to configure_optimizers(self, trainer), which I do not think is a good idea.

@marrrcin
Copy link

@SkafteNicki thanks for the response. That's almost exactly what I did:

    @lru_cache()
    def total_steps(self):
        return len(self.train_dataloader()) // self.hparams.accumulate_grad_batches * self.hparams.epochs

    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr=self.hparams.lr)
        lr_scheduler = get_linear_schedule_with_warmup(
                    optimizer,
                    num_warmup_steps=self.hparams.warmup_steps,
                    num_training_steps=self.total_steps(),
        )
        return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}]

If that's the "recommended" way of doing it then I'm fine with that :)

@Borda Borda closed this as completed Mar 26, 2020
@czheng94
Copy link

Any chance we can revive this issue?
I think this feature still needs to be supported, especially when you are doing distributed_backend="ddp|ddp2|horovod".

In this case

total_steps = len(self.train_dataloader()) // self.hparams.accumulate_grad_batches * self.hparams.epochs // num_distributed_processes

And num_distributed_processes is usually not specified in the arguments if running on a SLURM cluster. In addition, when users choose different distributed backend (e.g. ddp v.s. horovod), the method to get this num_distributed_processes will also differ (or you can get it from the trainer).

I agree with @SkafteNicki that it's bad to pass the trainer into configure_optimizers(self, trainer).
What I'm imagining is maybe we can provide a special lambdaLR scheduler class that will be configured in the training loop, so that total_steps can be passed in as a parameter to the lambda.

@Borda Borda reopened this May 13, 2020
@SkafteNicki
Copy link
Member

I am not sure how deep this should be integrated into lightning, it is after all a feature for specific types of schedulers (those who rely on knowing the total number of steps) in the specific case where the user do not know (in advance) how many distributed processes it gets allocated. I will let this be up to the core team.

That said the Callback system in lightning already allows for doing this.
In the configure_optimizers method we start out by only defining the variable we know at that point using functools.partial:

def configure_optimizers(self):
    optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
    scheduler = partial(transformers.get_cosine_schedule_with_warmup, 
                        optimizer=optimizer, num_cycles=0.5)
    return [optimizer], [{'scheduler': scheduler}]

Then we define a callback that will pass in the remaining variables that are available through the trainer in the on_train_start method (so before any training actually starts)

class TransformerLrScheduler(pl.Callback):
    def on_train_start(self, trainer, pl_module):
        for lr_scheduler in trainer.lr_schedulers:
            if hasattr(lr_scheduler['scheduler'], '__call__'):
                scheduler = lr_scheduler['scheduler']
                  
                n_train = len(pl_module.train_dataloader())
                n_accumulate_grad = trainer.accumulate_grad_batches
                n_max_epochs = trainer.max_epochs
                n_devices = trainer.num_gpus # or trainer.tpu_cores if tpu or 1 if cpu
                    
                num_training_steps = n_train // n_accumulate_grad * n_max_epochs // n_devices
                num_warmup_steps = int(0.1*num_training_steps)
                    
                # Here we actually define the lr schduler
                scheduler = scheduler(num_warmup_steps=num_warmup_steps, 
                                      num_training_steps=num_training_steps)
    
                lr_scheduler['scheduler'] = scheduler

We of cause then initialize the trainer with pl.Trainer(callbacks=[TransformerLrScheduler()])

@edenlightning
Copy link
Contributor

@williamFalcon thoughts?

@edenlightning edenlightning added feature Is an improvement or enhancement and removed bug Something isn't working labels Jul 27, 2020
@SokolovYaroslav
Copy link

I think that we can easily fix the first mentioned issue, if just configure_optimizers() will be called after train_dataloader(). In that case, we can easily save the length of the data loader into an attribute and use it in the configure_optimizers() method.
I don't like previously mentioned solution with explicit call len(self.train_dataloader()) because I have to call prepare_data() first and construct train_dataloader one more unnecessary time.

@stale
Copy link

stale bot commented Oct 21, 2020

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Oct 21, 2020
@stale stale bot closed this as completed Oct 28, 2020
@congchan
Copy link

What is the progress now?
I am considering either using huggingface trainer or ligntning trainer, and I am a heavy user of warmup functionality.

put schedule, optimizer and model in Trainer.fit seems to be a more reasonable choice.

@shimengfeng
Copy link

@congchan I am not sure whether what I did is the best practice. I created a DataModule class and then I get the train dataloader size after loading the data. I pass the value into my LitModel class as init-param train_dataloader_size. Then I do the calculation within the model class to pass it to my scheduler.

However, I am not sure how will dp/ddp/horovod change the calcuation.

@sloth2012
Copy link

I am not sure how deep this should be integrated into lightning, it is after all a feature for specific types of schedulers (those who rely on knowing the total number of steps) in the specific case where the user do not know (in advance) how many distributed processes it gets allocated. I will let this be up to the core team.

That said the Callback system in lightning already allows for doing this.
In the configure_optimizers method we start out by only defining the variable we know at that point using functools.partial:

def configure_optimizers(self):
    optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
    scheduler = partial(transformers.get_cosine_schedule_with_warmup, 
                        optimizer=optimizer, num_cycles=0.5)
    return [optimizer], [{'scheduler': scheduler}]

Then we define a callback that will pass in the remaining variables that are available through the trainer in the on_train_start method (so before any training actually starts)

class TransformerLrScheduler(pl.Callback):
    def on_train_start(self, trainer, pl_module):
        for lr_scheduler in trainer.lr_schedulers:
            if hasattr(lr_scheduler['scheduler'], '__call__'):
                scheduler = lr_scheduler['scheduler']
                  
                n_train = len(pl_module.train_dataloader())
                n_accumulate_grad = trainer.accumulate_grad_batches
                n_max_epochs = trainer.max_epochs
                n_devices = trainer.num_gpus # or trainer.tpu_cores if tpu or 1 if cpu
                    
                num_training_steps = n_train // n_accumulate_grad * n_max_epochs // n_devices
                num_warmup_steps = int(0.1*num_training_steps)
                    
                # Here we actually define the lr schduler
                scheduler = scheduler(num_warmup_steps=num_warmup_steps, 
                                      num_training_steps=num_training_steps)
    
                lr_scheduler['scheduler'] = scheduler

We of cause then initialize the trainer with pl.Trainer(callbacks=[TransformerLrScheduler()])

it doesn't work since optimizer will validate in https://github.com/PyTorchLightning/pytorch-lightning/blob/b1e3dcc607522b06e88d0cb086ab655b49c88b35/pytorch_lightning/trainer/optimizers.py#L178

@dsuthar-nvidia
Copy link

Can this be re-opened? Why not do what @SokolovYaroslav is suggesting. If configure_optimizers() can be called after train_dataloader() then users can simply save the length in a local variable.

@sai-prasanna
Copy link

@Borda huggingface's trainer passes the total training steps to the configure learning rate scheduler. https://github.com/huggingface/transformers/blob/b440b8d1ce404abc1fd953ec2e32e1817a4a4a77/src/transformers/trainer.py#L771 But this works only for dataloaders that are sized.

I think calling configure optimizer after training loader is created as @SokolovYaroslav suggested makes sense for our lightning trainer.

@Yevgnen
Copy link

Yevgnen commented Dec 2, 2021

Any update?

@igor17400
Copy link

Any update on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests