Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

drop unused variable in API #6308

Merged
merged 8 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def save_checkpoint(self, trainer, pl_module):
self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates)

# Mode 2: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
Expand Down Expand Up @@ -291,8 +291,7 @@ def _del_model(self, filepath: str):
self._fs.rm(filepath)
log.debug(f"Removed checkpoint: {filepath}")

def _save_model(self, filepath: str, trainer, pl_module):
# Todo: required argument `pl_module` is not used
def _save_model(self, filepath: str, trainer):
# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

Expand Down Expand Up @@ -481,7 +480,7 @@ def _monitor_candidates(self, trainer):
monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch)
return monitor_candidates

def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
def _save_last_checkpoint(self, trainer, ckpt_name_metrics):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
return
Expand All @@ -505,9 +504,9 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):

if trainer.training_type_plugin.rpc_enabled:
# RPCPlugin manages saving all model states
trainer.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module)
trainer.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer)
else:
self._save_model(last_filepath, trainer, pl_module)
self._save_model(last_filepath, trainer)
if (
self.last_model_path and self.last_model_path != last_filepath
and (self.save_top_k != -1 or self.save_last) and trainer.is_global_zero
Expand Down Expand Up @@ -574,7 +573,7 @@ def _update_best_and_save(
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
)
self._save_model(filepath, trainer, pl_module)
self._save_model(filepath, trainer)

if del_filepath is not None and filepath != del_filepath:
self._del_model(del_filepath)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def init_rpc_connection(self, global_rank: int, world_size: int) -> None:
rpc._set_rpc_timeout(self.rpc_timeout_sec)
self._is_rpc_initialized = True

def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None:
def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None:
"""
Override to save model to disk.
This is required as the main process will be required to handle aggregating model states from RPC processes.
Expand All @@ -72,7 +72,6 @@ def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> No
save_model_fn: The saving function to save final model.
last_filepath: The filepath to save the model to.
trainer: The trainer object.
pl_module: The LightningModule.
"""
raise NotImplementedError

Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,17 @@ def configure_ddp(self):
self._model.require_backward_grad_sync = False

@rank_zero_only
def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None:
def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None:
model = self.lightning_module
if not hasattr(model.sequential_module, "foreach_worker"):
return
current_layers = pl_module.sequential_module
current_layers = model.sequential_module
model.sequential_module.foreach_worker(
save_layers_on_all_rank_zero_workers, {"gpus_per_model": self.gpus_per_model}, include_self=True
)
pl_module.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model)
save_model_fn(last_filepath, trainer, pl_module)
pl_module.sequential_module = current_layers
model.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model)
save_model_fn(last_filepath, trainer)
model.sequential_module = current_layers
Comment on lines +277 to +279
Copy link
Member Author

@Borda Borda Mar 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how much this makes sense since the model is not used...
cc: @tchaton

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is used in save_model_fn as trainer will access lightning_module. So we re-create full model only temporary. But your changes should be fine.


def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None:
model.sequential_module.foreach_worker(
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):

# set up the passed in dataloaders (if needed)
self.attach_dataloaders(model, train_dataloader, val_dataloaders)
self.attach_datamodule(model, datamodule, 'fit')
self.attach_datamodule(model, datamodule)

def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
Expand Down Expand Up @@ -112,8 +112,7 @@ def attach_dataloaders(
if predict_dataloaders is not None:
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)

def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], stage: str) -> None:
# Todo: required argument `stage` is not used
def attach_datamodule(self, model, datamodule: Optional[LightningDataModule]) -> None:

# We use datamodule if it's been provided on .fit or .test, otherwise we check model for it
datamodule = datamodule or getattr(model, 'datamodule', None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoc
current_hook_fx_name=hook_fx_name, on_step=on_step, on_epoch=on_epoch
)

def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders):
# Todo: required argument `testing` is not used
def on_evaluation_batch_start(self, batch, dataloader_idx, num_dataloaders):
model = self.trainer.lightning_module
# set dataloader_idx only if multiple ones
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
Expand Down Expand Up @@ -260,8 +259,7 @@ def track_metrics_deprecated(self, deprecated_eval_results):
self._track_callback_metrics(deprecated_eval_results)
self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results)

def evaluation_epoch_end(self, testing):
# Todo: required argument `testing` is not used
def evaluation_epoch_end(self):
# reset dataloader idx
model_ref = self.trainer.lightning_module
model_ref._current_dataloader_idx = None
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/trainer/connectors/slurm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def register_slurm_signal_handlers(self):
signal.signal(signal.SIGTERM, self.term_handler)

def sig_handler(self, signum, frame): # pragma: no-cover
# Todo: required argument `signum` is not used
# Todo: required argument `frame` is not used
if self.trainer.is_global_zero:
# save weights
log.info('handling SIGUSR1')
Expand Down Expand Up @@ -59,7 +57,5 @@ def sig_handler(self, signum, frame): # pragma: no-cover
# close experiment to avoid issues
self.trainer.logger.close()

def term_handler(self, signum, frame):
# Todo: required argument `signum` is not used
# Todo: required argument `frame` is not used
def term_handler(self, signum, frame): # pragma: no-cover
log.info("bypassing sigterm")
6 changes: 2 additions & 4 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def evaluation_step_end(self, *args, **kwargs):

def evaluation_epoch_end(self):
# unset dataloder_idx in model
self.trainer.logger_connector.evaluation_epoch_end(self.trainer.testing)
self.trainer.logger_connector.evaluation_epoch_end()

# call the model epoch end
deprecated_results = self.__run_eval_epoch_end(self.num_dataloaders)
Expand Down Expand Up @@ -283,9 +283,7 @@ def _convert_to_numpy(v):

def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx):
# set dataloader_idx to model and track batch_size
self.trainer.logger_connector.on_evaluation_batch_start(
self.trainer.testing, batch, dataloader_idx, self.num_dataloaders
)
self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self.num_dataloaders)

if self.trainer.testing:
self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ def test(
)

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model or self.lightning_module, datamodule, 'test')
self.data_connector.attach_datamodule(model or self.lightning_module, datamodule)

if model is not None:
results = self.__test_given_model(model, test_dataloaders)
Expand Down Expand Up @@ -989,7 +989,7 @@ def predict(

if datamodule is not None:
# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule, 'predict')
self.data_connector.attach_datamodule(model, datamodule)

# attach data
if dataloaders is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/test_rpc_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, **kwargs):
self.rpc_save_model_count = 0
self.worker_optimizer_step_count = 0

def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None:
def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None:
self.rpc_save_model_count += 1

def barrier(self, name: Optional[str] = None) -> None:
Expand Down