Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelCheckpoint save_function() not set? #4079

Closed
celsofranssa opened this issue Oct 11, 2020 · 2 comments
Closed

ModelCheckpoint save_function() not set? #4079

celsofranssa opened this issue Oct 11, 2020 · 2 comments
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@celsofranssa
Copy link

I am training a PL model using the following code snippet:

    # logger
    tb_logger = pl_loggers.TensorBoardLogger(cfg.logs.path, name='rnn_exp')

    # checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        filepath=cfg.checkpoint.path + "encoder_rnn{epoch:02d}",
        save_top_k=1,
        mode="min" # monitor is defined in val_step: EvalResult(checkpoint_on=val_loss)
    )

    # early stopping callback
    early_stopping_callback = EarlyStopping(
        monitor="val_loss",
        patience=cfg.val.patience,
        mode="min"
    )

    tokenizer = ...
    dm = MyDataModule(cfg, tokenizer)

    model = RNNEncoder(cfg)

    trainer = Trainer(
        fast_dev_run=False,
        max_epochs=cfg.train.max_epochs,
        gpus=1,
        logger=tb_logger,
        callbacks=[checkpoint_callback, early_stopping_callback]
    )

    # training
    dm.setup('fit')
    trainer.fit(model, datamodule=dm)

However, after the first epoch, the model presents the following error, probably when calling the model checkpoint callback:

    trainer.fit(model, datamodule=dm)
  File "/home/celso/projects/venvs/semantic_code_search/lib/python3.7/site-packages/pytorch_lightning/trainer/states.py", line 48, in wrapped_fn
    result = fn(self, *args, **kwargs)
  File "/home/celso/projects/venvs/semantic_code_search/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1073, in fit
    results = self.accelerator_backend.train(model)
  File "/home/celso/projects/venvs/semantic_code_search/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_backend.py", line 51, in train
    results = self.trainer.run_pretrain_routine(model)
  File "/home/celso/projects/venvs/semantic_code_search/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1239, in run_pretrain_routine
    self.train()
  File "/home/celso/projects/venvs/semantic_code_search/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 394, in train
    self.run_training_epoch()
  File "/home/celso/projects/venvs/semantic_code_search/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 516, in run_training_epoch
    self.run_evaluation(test_mode=False)
  File "/home/celso/projects/venvs/semantic_code_search/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 603, in run_evaluation
    self.on_validation_end()
  File "/home/celso/projects/venvs/semantic_code_search/lib/python3.7/site-packages/pytorch_lightning/trainer/callback_hook.py", line 176, in on_validation_end
    callback.on_validation_end(self, self.get_model())
  File "/home/celso/projects/venvs/semantic_code_search/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py", line 27, in wrapped_fn
    return fn(*args, **kwargs)
  File "/home/celso/projects/venvs/semantic_code_search/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 380, in on_validation_end
    self._do_check_save(filepath, current, epoch, trainer, pl_module)
  File "/home/celso/projects/venvs/semantic_code_search/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 421, in _do_check_save
    self._save_model(filepath, trainer, pl_module)
  File "/home/celso/projects/venvs/semantic_code_search/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 212, in _save_model
    raise ValueError(".save_function() not set")
ValueError: .save_function() not set

Could you tell me if I forgot to configure something?

@celsofranssa celsofranssa added bug Something isn't working help wanted Open to be worked on labels Oct 11, 2020
@awaelchli
Copy link
Contributor

awaelchli commented Oct 11, 2020

currently you need to set the ModelCheckpoint via Trainer(checkpoint_callback=...)
#3990 will enable passing it to callbacks

@celsofranssa
Copy link
Author

Thanks, @awaelchli,

I've just thought that. Thanks a lot for the help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

2 participants