Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor 2/n #2708

Merged
merged 7 commits into from
Jul 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from pytorch_lightning.accelerators.tpu_accelerator import TPUAccelerator
133 changes: 133 additions & 0 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning import _logger as log


try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True


class TPUAccelerator(object):

def __init__(self, trainer):
self.trainer = trainer
self.start_method = None

def setup(self):
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')

if not XLA_AVAILABLE:
raise MisconfigurationException('No TPU devices found.')

# COLAB_GPU is an env var available by default in Colab environments.
self.start_method = 'fork' if self.trainer.on_colab_kaggle else 'spawn'

def teardown(self):

# when training completes, load the weights back in main process
self.__load_weights_on_main_process()

def train(self, model):
self.trainer.model = model

# train
if self.trainer.tpu_id is not None:
self.tpu_train_in_process(self.trainer.tpu_id, model)
else:
xmp.spawn(
self.tpu_train_in_process,
args=(model,),
nprocs=self.trainer.tpu_cores,
start_method=self.start_method
)

def __load_weights_on_main_process(self):
model = self.trainer.model

# load weights if not interrupted
if self.trainer.on_colab_kaggle and not self.trainer.testing:
self.trainer.load_spawn_weights(model)

self.trainer.model = model

def tpu_train_in_process(self, tpu_core_idx, model):
"""
Here we are inside each individual process
"""
if not self.trainer.testing:
self.trainer.setup('fit')
model.setup('fit')

# setup TPU training
self.__setup_tpu_training(model)

# Run the pretrain routine
self.trainer.run_pretrain_routine(model)

# save weights at the end of training
self.__save_end_of_training_weights(model)

def __save_end_of_training_weights(self, model):

# when training ends on these platforms dump weights to get out of the main process
if self.trainer.on_colab_kaggle:
rank_zero_warn('cleaning up... please do not interrupt')
self.trainer.save_spawn_weights(model)

def __setup_tpu_training(self, model):
# use the default device from the process
tpu_device = xm.xla_device()

# if given an ordinal device, use this as the device
if self.trainer.tpu_id is not None:
tpu_device = xm.xla_device(self.trainer.tpu_id)

# track the device and move model to it
self.trainer._device = tpu_device
model.to(self.trainer._device)

# get the appropriate tpu ranks
self.trainer.tpu_local_core_rank = xm.get_local_ordinal()
self.trainer.tpu_global_core_rank = xm.get_ordinal()

# avoid duplicating progress bar
if self.trainer.tpu_global_core_rank != 0 and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

self.trainer.global_rank = self.trainer.tpu_local_core_rank
rank_zero_only.rank = self.trainer.global_rank

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# init 16 bit for TPU
if self.trainer.precision == 16:
os.environ['XLA_USE_BF16'] = str(1)

log.info(f'INIT TPU local core: {self.trainer.tpu_local_core_rank},'
f' global rank: {self.trainer.tpu_global_core_rank}')
40 changes: 0 additions & 40 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,46 +179,6 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device):
return model.transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)

def tpu_train(self, tpu_core_idx, model):
# call setup after the ddp process has connected
if not self.testing:
self.setup('fit')
model.setup('fit')

# put model on tpu
self._device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
model.to(self._device)

# get the appropriate tpu ranks
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()

# avoid duplicating progress bar
if self.tpu_global_core_rank != 0 and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()

self.global_rank = self.tpu_local_core_rank
rank_zero_only.rank = self.global_rank

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

# init 16 bit for TPU
if self.precision == 16:
os.environ['XLA_USE_BF16'] = str(1)

log.info(f'INIT TPU local core: {self.tpu_local_core_rank},'
f' global rank: {self.tpu_global_core_rank}')

# continue training routine
self.run_pretrain_routine(model)

# when training ends on these platforms dump weights to get out of the main process
if self.on_colab_kaggle:
rank_zero_warn('cleaning up... please do not interrupt')
self.save_spawn_weights(model)

def dp_train(self, model):
# call setup after the ddp process has connected
if not self.testing:
Expand Down
30 changes: 6 additions & 24 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from pytorch_lightning.accelerators import GPUAccelerator, TPUAccelerator

# warnings to ignore in trainer
warnings.filterwarnings(
Expand Down Expand Up @@ -1077,29 +1077,11 @@ def fit(
self.accelerator.setup(model)
results = self.run_pretrain_routine(model)

elif self.use_tpu: # pragma: no-cover
rank_zero_info(f'training on {self.tpu_cores} TPU cores')

if not XLA_AVAILABLE:
raise MisconfigurationException('No TPU devices found.')

# COLAB_GPU is an env var available by default in Colab environments.
start_method = 'fork' if self.on_colab_kaggle else 'spawn'

# track for predict
self.model = model

# train
if self.tpu_id is not None:
self.tpu_train(self.tpu_id, model)
else:
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.tpu_cores, start_method=start_method)

# load weights if not interrupted
if self.on_colab_kaggle and not self.testing:
self.load_spawn_weights(model)

self.model = model
elif self.use_tpu:
self.accelerator = TPUAccelerator(self)
self.accelerator.setup()
self.accelerator.train(model)
self.accelerator.teardown()

# ON CPU
else:
Expand Down