Skip to content

Commit

Permalink
Trainer only references accelerator (#6039)
Browse files Browse the repository at this point in the history
* Trainer only references accelerator where it can

* Move teardown to the trainer, as it is reponsible for the accelerator
  • Loading branch information
SeanNaren authored Feb 17, 2021
1 parent 7189d67 commit b7c2e0a
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 33 deletions.
49 changes: 32 additions & 17 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,25 @@ def setup(self, trainer: "Trainer", model: LightningModule) -> None:
self.setup_optimizers(trainer)
self.connect_precision_plugin(self.precision_plugin)

def start_training(self, trainer: 'Trainer'):
self.training_type_plugin.start_training(trainer)

def start_testing(self, trainer: 'Trainer'):
self.training_type_plugin.start_testing(trainer)

def start_predicting(self, trainer: 'Trainer'):
self.training_type_plugin.start_predicting(trainer)

def pre_dispatch(self) -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.pre_dispatch()
self.precision_plugin.pre_dispatch()

def post_dispatch(self) -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()

@property
def model(self) -> torch.nn.Module:
"""Returns the model. This can also be a wrapped LightningModule.
Expand Down Expand Up @@ -224,23 +243,6 @@ def validation_step_end(self, output):
"""
return self.training_type_plugin.validation_step_end(output)

def predict(self, args):
"""The prediction step.
Args:
args: the arguments for the models predict step. Can consist of the following:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): Integer displaying index of this batch
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
"""
batch = self.to_device(args[0])
args[0] = batch
return self.training_type_plugin.predict(*args)

def backward(
self,
closure_loss: torch.Tensor,
Expand Down Expand Up @@ -380,6 +382,10 @@ def on_save(self, checkpoint):
def barrier(self, name: Optional[str] = None) -> None:
self.training_type_plugin.barrier(name=name)

def broadcast(self, obj: object, src: int = 0) -> object:
"""Broadcasts an object to all processes"""
return self.training_type_plugin.broadcast(obj, src)

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Expand All @@ -399,3 +405,12 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
"""
return self.training_type_plugin.process_dataloader(dataloader)

@property
def results(self) -> Any:
"""
The results of the last training/testing run will be cached here.
In distributed training, we make sure to transfer the results to the appropriate master process.
"""
# TODO: improve these docs
return self.training_type_plugin.results
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
dataloader = self._flatten_dl_only(dataloader)

if self.accelerator_backend is not None:
self.training_type_plugin.barrier('get_dataloaders')
self.accelerator_backend.barrier('get_dataloaders')
return dataloader

def _flatten_dl_only(self, dataloaders):
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
from torch.optim import Optimizer

from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.states import TrainerState
Expand Down Expand Up @@ -138,7 +138,7 @@ def log_dir(self) -> Optional[str]:
else:
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')

dirpath = self.training_type_plugin.broadcast(dirpath)
dirpath = self.accelerator_backend.broadcast(dirpath)
return dirpath

@property
Expand Down Expand Up @@ -365,7 +365,7 @@ def lightning_optimizers(self) -> List[LightningOptimizer]:

@property
def lightning_module(self) -> LightningModule:
return self.training_type_plugin.lightning_module
return self.accelerator_backend.lightning_module

@property
def optimizers(self) -> Optional[List[Optimizer]]:
Expand Down
22 changes: 10 additions & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
Expand All @@ -33,6 +32,7 @@
from pytorch_lightning.profiler import BaseProfiler
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
Expand Down Expand Up @@ -484,7 +484,7 @@ def fit(
# trainer.dispatch || LIGHTNING
# | ||
# start_training or start_testing or start_predicting call || FLOW
# from `accelerator.training_type_plugin` ||
# from `accelerator` ||
# | || DIRECTION
# run_train or run_test or run_predict call ||
# from `trainer` ||
Expand Down Expand Up @@ -532,26 +532,24 @@ def fit(

self._set_running_stage(None, model)

return self.training_type_plugin.results or 1
return self.accelerator_backend.results or 1

def pre_dispatch(self):
self.training_type_plugin.pre_dispatch()
self.precision_plugin.pre_dispatch()
self.accelerator_backend.pre_dispatch()

def post_dispatch(self):
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()
self.accelerator_backend.post_dispatch()
self.accelerator_backend.teardown()

def dispatch(self):
if self.testing:
self.training_type_plugin.start_testing(self)
self.accelerator_backend.start_testing(self)

elif self.predicting:
self.training_type_plugin.start_predicting(self)
self.accelerator_backend.start_predicting(self)

else:
self.training_type_plugin.start_training(self)
self.accelerator_backend.start_training(self)

def train_or_test_or_predict(self):
if self.testing:
Expand All @@ -575,7 +573,7 @@ def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule):

def _pre_training_routine(self):
# wait for all to join if on distributed
self.accelerator.training_type_plugin.barrier("setup_training")
self.accelerator.barrier("setup_training")

# register auto-resubmit when on SLURM
self.slurm_connector.register_slurm_signal_handlers()
Expand Down Expand Up @@ -948,7 +946,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
)
return {}
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()
self.accelerator_backend.barrier()

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
Expand Down

0 comments on commit b7c2e0a

Please sign in to comment.