Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Enable manual optimization for TPUs #8458

Merged
merged 12 commits into from
Jul 22, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed DeepSpeed Windows support ([#8488](https://github.com/PyTorchLightning/pytorch-lightning/pull/8488))


- Enabled manual optimization for TPUs ([#8458](https://github.com/PyTorchLightning/pytorch-lightning/pull/8458))


- Fixed `accumulate_grad_batches` not been recomputed during model reload ([#5334](https://github.com/PyTorchLightning/pytorch-lightning/pull/5334))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def model_to_device(self):

def pre_backward(self, closure_loss: torch.Tensor) -> None:
"""Run before precision plugin executes backward"""
if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
if not self.lightning_module.automatic_optimization:
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
prepare_for_backward(self.model, closure_loss)

def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor:
Expand Down
78 changes: 78 additions & 0 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import collections
from copy import deepcopy

import pytest
import torch
from torch import nn

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import TPUSpawnPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
Expand Down Expand Up @@ -186,3 +190,77 @@ def test_set_devices_if_none_tpu():

trainer = Trainer(accelerator="tpu", tpu_cores=8)
assert trainer.devices == 8


@RunIf(tpu=True)
def test_manual_optimization_tpus(tmpdir):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

class ManualOptimizationModel(BoringModel):

count = 0
called = collections.defaultdict(int)

def __init__(self):
super().__init__()
self.automatic_optimization = False

@property
def should_update(self):
return self.count % 2 == 0

def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
self.called["on_train_batch_start"] += 1
self.weight_before = self.layer.weight.clone()

def training_step(self, batch, batch_idx):
self.called["training_step"] += 1
opt = self.optimizers()
output = self.layer(batch)
loss = self.loss(batch, output)

if self.should_update:
self.manual_backward(loss)
opt.step()
opt.zero_grad()
return loss

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.called["on_train_batch_end"] += 1
after_before = self.layer.weight.clone()
if self.should_update:
assert not torch.equal(self.weight_before, after_before), self.count
else:
assert torch.equal(self.weight_before, after_before)
assert torch.all(self.layer.weight.grad == 0)
self.count += 1

def on_train_end(self):
assert self.called["training_step"] == 5
assert self.called["on_train_batch_start"] == 5
assert self.called["on_train_batch_end"] == 5

class TestManualOptimizationCallack(Callback):

def on_train_end(self, trainer, pl_module):

opt = pl_module.optimizers()
assert opt._total_optimizer_step_calls == 3
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

model = ManualOptimizationModel()
model_copy = deepcopy(model)
model.training_step_end = None
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=5,
limit_test_batches=0,
limit_val_batches=0,
tpu_cores=8,
callbacks=[TestManualOptimizationCallack()]
)
trainer.fit(model)

for param, param_copy in zip(model.parameters(), model_copy.parameters()):
assert not torch.equal(param.cpu().data, param_copy.data)