-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial changes * Add broken example for now * Fix reference * Fix format * Code runs * Fixes * Clear up files * Add tests, helpers, fixes * Small cleanups * Refactors based on review * Swap to special tests * Add special tests * Add source * Cleanups * Add logic to attach/detach model from devices * Fixes for tests * Fixes for tests * Move earlier * Cleanups * Add check for nvcc * Add tests, cleanups * Fix errors * fix * Try condition * Add missing annotation * Clearer * Clearer message * Fix variable * Cleanups * Add comment * CHANGELOG.md * Add simple selection test * Remove special=True to see what happens * Fix test * Update tests/accelerators/test_ipu.py Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> * Convert ipu_cores -> ipus * Add typing, fail earlier * simplify precision * Add test, add helper * fix accum * Update pytorch_lightning/plugins/training_type/ipu.py Co-authored-by: thomas chaton <thomas@grid.ai> * Use stages * Make sure warning message returned * thorw error * Add more tests, use fs * add comment * Clean * Address feedback, add IPU tests * Fixes * Fix signature * Add types * Remove autoround * Add docstring * ipu_cores -> ipus * Add test, remove unnecessary precision set * Add optimizer test * Add precision back with test * Address code review * Change to probs * Move some of the asserts earlier Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: thomas chaton <thomas@grid.ai>
- Loading branch information
1 parent
42c7f27
commit 96433d0
Showing
15 changed files
with
1,150 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
from torch.nn import functional as F | ||
|
||
import pytorch_lightning as pl | ||
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule | ||
|
||
|
||
class LitClassifier(pl.LightningModule): | ||
|
||
def __init__( | ||
self, | ||
hidden_dim: int = 128, | ||
learning_rate: float = 0.0001, | ||
): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
|
||
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) | ||
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) | ||
|
||
def forward(self, x): | ||
x = x.view(x.size(0), -1) | ||
x = torch.relu(self.l1(x)) | ||
x = torch.relu(self.l2(x)) | ||
return x | ||
|
||
def training_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = self(x) | ||
loss = F.cross_entropy(y_hat, y) | ||
return loss | ||
|
||
def validation_step(self, batch, batch_idx): | ||
x, y = batch | ||
probs = self(x) | ||
# we currently return the accuracy as the validation_step/test_step is run on the IPU devices. | ||
# Outputs from the step functions are sent to the host device, where we calculate the metrics in | ||
# validation_epoch_end and test_epoch_end for the test_step. | ||
acc = self.accuracy(probs, y) | ||
return acc | ||
|
||
def test_step(self, batch, batch_idx): | ||
x, y = batch | ||
logits = self(x) | ||
acc = self.accuracy(logits, y) | ||
return acc | ||
|
||
def accuracy(self, logits, y): | ||
# currently IPU poptorch doesn't implicit convert bools to tensor | ||
# hence we use an explicit calculation for accuracy here. Once fixed in poptorch | ||
# we can use the accuracy metric. | ||
acc = torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y) | ||
return acc | ||
|
||
def validation_epoch_end(self, outputs) -> None: | ||
# since the training step/validation step and test step are run on the IPU device | ||
# we must log the average loss outside the step functions. | ||
self.log('val_acc', torch.stack(outputs).mean(), prog_bar=True) | ||
|
||
def test_epoch_end(self, outputs) -> None: | ||
self.log('test_acc', torch.stack(outputs).mean()) | ||
|
||
def configure_optimizers(self): | ||
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) | ||
|
||
|
||
if __name__ == '__main__': | ||
dm = MNISTDataModule(batch_size=32) | ||
|
||
model = LitClassifier() | ||
|
||
trainer = pl.Trainer(max_epochs=2, ipus=8) | ||
|
||
trainer.fit(model, datamodule=dm) | ||
trainer.test(model, datamodule=dm) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections import Callable | ||
from typing import Any | ||
|
||
from torch.optim import Optimizer | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning.accelerators.accelerator import Accelerator | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
|
||
class IPUAccelerator(Accelerator): | ||
""" Accelerator for IPUs. """ | ||
|
||
def setup_optimizers(self, trainer: 'pl.Trainer') -> None: | ||
super().setup_optimizers(trainer) | ||
|
||
if len(self.optimizers) > 1: | ||
raise MisconfigurationException("IPUs currently only support one optimizer.") | ||
|
||
def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None: | ||
# Optimizer step is handled by the IPU accelerator. | ||
lambda_closure() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Optional, Union | ||
|
||
from torch import Tensor | ||
from torch.nn import Module | ||
from torch.optim import Optimizer | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin | ||
from pytorch_lightning.utilities import GradClipAlgorithmType | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
|
||
class IPUPrecisionPlugin(PrecisionPlugin): | ||
|
||
def __init__(self, precision: int) -> None: | ||
super().__init__() | ||
self.precision = precision | ||
|
||
def backward( | ||
self, | ||
model: 'pl.LightningModule', | ||
closure_loss: Tensor, | ||
optimizer: Optimizer, | ||
opt_idx: int, | ||
should_accumulate: bool, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Tensor: | ||
# IPU internally manages bwd step. | ||
return closure_loss | ||
|
||
def clip_gradients( | ||
self, | ||
optimizer: Optimizer, | ||
clip_val: Union[int, float], | ||
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, | ||
model: Optional[Module] = None | ||
) -> None: | ||
"""Clips the gradients""" | ||
if clip_val is None: | ||
return | ||
|
||
clip_val = float(clip_val) | ||
if clip_val <= 0: | ||
return | ||
|
||
raise MisconfigurationException("IPUs currently do not support clipping gradients.") |
Oops, something went wrong.