Skip to content

Commit

Permalink
Refactor 2/n (#2708)
Browse files Browse the repository at this point in the history
* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator
  • Loading branch information
williamFalcon committed Jul 25, 2020
1 parent e9ed9b7 commit b34217e
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 64 deletions.
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from pytorch_lightning.accelerators.tpu_accelerator import TPUAccelerator
133 changes: 133 additions & 0 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning import _logger as log


try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True


class TPUAccelerator(object):

def __init__(self, trainer):
self.trainer = trainer
self.start_method = None

def setup(self):
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')

if not XLA_AVAILABLE:
raise MisconfigurationException('No TPU devices found.')

# COLAB_GPU is an env var available by default in Colab environments.
self.start_method = 'fork' if self.trainer.on_colab_kaggle else 'spawn'

def teardown(self):

# when training completes, load the weights back in main process
self.__load_weights_on_main_process()

def train(self, model):
self.trainer.model = model

# train
if self.trainer.tpu_id is not None:
self.tpu_train_in_process(self.trainer.tpu_id, model)
else:
xmp.spawn(
self.tpu_train_in_process,
args=(model,),
nprocs=self.trainer.tpu_cores,
start_method=self.start_method
)

def __load_weights_on_main_process(self):
model = self.trainer.model

# load weights if not interrupted
if self.trainer.on_colab_kaggle and not self.trainer.testing:
self.trainer.load_spawn_weights(model)

self.trainer.model = model

def tpu_train_in_process(self, tpu_core_idx, model):
"""
Here we are inside each individual process
"""
if not self.trainer.testing:
self.trainer.setup('fit')
model.setup('fit')

# setup TPU training
self.__setup_tpu_training(model)

# Run the pretrain routine
self.trainer.run_pretrain_routine(model)

# save weights at the end of training
self.__save_end_of_training_weights(model)

def __save_end_of_training_weights(self, model):

# when training ends on these platforms dump weights to get out of the main process
if self.trainer.on_colab_kaggle:
rank_zero_warn('cleaning up... please do not interrupt')
self.trainer.save_spawn_weights(model)

def __setup_tpu_training(self, model):
# use the default device from the process
tpu_device = xm.xla_device()

# if given an ordinal device, use this as the device
if self.trainer.tpu_id is not None:
tpu_device = xm.xla_device(self.trainer.tpu_id)

# track the device and move model to it
self.trainer._device = tpu_device
model.to(self.trainer._device)

# get the appropriate tpu ranks
self.trainer.tpu_local_core_rank = xm.get_local_ordinal()
self.trainer.tpu_global_core_rank = xm.get_ordinal()

# avoid duplicating progress bar
if self.trainer.tpu_global_core_rank != 0 and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

self.trainer.global_rank = self.trainer.tpu_local_core_rank
rank_zero_only.rank = self.trainer.global_rank

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# init 16 bit for TPU
if self.trainer.precision == 16:
os.environ['XLA_USE_BF16'] = str(1)

log.info(f'INIT TPU local core: {self.trainer.tpu_local_core_rank},'
f' global rank: {self.trainer.tpu_global_core_rank}')
40 changes: 0 additions & 40 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,46 +179,6 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device):
return model.transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)

def tpu_train(self, tpu_core_idx, model):
# call setup after the ddp process has connected
if not self.testing:
self.setup('fit')
model.setup('fit')

# put model on tpu
self._device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
model.to(self._device)

# get the appropriate tpu ranks
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()

# avoid duplicating progress bar
if self.tpu_global_core_rank != 0 and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()

self.global_rank = self.tpu_local_core_rank
rank_zero_only.rank = self.global_rank

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

# init 16 bit for TPU
if self.precision == 16:
os.environ['XLA_USE_BF16'] = str(1)

log.info(f'INIT TPU local core: {self.tpu_local_core_rank},'
f' global rank: {self.tpu_global_core_rank}')

# continue training routine
self.run_pretrain_routine(model)

# when training ends on these platforms dump weights to get out of the main process
if self.on_colab_kaggle:
rank_zero_warn('cleaning up... please do not interrupt')
self.save_spawn_weights(model)

def dp_train(self, model):
# call setup after the ddp process has connected
if not self.testing:
Expand Down
30 changes: 6 additions & 24 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from pytorch_lightning.accelerators import GPUAccelerator, TPUAccelerator

# warnings to ignore in trainer
warnings.filterwarnings(
Expand Down Expand Up @@ -1077,29 +1077,11 @@ def fit(
self.accelerator.setup(model)
results = self.run_pretrain_routine(model)

elif self.use_tpu: # pragma: no-cover
rank_zero_info(f'training on {self.tpu_cores} TPU cores')

if not XLA_AVAILABLE:
raise MisconfigurationException('No TPU devices found.')

# COLAB_GPU is an env var available by default in Colab environments.
start_method = 'fork' if self.on_colab_kaggle else 'spawn'

# track for predict
self.model = model

# train
if self.tpu_id is not None:
self.tpu_train(self.tpu_id, model)
else:
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.tpu_cores, start_method=start_method)

# load weights if not interrupted
if self.on_colab_kaggle and not self.testing:
self.load_spawn_weights(model)

self.model = model
elif self.use_tpu:
self.accelerator = TPUAccelerator(self)
self.accelerator.setup()
self.accelerator.train(model)
self.accelerator.teardown()

# ON CPU
else:
Expand Down

0 comments on commit b34217e

Please sign in to comment.