Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user to select individual TPU core to train on #1729

Merged
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8995fbc
added tpu_id
Apr 19, 2020
bd9e88c
train on individual tpu
Apr 26, 2020
1daadfa
parallel loader if tpu_id is None
May 3, 2020
e4d49d0
removed progress_bar_refresh_rate
May 4, 2020
0ed38cd
chlog
Borda May 5, 2020
725ef5d
replaced num_tpu_cores with tpu_cores
May 6, 2020
c0a4f9d
set tpu_id to None if int
May 6, 2020
f25d516
changed num_tpu_cores to tpu_cores in docs
May 6, 2020
a93c6bc
Merge branch 'master' into feature/1539_tpu_train_parallel
lezwon May 7, 2020
b22f485
updated docs
May 9, 2020
cdda262
Merge branch 'master' into feature/1539_tpu_train_parallel
lezwon May 9, 2020
0669ad2
updated __init__.py
May 9, 2020
2253b9f
Update pytorch_lightning/trainer/__init__.py
Borda May 10, 2020
67c5688
check if tpu_cores is a list
lezwon May 13, 2020
ec278d1
xla device conditional
May 10, 2020
100071b
num_tpu_cores deprecation
May 13, 2020
8adb0a9
removed duplicate warning
May 13, 2020
34f2209
Merge remote-tracking branch 'official/master' into feature/1539_tpu_…
May 13, 2020
f779d01
fixed pep8 error
May 13, 2020
dafe174
Revert "removed duplicate warning"
May 14, 2020
4c6958e
deprecated api update
May 14, 2020
5c0db30
fixed recursion error
May 14, 2020
c7a9b4e
fixed tests
May 14, 2020
83e5d99
fixed flake errors
May 14, 2020
230831e
Merge remote-tracking branch 'official/master' into feature/1539_tpu_…
May 14, 2020
59e0b49
removed current_tpu_index
May 14, 2020
f22d90d
Merge branch 'master' into feature/1539_tpu_train_parallel
williamFalcon May 17, 2020
940f70b
Update CHANGELOG.md
Borda May 17, 2020
ec300ee
Update trainer.py
Borda May 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)).

### Changed

- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))

### Deprecated

### Removed
Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,11 @@ trainer = Trainer(max_epochs=1, gpus=8, num_nodes=32)

Or TPUs
```python
trainer = Trainer(num_tpu_cores=8)
# Distributes TPU core training
trainer = Trainer(tpu_cores=8)

# Single TPU core training
trainer = Trainer(tpu_cores=[1])
```

When you're done training, run the test accuracy
Expand Down
4 changes: 2 additions & 2 deletions docs/source/apex.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ TPU 16-bit
.. testcode::

# DEFAULT
trainer = Trainer(num_tpu_cores=8, precision=32)
trainer = Trainer(tpu_cores=8, precision=32)

# turn on 16-bit
trainer = Trainer(num_tpu_cores=8, precision=16)
trainer = Trainer(tpu_cores=8, precision=16)
56 changes: 7 additions & 49 deletions docs/source/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ EXACTLY the same as you would a PyTorch Module.

Out:

.. code-block:: none
.. code-block:: python

torch.Size([1, 10])

Expand Down Expand Up @@ -519,50 +519,8 @@ First, change the runtime to TPU (and reinstall lightning).

Next, install the required xla library (adds support for PyTorch on TPUs)

.. code-block:: python

import collections
from datetime import datetime, timedelta
import os
import requests
import threading

_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
VERSION = "torch_xla==nightly" #@param ["xrt==1.15.0", "torch_xla==nightly"]
CONFIG = {
'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
(datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
}[VERSION]
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)

# Update TPU XRT version
def update_server_xrt():
print('Updating server-side XRT to {} ...'.format(CONFIG.server))
url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
XRT_VERSION=CONFIG.server,
)
print('Done updating server-side XRT: {}'.format(requests.post(url)))

update = threading.Thread(target=update_server_xrt)
update.start()

.. code-block::

# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5
update.join()
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In distributed training (multiple GPUs and multiple TPU cores) each GPU or TPU core will run a copy
of this program. This means that without taking any care you will download the dataset N times which
Expand Down Expand Up @@ -609,7 +567,7 @@ Now we can train the LightningModule on a TPU without doing anything else!
.. code-block:: python

model = LitMNIST()
trainer = Trainer(num_tpu_cores=8)
trainer = Trainer(tpu_cores=8)
trainer.fit(model)

You'll now see the TPU cores booting up.
Expand Down Expand Up @@ -696,7 +654,7 @@ while checking the validation set.
from pytorch_lightning import Trainer

model = LitMNIST()
trainer = Trainer(num_tpu_cores=8)
trainer = Trainer(tpu_cores=8)
trainer.fit(model)

You may have noticed the words `Validation sanity check` logged. This is because Lightning runs 5 batches
Expand Down Expand Up @@ -747,7 +705,7 @@ Once you train your model simply call `.test()`.
from pytorch_lightning import Trainer

model = LitMNIST()
trainer = Trainer(num_tpu_cores=8)
trainer = Trainer(tpu_cores=8)
trainer.fit(model)

# run test set
Expand All @@ -769,7 +727,7 @@ You can also run the test from a saved lightning model
.. code-block:: python

model = LitMNIST.load_from_checkpoint(PATH)
trainer = Trainer(num_tpu_cores=8)
trainer = Trainer(tpu_cores=8)
trainer.test(model)

.. note:: Lightning disables gradients, puts model in eval mode and does everything needed for testing.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ Lightning allows multiple ways of training
- DistributedDataParallel (`distributed_backend='ddp'`) (multiple-gpus across many machines).
- DistributedDataParallel2 (`distributed_backend='ddp2'`) (dp in a machine, ddp across machines).
- Horovod (`distributed_backend='horovod'`) (multi-machine, multi-gpu, configured at runtime)
- TPUs (`num_tpu_cores=8|x`) (tpu or TPU pod)
- TPUs (`tpu_cores=8|x`) (tpu or TPU pod)

.. note:: If you request multiple GPUs without setting a mode, ddp will be automatically used.

Expand Down
4 changes: 2 additions & 2 deletions docs/source/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ However, this time you need to specifically call test (this is done so you don't
# OPTION 2:
# test after loading weights
model = LitModel.load_from_checkpoint(PATH)
trainer = Trainer(num_tpu_cores=1)
trainer = Trainer(tpu_cores=1)
trainer.test()

Again, under the hood, lightning does the following in (pseudocode):
Expand Down Expand Up @@ -236,7 +236,7 @@ Without changing a SINGLE line of your code, you can now do the following with t
# train on TPUs using 16 bit precision with early stopping
# using only half the training data and checking validation every quarter of a training epoch
trainer = Trainer(
nb_tpu_cores=8,
tpu_cores=8,
precision=16,
early_stop_checkpoint=True,
train_percent_check=0.5,
Expand Down
63 changes: 17 additions & 46 deletions docs/source/tpu.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
TPU support
===========

Lightning supports running on TPUs. At this moment, TPUs are only available
on Google Cloud (GCP). For more information on TPUs
Lightning supports running on TPUs. At this moment, TPUs are available
on Google Cloud (GCP), Google Colab and Kaggle Environments. For more information on TPUs
`watch this video <https://www.youtube.com/watch?v=kPMpmcl_Pyw>`_.

---------------
Expand Down Expand Up @@ -31,6 +31,7 @@ To access TPUs there are two main ways.

1. Using google colab.
2. Using Google Cloud (GCP).
3. Using Kaggle.

---------------

Expand All @@ -51,50 +52,10 @@ To get a TPU on colab, follow these steps:
4. Next, insert this code into the first cell and execute.
This will install the xla library that interfaces between PyTorch and the TPU.

.. code-block:: python

import collections
from datetime import datetime, timedelta
import os
import requests
import threading

_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
VERSION = "xrt==1.15.0" #@param ["xrt==1.15.0", "torch_xla==nightly"]
CONFIG = {
'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
(datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
}[VERSION]
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)

# Update TPU XRT version
def update_server_xrt():
print('Updating server-side XRT to {} ...'.format(CONFIG.server))
url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
XRT_VERSION=CONFIG.server,
)
print('Done updating server-side XRT: {}'.format(requests.post(url)))

update = threading.Thread(target=update_server_xrt)
update.start()

.. code-block::

# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5
update.join()
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

5. Once the above is done, install PyTorch Lightning (v 0.7.0+).

Expand Down Expand Up @@ -156,13 +117,23 @@ To use a full TPU pod skip to the TPU pod section.
import pytorch_lightning as pl

my_model = MyLightningModule()
trainer = pl.Trainer(num_tpu_cores=8)
trainer = pl.Trainer(tpu_cores=8)
trainer.fit(my_model)

That's it! Your model will train on all 8 TPU cores.

---------------

Single TPU core training
----------------------------
Lightning supports training on a single TPU core. Just pass the TPU core ID [1-8] in a list.

.. code-block:: python

trainer = pl.Trainer(tpu_cores=[1])

---------------

Distributed Backend with TPU
----------------------------
The ```distributed_backend``` option used for GPUs does not apply to TPUs.
Expand Down Expand Up @@ -195,7 +166,7 @@ set the 16-bit flag.
import pytorch_lightning as pl

my_model = MyLightningModule()
trainer = pl.Trainer(num_tpu_cores=8, precision=16)
trainer = pl.Trainer(tpu_cores=8, precision=16)
trainer.fit(my_model)

Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.
Expand Down
31 changes: 23 additions & 8 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,22 @@ def on_train_end(self):

num_tpu_cores
^^^^^^^^^^^^^
Borda marked this conversation as resolved.
Show resolved Hide resolved
How many TPU cores to train on (1 or 8).
.. warning:: .. deprecated:: 0.7.6

Use `tpu_cores` instead. Will remove 0.9.0.

Example::

python -m torch_xla.distributed.xla_dist
--tpu=$TPU_POD_NAME
--conda-env=torch-xla-nightly
--env=XLA_USE_BF16=1
-- python your_trainer_file.py

tpu_cores
^^^^^^^^^
- How many TPU cores to train on (1 or 8).
- Which TPU core to train on [1-8]

A single TPU v2 or v3 has 8 cores. A TPU pod has
up to 2048 cores. A slice of a POD means you get as many cores
Expand All @@ -615,21 +630,21 @@ def on_train_end(self):
# your_trainer_file.py

# default used by the Trainer (ie: train on CPU)
trainer = Trainer(num_tpu_cores=None)
trainer = Trainer(tpu_cores=None)

# int: train on a single core
trainer = Trainer(num_tpu_cores=1)
trainer = Trainer(tpu_cores=1)

# list: train on a single selected core
trainer = Trainer(tpu_cores=[2])

# int: train on all cores few cores
trainer = Trainer(num_tpu_cores=8)
trainer = Trainer(tpu_cores=8)

# for 8+ cores must submit via xla script with
# a max of 8 cores specified. The XLA script
# will duplicate script onto each TPU in the POD
trainer = Trainer(num_tpu_cores=8)

# -1: train on all available TPUs
trainer = Trainer(num_tpu_cores=-1)
trainer = Trainer(tpu_cores=8)

To train on more than 8 cores (ie: a POD),
submit this script using the xla_dist script.
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,9 @@ def training_tqdm_dict(self):
rank_zero_warn("`training_tqdm_dict` was renamed to `progress_bar_dict` in v0.7.3"
" and this method will be removed in v0.9.0", DeprecationWarning)
return self.progress_bar_dict

@property
def num_tpu_cores(self):
"""Back compatibility, will be removed in v0.9.0"""
rank_zero_warn("Argument `num_tpu_cores` is now set by `tpu_cores` since v0.7.6"
" and this argument will be removed in v0.9.0", DeprecationWarning)
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ class TrainerDPMixin(ABC):
root_gpu: ...
amp_level: str
precision: ...
current_tpu_idx: ...
proc_rank: int
tpu_local_core_rank: int
tpu_global_core_rank: int
Expand All @@ -398,6 +397,7 @@ class TrainerDPMixin(ABC):
data_parallel_device_ids: ...
logger: Union[LightningLoggerBase, bool]
progress_bar_callback: ...
tpu_id: int

@property
@abstractmethod
Expand Down Expand Up @@ -442,7 +442,8 @@ def __transfer_data_to_device(self, batch, device, gpu_id=None):
if device == 'tpu' and XLA_AVAILABLE:
# base case: object can be directly moved using `to`
if callable(getattr(batch, 'to', None)):
return batch.to(xm.xla_device())
xla_device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
return batch.to(xla_device)

if device == 'gpu':
# base case: object can be directly moved using `cuda` or `to`
Expand Down Expand Up @@ -501,7 +502,8 @@ def single_gpu_train(self, model):

def tpu_train(self, tpu_core_idx, model):
# put model on tpu
model.to(xm.xla_device())
self._device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
model.to(self._device)

# get the appropriate tpu ranks
self.tpu_local_core_rank = xm.get_local_ordinal()
Expand All @@ -511,8 +513,6 @@ def tpu_train(self, tpu_core_idx, model):
if self.tpu_global_core_rank != 0 and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()

# track current tpu
self.current_tpu_idx = tpu_core_idx
self.proc_rank = self.tpu_local_core_rank
rank_zero_only.rank = self.proc_rank

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class TrainerEvaluationLoopMixin(ABC):
val_dataloaders: DataLoader
use_tpu: bool
reload_dataloaders_every_epoch: ...
tpu_id: int

# Callback system
on_validation_batch_start: Callable
Expand Down Expand Up @@ -248,7 +249,7 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
dl_outputs = []

# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu:
if self.use_tpu and self.tpu_id is None:
device = xm.xla_device()
dataloader = xla_pl.ParallelLoader(dataloader, [device])
dataloader = dataloader.per_device_loader(device)
Expand Down
Loading