diff --git a/.drone.yml b/.drone.yml index 5e6c08f7a8256..9774ffaaaecc7 100644 --- a/.drone.yml +++ b/.drone.yml @@ -32,6 +32,8 @@ steps: - pip --version - nvidia-smi - pip install -r ./requirements/devel.txt --upgrade-strategy only-if-needed -v --no-cache-dir + # when Image has defined CUDa version we can switch to this package spec "nvidia-dali-cuda${CUDA_VERSION%%.*}0" + - pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100 --upgrade-strategy only-if-needed - pip list - coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --color=yes --durations=25 # --flake8 - python -m pytest benchmarks pl_examples -v --color=yes --maxfail=2 --durations=0 # --flake8 diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 6e6a6af863d27..78c89cdae7e05 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -4,6 +4,8 @@ Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. + +If we didn't discuss your PR in Github issues there's a high chance it will not be merged. --> Fixes # (issue) @@ -20,10 +22,14 @@ Fixes # (issue) ## PR review - - [ ] Is this pull request ready for review? (if not, please submit in draft mode) +Anyone in the community is free to review the PR once the tests have passed. +Before you start reviewing make sure you have read [Review guidelines](https://github.com/PyTorchLightning/pytorch-lightning/wiki/Review-guidelines). In in short, see following bullet-list: -Anyone in the community is free to review the PR once the tests have passed. -If we didn't discuss your PR in Github issues there's a high chance it will not be merged. + - [ ] Is this pull request ready for review? (if not, please submit in draft mode) + - [ ] Check that all items from **Before submitting** are resolved + - [ ] Make sure the title is self explanatory and the description concisely explains the PR + - [ ] Add labels and milestones (and optionally projects) to the PR so it can be classified; _Bugfixes should be including in bug-fix release milestones (m.f.X) and features should be included in (m.X.b) releases._ + ## Did you have fun? Make sure you had fun coding 🙃 diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index b14f40a6c4339..1395b7ede4b1d 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -29,7 +29,7 @@ jobs: # We do this, since failures on test.pypi aren't that bad - name: Publish to Test PyPI - uses: pypa/gh-action-pypi-publish@v1 + uses: pypa/gh-action-pypi-publish@v1.4.1 with: user: __token__ password: ${{ secrets.test_pypi_password }} diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index fb99ae2284f76..354f799df20b1 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -30,7 +30,7 @@ jobs: # We do this, since failures on test.pypi aren't that bad - name: Publish to Test PyPI if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@v1 + uses: pypa/gh-action-pypi-publish@v1.4.1 with: user: __token__ password: ${{ secrets.test_pypi_password }} @@ -39,7 +39,7 @@ jobs: - name: Publish distribution 📦 to PyPI if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@v1 + uses: pypa/gh-action-pypi-publish@v1.4.1 with: user: __token__ password: ${{ secrets.pypi_password }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 42d78b40b0332..1f4defbd5cc30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `fsspec` to tuner ([#4458](https://github.com/PyTorchLightning/pytorch-lightning/pull/4458)) +- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) + + ### Changed diff --git a/benchmarks/test_parity.py b/benchmarks/test_parity.py index d2b30afb23946..d2bc97deff598 100644 --- a/benchmarks/test_parity.py +++ b/benchmarks/test_parity.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize('cls_model,max_diff', [ (ParityModuleRNN, 0.05), - (ParityModuleMNIST, 0.70) + (ParityModuleMNIST, 0.8) ]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_pytorch_parity(tmpdir, cls_model, max_diff): diff --git a/notebooks/01-mnist-hello-world.ipynb b/notebooks/01-mnist-hello-world.ipynb index 79bc9ebec9632..b0323458c228b 100644 --- a/notebooks/01-mnist-hello-world.ipynb +++ b/notebooks/01-mnist-hello-world.ipynb @@ -1,400 +1,448 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "01-mnist-hello-world.ipynb", - "provenance": [], - "collapsed_sections": [], - "authorship_tag": "ABX9TyOtAKVa5POQ6Xg3UcTQqXDJ", - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "i7XbLCXGkll9", - "colab_type": "text" - }, - "source": [ - "# Introduction to Pytorch Lightning ⚡\n", - "\n", - "In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n", - "\n", - "---\n", - " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", - " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", - " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2LODD6w9ixlT", - "colab_type": "text" - }, - "source": [ - "### Setup \n", - "Lightning is easy to install. Simply ```pip install pytorch-lightning```" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "zK7-Gg69kMnG", - "colab_type": "code", - "colab": {} - }, - "source": [ - "! pip install pytorch-lightning --quiet" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "w4_TYnt_keJi", - "colab_type": "code", - "colab": {} - }, - "source": [ - "import os\n", - "\n", - "import torch\n", - "from torch import nn\n", - "from torch.nn import functional as F\n", - "from torch.utils.data import DataLoader, random_split\n", - "from torchvision.datasets import MNIST\n", - "from torchvision import transforms\n", - "import pytorch_lightning as pl\n", - "from pytorch_lightning.metrics.functional import accuracy" - ], - "execution_count": 2, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EHpyMPKFkVbZ", - "colab_type": "text" - }, - "source": [ - "## Simplest example\n", - "\n", - "Here's the simplest most minimal example with just a training loop (no validation, no testing).\n", - "\n", - "**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "V7ELesz1kVQo", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class MNISTModel(pl.LightningModule):\n", - "\n", - " def __init__(self):\n", - " super(MNISTModel, self).__init__()\n", - " self.l1 = torch.nn.Linear(28 * 28, 10)\n", - "\n", - " def forward(self, x):\n", - " return torch.relu(self.l1(x.view(x.size(0), -1)))\n", - "\n", - " def training_step(self, batch, batch_nb):\n", - " x, y = batch\n", - " loss = F.cross_entropy(self(x), y)\n", - " return loss\n", - "\n", - " def configure_optimizers(self):\n", - " return torch.optim.Adam(self.parameters(), lr=0.02)" - ], - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hIrtHg-Dv8TJ", - "colab_type": "text" - }, - "source": [ - "By using the `Trainer` you automatically get:\n", - "1. Tensorboard logging\n", - "2. Model checkpointing\n", - "3. Training and validation loop\n", - "4. early-stopping" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "4Dk6Ykv8lI7X", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Init our model\n", - "mnist_model = MNISTModel()\n", - "\n", - "# Init DataLoader from MNIST Dataset\n", - "train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n", - "train_loader = DataLoader(train_ds, batch_size=32)\n", - "\n", - "# Initialize a trainer\n", - "trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n", - "\n", - "# Train the model ⚡\n", - "trainer.fit(mnist_model, train_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KNpOoBeIjscS", - "colab_type": "text" - }, - "source": [ - "## A more complete MNIST Lightning Module Example\n", - "\n", - "That wasn't so hard was it?\n", - "\n", - "Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n", - "\n", - "This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n", - "\n", - "---\n", - "\n", - "### Note what the following built-in functions are doing:\n", - "\n", - "1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾\n", - " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", - " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", - "\n", - "2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#setup) ⚙️\n", - " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n", - " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", - " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n", - " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", - "\n", - "3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#data-hooks) ♻️\n", - " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "4DNItffri95Q", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class LitMNIST(pl.LightningModule):\n", - " \n", - " def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n", - "\n", - " super().__init__()\n", - "\n", - " # Set our init args as class attributes\n", - " self.data_dir = data_dir\n", - " self.hidden_size = hidden_size\n", - " self.learning_rate = learning_rate\n", - "\n", - " # Hardcode some dataset specific attributes\n", - " self.num_classes = 10\n", - " self.dims = (1, 28, 28)\n", - " channels, width, height = self.dims\n", - " self.transform = transforms.Compose([\n", - " transforms.ToTensor(),\n", - " transforms.Normalize((0.1307,), (0.3081,))\n", - " ])\n", - "\n", - " # Define PyTorch model\n", - " self.model = nn.Sequential(\n", - " nn.Flatten(),\n", - " nn.Linear(channels * width * height, hidden_size),\n", - " nn.ReLU(),\n", - " nn.Dropout(0.1),\n", - " nn.Linear(hidden_size, hidden_size),\n", - " nn.ReLU(),\n", - " nn.Dropout(0.1),\n", - " nn.Linear(hidden_size, self.num_classes)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " x = self.model(x)\n", - " return F.log_softmax(x, dim=1)\n", - "\n", - " def training_step(self, batch, batch_idx):\n", - " x, y = batch\n", - " logits = self(x)\n", - " loss = F.nll_loss(logits, y)\n", - " return loss\n", - "\n", - " def validation_step(self, batch, batch_idx):\n", - " x, y = batch\n", - " logits = self(x)\n", - " loss = F.nll_loss(logits, y)\n", - " preds = torch.argmax(logits, dim=1)\n", - " acc = accuracy(preds, y)\n", - "\n", - " # Calling self.log will surface up scalars for you in TensorBoard\n", - " self.log('val_loss', loss, prog_bar=True)\n", - " self.log('val_acc', acc, prog_bar=True)\n", - " return loss\n", - "\n", - " def test_step(self, batch, batch_idx):\n", - " # Here we just reuse the validation_step for testing\n", - " return self.validation_step(batch, batch_idx)\n", - "\n", - " def configure_optimizers(self):\n", - " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", - " return optimizer\n", - "\n", - " ####################\n", - " # DATA RELATED HOOKS\n", - " ####################\n", - "\n", - " def prepare_data(self):\n", - " # download\n", - " MNIST(self.data_dir, train=True, download=True)\n", - " MNIST(self.data_dir, train=False, download=True)\n", - "\n", - " def setup(self, stage=None):\n", - "\n", - " # Assign train/val datasets for use in dataloaders\n", - " if stage == 'fit' or stage is None:\n", - " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", - " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", - "\n", - " # Assign test dataset for use in dataloader(s)\n", - " if stage == 'test' or stage is None:\n", - " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", - "\n", - " def train_dataloader(self):\n", - " return DataLoader(self.mnist_train, batch_size=32)\n", - "\n", - " def val_dataloader(self):\n", - " return DataLoader(self.mnist_val, batch_size=32)\n", - "\n", - " def test_dataloader(self):\n", - " return DataLoader(self.mnist_test, batch_size=32)" - ], - "execution_count": 5, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "Mb0U5Rk2kLBy", - "colab_type": "code", - "colab": {} - }, - "source": [ - "model = LitMNIST()\n", - "trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n", - "trainer.fit(model)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nht8AvMptY6I", - "colab_type": "text" - }, - "source": [ - "### Testing\n", - "\n", - "To test a model, call `trainer.test(model)`.\n", - "\n", - "Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss)." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "PA151FkLtprO", - "colab_type": "code", - "colab": {} - }, - "source": [ - "trainer.test()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "T3-3lbbNtr5T", - "colab_type": "text" - }, - "source": [ - "### Bonus Tip\n", - "\n", - "You can keep calling `trainer.fit(model)` as many times as you'd like to continue training" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "IFBwCbLet2r6", - "colab_type": "code", - "colab": {} - }, - "source": [ - "trainer.fit(model)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8TRyS5CCt3n9", - "colab_type": "text" - }, - "source": [ - "In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "wizS-QiLuAYo", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Start tensorboard.\n", - "%load_ext tensorboard\n", - "%tensorboard --logdir lightning_logs/" - ], - "execution_count": null, - "outputs": [] - } - ] + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "i7XbLCXGkll9" + }, + "source": [ + "# Introduction to Pytorch Lightning ⚡\n", + "\n", + "In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n", + "\n", + "---\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "2LODD6w9ixlT" + }, + "source": [ + "### Setup \n", + "Lightning is easy to install. Simply ```pip install pytorch-lightning```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "zK7-Gg69kMnG" + }, + "outputs": [], + "source": [ + "! pip install pytorch-lightning --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "w4_TYnt_keJi" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from torch.nn import functional as F\n", + "from torch.utils.data import DataLoader, random_split\n", + "from torchvision.datasets import MNIST\n", + "from torchvision import transforms\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning.metrics.functional import accuracy" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EHpyMPKFkVbZ" + }, + "source": [ + "## Simplest example\n", + "\n", + "Here's the simplest most minimal example with just a training loop (no validation, no testing).\n", + "\n", + "**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "V7ELesz1kVQo" + }, + "outputs": [], + "source": [ + "class MNISTModel(pl.LightningModule):\n", + "\n", + " def __init__(self):\n", + " super(MNISTModel, self).__init__()\n", + " self.l1 = torch.nn.Linear(28 * 28, 10)\n", + "\n", + " def forward(self, x):\n", + " return torch.relu(self.l1(x.view(x.size(0), -1)))\n", + "\n", + " def training_step(self, batch, batch_nb):\n", + " x, y = batch\n", + " loss = F.cross_entropy(self(x), y)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.Adam(self.parameters(), lr=0.02)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "hIrtHg-Dv8TJ" + }, + "source": [ + "By using the `Trainer` you automatically get:\n", + "1. Tensorboard logging\n", + "2. Model checkpointing\n", + "3. Training and validation loop\n", + "4. early-stopping" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "4Dk6Ykv8lI7X" + }, + "outputs": [], + "source": [ + "# Init our model\n", + "mnist_model = MNISTModel()\n", + "\n", + "# Init DataLoader from MNIST Dataset\n", + "train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n", + "train_loader = DataLoader(train_ds, batch_size=32)\n", + "\n", + "# Initialize a trainer\n", + "trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n", + "\n", + "# Train the model ⚡\n", + "trainer.fit(mnist_model, train_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "KNpOoBeIjscS" + }, + "source": [ + "## A more complete MNIST Lightning Module Example\n", + "\n", + "That wasn't so hard was it?\n", + "\n", + "Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n", + "\n", + "This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n", + "\n", + "---\n", + "\n", + "### Note what the following built-in functions are doing:\n", + "\n", + "1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾\n", + " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", + " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", + "\n", + "2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#setup) ⚙️\n", + " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n", + " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", + " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n", + " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", + "\n", + "3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#data-hooks) ♻️\n", + " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "4DNItffri95Q" + }, + "outputs": [], + "source": [ + "class LitMNIST(pl.LightningModule):\n", + " \n", + " def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n", + "\n", + " super().__init__()\n", + "\n", + " # Set our init args as class attributes\n", + " self.data_dir = data_dir\n", + " self.hidden_size = hidden_size\n", + " self.learning_rate = learning_rate\n", + "\n", + " # Hardcode some dataset specific attributes\n", + " self.num_classes = 10\n", + " self.dims = (1, 28, 28)\n", + " channels, width, height = self.dims\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])\n", + "\n", + " # Define PyTorch model\n", + " self.model = nn.Sequential(\n", + " nn.Flatten(),\n", + " nn.Linear(channels * width * height, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, self.num_classes)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.model(x)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " acc = accuracy(preds, y)\n", + "\n", + " # Calling self.log will surface up scalars for you in TensorBoard\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log('val_acc', acc, prog_bar=True)\n", + " return loss\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " # Here we just reuse the validation_step for testing\n", + " return self.validation_step(batch, batch_idx)\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return optimizer\n", + "\n", + " ####################\n", + " # DATA RELATED HOOKS\n", + " ####################\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " MNIST(self.data_dir, train=True, download=True)\n", + " MNIST(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + "\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == 'fit' or stage is None:\n", + " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", + " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == 'test' or stage is None:\n", + " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.mnist_train, batch_size=32)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.mnist_val, batch_size=32)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.mnist_test, batch_size=32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Mb0U5Rk2kLBy" + }, + "outputs": [], + "source": [ + "model = LitMNIST()\n", + "trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n", + "trainer.fit(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "nht8AvMptY6I" + }, + "source": [ + "### Testing\n", + "\n", + "To test a model, call `trainer.test(model)`.\n", + "\n", + "Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "PA151FkLtprO" + }, + "outputs": [], + "source": [ + "trainer.test()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "T3-3lbbNtr5T" + }, + "source": [ + "### Bonus Tip\n", + "\n", + "You can keep calling `trainer.fit(model)` as many times as you'd like to continue training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "IFBwCbLet2r6" + }, + "outputs": [], + "source": [ + "trainer.fit(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8TRyS5CCt3n9" + }, + "source": [ + "In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "wizS-QiLuAYo" + }, + "outputs": [], + "source": [ + "# Start tensorboard.\n", + "%load_ext tensorboard\n", + "%tensorboard --logdir lightning_logs/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "authorship_tag": "ABX9TyOtAKVa5POQ6Xg3UcTQqXDJ", + "collapsed_sections": [], + "include_colab_link": true, + "name": "01-mnist-hello-world.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/notebooks/02-datamodules.ipynb b/notebooks/02-datamodules.ipynb index 3e027cd304c77..599cb1d6bd289 100644 --- a/notebooks/02-datamodules.ipynb +++ b/notebooks/02-datamodules.ipynb @@ -1,540 +1,588 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "02-datamodules.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2O5r7QvP8-rt", - "colab_type": "text" - }, - "source": [ - "# PyTorch Lightning DataModules ⚡\n", - "\n", - "With the release of `pytorch-lightning` version 0.9.0, we have included a new class called `LightningDataModule` to help you decouple data related hooks from your `LightningModule`.\n", - "\n", - "This notebook will walk you through how to start using Datamodules.\n", - "\n", - "The most up to date documentation on datamodules can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html).\n", - "\n", - "---\n", - "\n", - " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", - " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", - " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6RYMhmfA9ATN", - "colab_type": "text" - }, - "source": [ - "### Setup\n", - "Lightning is easy to install. Simply ```pip install pytorch-lightning```" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "lj2zD-wsbvGr", - "colab_type": "code", - "colab": {} - }, - "source": [ - "! pip install pytorch-lightning --quiet" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8g2mbvy-9xDI", - "colab_type": "text" - }, - "source": [ - "# Introduction\n", - "\n", - "First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "eg-xDlmDdAwy", - "colab_type": "code", - "colab": {} - }, - "source": [ - "import pytorch_lightning as pl\n", - "from pytorch_lightning.metrics.functional import accuracy\n", - "import torch\n", - "from torch import nn\n", - "import torch.nn.functional as F\n", - "from torch.utils.data import random_split, DataLoader\n", - "\n", - "# Note - you must have torchvision installed for this example\n", - "from torchvision.datasets import MNIST, CIFAR10\n", - "from torchvision import transforms" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DzgY7wi88UuG", - "colab_type": "text" - }, - "source": [ - "## Defining the LitMNISTModel\n", - "\n", - "Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.\n", - "\n", - "Unfortunately, we have hardcoded dataset-specific items within the model, forever limiting it to working with MNIST Data. 😢\n", - "\n", - "This is fine if you don't plan on training/evaluating your model on different datasets. However, in many cases, this can become bothersome when you want to try out your architecture with different datasets." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "IQkW8_FF5nU2", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class LitMNIST(pl.LightningModule):\n", - " \n", - " def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n", - "\n", - " super().__init__()\n", - "\n", - " # We hardcode dataset specific stuff here.\n", - " self.data_dir = data_dir\n", - " self.num_classes = 10\n", - " self.dims = (1, 28, 28)\n", - " channels, width, height = self.dims\n", - " self.transform = transforms.Compose([\n", - " transforms.ToTensor(),\n", - " transforms.Normalize((0.1307,), (0.3081,))\n", - " ])\n", - "\n", - " self.hidden_size = hidden_size\n", - " self.learning_rate = learning_rate\n", - "\n", - " # Build model\n", - " self.model = nn.Sequential(\n", - " nn.Flatten(),\n", - " nn.Linear(channels * width * height, hidden_size),\n", - " nn.ReLU(),\n", - " nn.Dropout(0.1),\n", - " nn.Linear(hidden_size, hidden_size),\n", - " nn.ReLU(),\n", - " nn.Dropout(0.1),\n", - " nn.Linear(hidden_size, self.num_classes)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " x = self.model(x)\n", - " return F.log_softmax(x, dim=1)\n", - "\n", - " def training_step(self, batch, batch_idx):\n", - " x, y = batch\n", - " logits = self(x)\n", - " loss = F.nll_loss(logits, y)\n", - " return loss\n", - "\n", - " def validation_step(self, batch, batch_idx):\n", - " x, y = batch\n", - " logits = self(x)\n", - " loss = F.nll_loss(logits, y)\n", - " preds = torch.argmax(logits, dim=1)\n", - " acc = accuracy(preds, y)\n", - " self.log('val_loss', loss, prog_bar=True)\n", - " self.log('val_acc', acc, prog_bar=True)\n", - " return loss\n", - "\n", - " def configure_optimizers(self):\n", - " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", - " return optimizer\n", - "\n", - " ####################\n", - " # DATA RELATED HOOKS\n", - " ####################\n", - "\n", - " def prepare_data(self):\n", - " # download\n", - " MNIST(self.data_dir, train=True, download=True)\n", - " MNIST(self.data_dir, train=False, download=True)\n", - "\n", - " def setup(self, stage=None):\n", - "\n", - " # Assign train/val datasets for use in dataloaders\n", - " if stage == 'fit' or stage is None:\n", - " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", - " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", - "\n", - " # Assign test dataset for use in dataloader(s)\n", - " if stage == 'test' or stage is None:\n", - " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", - "\n", - " def train_dataloader(self):\n", - " return DataLoader(self.mnist_train, batch_size=32)\n", - "\n", - " def val_dataloader(self):\n", - " return DataLoader(self.mnist_val, batch_size=32)\n", - "\n", - " def test_dataloader(self):\n", - " return DataLoader(self.mnist_test, batch_size=32)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "K7sg9KQd-QIO", - "colab_type": "text" - }, - "source": [ - "## Training the ListMNIST Model" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "QxDNDaus6byD", - "colab_type": "code", - "colab": {} - }, - "source": [ - "model = LitMNIST()\n", - "trainer = pl.Trainer(max_epochs=2, gpus=1, progress_bar_refresh_rate=20)\n", - "trainer.fit(model)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dY8d6GxmB0YU", - "colab_type": "text" - }, - "source": [ - "# Using DataModules\n", - "\n", - "DataModules are a way of decoupling data-related hooks from the `LightningModule` so you can develop dataset agnostic models." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eJeT5bW081wn", - "colab_type": "text" - }, - "source": [ - "## Defining The MNISTDataModule\n", - "\n", - "Let's go over each function in the class below and talk about what they're doing:\n", - "\n", - "1. ```__init__```\n", - " - Takes in a `data_dir` arg that points to where you have downloaded/wish to download the MNIST dataset.\n", - " - Defines a transform that will be applied across train, val, and test dataset splits.\n", - " - Defines default `self.dims`, which is a tuple returned from `datamodule.size()` that can help you initialize models.\n", - "\n", - "\n", - "2. ```prepare_data```\n", - " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", - " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", - "\n", - "3. ```setup```\n", - " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n", - " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", - " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage`.\n", - " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", - "\n", - "\n", - "4. ```x_dataloader```\n", - " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "DfGKyGwG_X9v", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class MNISTDataModule(pl.LightningDataModule):\n", - "\n", - " def __init__(self, data_dir: str = './'):\n", - " super().__init__()\n", - " self.data_dir = data_dir\n", - " self.transform = transforms.Compose([\n", - " transforms.ToTensor(),\n", - " transforms.Normalize((0.1307,), (0.3081,))\n", - " ])\n", - "\n", - " # self.dims is returned when you call dm.size()\n", - " # Setting default dims here because we know them.\n", - " # Could optionally be assigned dynamically in dm.setup()\n", - " self.dims = (1, 28, 28)\n", - " self.num_classes = 10\n", - "\n", - " def prepare_data(self):\n", - " # download\n", - " MNIST(self.data_dir, train=True, download=True)\n", - " MNIST(self.data_dir, train=False, download=True)\n", - "\n", - " def setup(self, stage=None):\n", - "\n", - " # Assign train/val datasets for use in dataloaders\n", - " if stage == 'fit' or stage is None:\n", - " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", - " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", - "\n", - " # Assign test dataset for use in dataloader(s)\n", - " if stage == 'test' or stage is None:\n", - " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", - "\n", - " def train_dataloader(self):\n", - " return DataLoader(self.mnist_train, batch_size=32)\n", - "\n", - " def val_dataloader(self):\n", - " return DataLoader(self.mnist_val, batch_size=32)\n", - "\n", - " def test_dataloader(self):\n", - " return DataLoader(self.mnist_test, batch_size=32)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "H2Yoj-9M9dS7", - "colab_type": "text" - }, - "source": [ - "## Defining the dataset agnostic `LitModel`\n", - "\n", - "Below, we define the same model as the `LitMNIST` model we made earlier. \n", - "\n", - "However, this time our model has the freedom to use any input data that we'd like 🔥." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "PM2IISuOBDIu", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class LitModel(pl.LightningModule):\n", - " \n", - " def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n", - "\n", - " super().__init__()\n", - "\n", - " # We take in input dimensions as parameters and use those to dynamically build model.\n", - " self.channels = channels\n", - " self.width = width\n", - " self.height = height\n", - " self.num_classes = num_classes\n", - " self.hidden_size = hidden_size\n", - " self.learning_rate = learning_rate\n", - "\n", - " self.model = nn.Sequential(\n", - " nn.Flatten(),\n", - " nn.Linear(channels * width * height, hidden_size),\n", - " nn.ReLU(),\n", - " nn.Dropout(0.1),\n", - " nn.Linear(hidden_size, hidden_size),\n", - " nn.ReLU(),\n", - " nn.Dropout(0.1),\n", - " nn.Linear(hidden_size, num_classes)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " x = self.model(x)\n", - " return F.log_softmax(x, dim=1)\n", - "\n", - " def training_step(self, batch, batch_idx):\n", - " x, y = batch\n", - " logits = self(x)\n", - " loss = F.nll_loss(logits, y)\n", - " return loss\n", - "\n", - " def validation_step(self, batch, batch_idx):\n", - "\n", - " x, y = batch\n", - " logits = self(x)\n", - " loss = F.nll_loss(logits, y)\n", - " preds = torch.argmax(logits, dim=1)\n", - " acc = accuracy(preds, y)\n", - " self.log('val_loss', loss, prog_bar=True)\n", - " self.log('val_acc', acc, prog_bar=True)\n", - " return loss\n", - "\n", - " def configure_optimizers(self):\n", - " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", - " return optimizer" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "G4Z5olPe-xEo", - "colab_type": "text" - }, - "source": [ - "## Training the `LitModel` using the `MNISTDataModule`\n", - "\n", - "Now, we initialize and train the `LitModel` using the `MNISTDataModule`'s configuration settings and dataloaders." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "kV48vP_9mEli", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Init DataModule\n", - "dm = MNISTDataModule()\n", - "# Init model from datamodule's attributes\n", - "model = LitModel(*dm.size(), dm.num_classes)\n", - "# Init trainer\n", - "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, gpus=1)\n", - "# Pass the datamodule as arg to trainer.fit to override model hooks :)\n", - "trainer.fit(model, dm)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WNxrugIGRRv5", - "colab_type": "text" - }, - "source": [ - "## Defining the CIFAR10 DataModule\n", - "\n", - "Lets prove the `LitModel` we made earlier is dataset agnostic by defining a new datamodule for the CIFAR10 dataset." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "1tkaYLU7RT5P", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class CIFAR10DataModule(pl.LightningDataModule):\n", - "\n", - " def __init__(self, data_dir: str = './'):\n", - " super().__init__()\n", - " self.data_dir = data_dir\n", - " self.transform = transforms.Compose([\n", - " transforms.ToTensor(),\n", - " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", - " ])\n", - "\n", - " self.dims = (3, 32, 32)\n", - " self.num_classes = 10\n", - "\n", - " def prepare_data(self):\n", - " # download\n", - " CIFAR10(self.data_dir, train=True, download=True)\n", - " CIFAR10(self.data_dir, train=False, download=True)\n", - "\n", - " def setup(self, stage=None):\n", - "\n", - " # Assign train/val datasets for use in dataloaders\n", - " if stage == 'fit' or stage is None:\n", - " cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)\n", - " self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])\n", - "\n", - " # Assign test dataset for use in dataloader(s)\n", - " if stage == 'test' or stage is None:\n", - " self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)\n", - "\n", - " def train_dataloader(self):\n", - " return DataLoader(self.cifar_train, batch_size=32)\n", - "\n", - " def val_dataloader(self):\n", - " return DataLoader(self.cifar_val, batch_size=32)\n", - "\n", - " def test_dataloader(self):\n", - " return DataLoader(self.cifar_test, batch_size=32)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BrXxf3oX_gsZ", - "colab_type": "text" - }, - "source": [ - "## Training the `LitModel` using the `CIFAR10DataModule`\n", - "\n", - "Our model isn't very good, so it will perform pretty badly on the CIFAR10 dataset.\n", - "\n", - "The point here is that we can see that our `LitModel` has no problem using a different datamodule as its input data." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "sd-SbWi_krdj", - "colab_type": "code", - "colab": {} - }, - "source": [ - "dm = CIFAR10DataModule()\n", - "model = LitModel(*dm.size(), dm.num_classes, hidden_size=256)\n", - "trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, gpus=1)\n", - "trainer.fit(model, dm)" - ], - "execution_count": null, - "outputs": [] - } - ] + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "2O5r7QvP8-rt" + }, + "source": [ + "# PyTorch Lightning DataModules ⚡\n", + "\n", + "With the release of `pytorch-lightning` version 0.9.0, we have included a new class called `LightningDataModule` to help you decouple data related hooks from your `LightningModule`.\n", + "\n", + "This notebook will walk you through how to start using Datamodules.\n", + "\n", + "The most up to date documentation on datamodules can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html).\n", + "\n", + "---\n", + "\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "6RYMhmfA9ATN" + }, + "source": [ + "### Setup\n", + "Lightning is easy to install. Simply ```pip install pytorch-lightning```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "lj2zD-wsbvGr" + }, + "outputs": [], + "source": [ + "! pip install pytorch-lightning --quiet" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8g2mbvy-9xDI" + }, + "source": [ + "# Introduction\n", + "\n", + "First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "eg-xDlmDdAwy" + }, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "from pytorch_lightning.metrics.functional import accuracy\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import random_split, DataLoader\n", + "\n", + "# Note - you must have torchvision installed for this example\n", + "from torchvision.datasets import MNIST, CIFAR10\n", + "from torchvision import transforms" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DzgY7wi88UuG" + }, + "source": [ + "## Defining the LitMNISTModel\n", + "\n", + "Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.\n", + "\n", + "Unfortunately, we have hardcoded dataset-specific items within the model, forever limiting it to working with MNIST Data. 😢\n", + "\n", + "This is fine if you don't plan on training/evaluating your model on different datasets. However, in many cases, this can become bothersome when you want to try out your architecture with different datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "IQkW8_FF5nU2" + }, + "outputs": [], + "source": [ + "class LitMNIST(pl.LightningModule):\n", + " \n", + " def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n", + "\n", + " super().__init__()\n", + "\n", + " # We hardcode dataset specific stuff here.\n", + " self.data_dir = data_dir\n", + " self.num_classes = 10\n", + " self.dims = (1, 28, 28)\n", + " channels, width, height = self.dims\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])\n", + "\n", + " self.hidden_size = hidden_size\n", + " self.learning_rate = learning_rate\n", + "\n", + " # Build model\n", + " self.model = nn.Sequential(\n", + " nn.Flatten(),\n", + " nn.Linear(channels * width * height, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, self.num_classes)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.model(x)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " acc = accuracy(preds, y)\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log('val_acc', acc, prog_bar=True)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return optimizer\n", + "\n", + " ####################\n", + " # DATA RELATED HOOKS\n", + " ####################\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " MNIST(self.data_dir, train=True, download=True)\n", + " MNIST(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + "\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == 'fit' or stage is None:\n", + " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", + " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == 'test' or stage is None:\n", + " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.mnist_train, batch_size=32)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.mnist_val, batch_size=32)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.mnist_test, batch_size=32)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "K7sg9KQd-QIO" + }, + "source": [ + "## Training the ListMNIST Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "QxDNDaus6byD" + }, + "outputs": [], + "source": [ + "model = LitMNIST()\n", + "trainer = pl.Trainer(max_epochs=2, gpus=1, progress_bar_refresh_rate=20)\n", + "trainer.fit(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "dY8d6GxmB0YU" + }, + "source": [ + "# Using DataModules\n", + "\n", + "DataModules are a way of decoupling data-related hooks from the `LightningModule` so you can develop dataset agnostic models." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eJeT5bW081wn" + }, + "source": [ + "## Defining The MNISTDataModule\n", + "\n", + "Let's go over each function in the class below and talk about what they're doing:\n", + "\n", + "1. ```__init__```\n", + " - Takes in a `data_dir` arg that points to where you have downloaded/wish to download the MNIST dataset.\n", + " - Defines a transform that will be applied across train, val, and test dataset splits.\n", + " - Defines default `self.dims`, which is a tuple returned from `datamodule.size()` that can help you initialize models.\n", + "\n", + "\n", + "2. ```prepare_data```\n", + " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", + " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", + "\n", + "3. ```setup```\n", + " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n", + " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", + " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage`.\n", + " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", + "\n", + "\n", + "4. ```x_dataloader```\n", + " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "DfGKyGwG_X9v" + }, + "outputs": [], + "source": [ + "class MNISTDataModule(pl.LightningDataModule):\n", + "\n", + " def __init__(self, data_dir: str = './'):\n", + " super().__init__()\n", + " self.data_dir = data_dir\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])\n", + "\n", + " # self.dims is returned when you call dm.size()\n", + " # Setting default dims here because we know them.\n", + " # Could optionally be assigned dynamically in dm.setup()\n", + " self.dims = (1, 28, 28)\n", + " self.num_classes = 10\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " MNIST(self.data_dir, train=True, download=True)\n", + " MNIST(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + "\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == 'fit' or stage is None:\n", + " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", + " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == 'test' or stage is None:\n", + " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.mnist_train, batch_size=32)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.mnist_val, batch_size=32)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.mnist_test, batch_size=32)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "H2Yoj-9M9dS7" + }, + "source": [ + "## Defining the dataset agnostic `LitModel`\n", + "\n", + "Below, we define the same model as the `LitMNIST` model we made earlier. \n", + "\n", + "However, this time our model has the freedom to use any input data that we'd like 🔥." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "PM2IISuOBDIu" + }, + "outputs": [], + "source": [ + "class LitModel(pl.LightningModule):\n", + " \n", + " def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n", + "\n", + " super().__init__()\n", + "\n", + " # We take in input dimensions as parameters and use those to dynamically build model.\n", + " self.channels = channels\n", + " self.width = width\n", + " self.height = height\n", + " self.num_classes = num_classes\n", + " self.hidden_size = hidden_size\n", + " self.learning_rate = learning_rate\n", + "\n", + " self.model = nn.Sequential(\n", + " nn.Flatten(),\n", + " nn.Linear(channels * width * height, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, num_classes)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.model(x)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + "\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " acc = accuracy(preds, y)\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log('val_acc', acc, prog_bar=True)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "G4Z5olPe-xEo" + }, + "source": [ + "## Training the `LitModel` using the `MNISTDataModule`\n", + "\n", + "Now, we initialize and train the `LitModel` using the `MNISTDataModule`'s configuration settings and dataloaders." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "kV48vP_9mEli" + }, + "outputs": [], + "source": [ + "# Init DataModule\n", + "dm = MNISTDataModule()\n", + "# Init model from datamodule's attributes\n", + "model = LitModel(*dm.size(), dm.num_classes)\n", + "# Init trainer\n", + "trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, gpus=1)\n", + "# Pass the datamodule as arg to trainer.fit to override model hooks :)\n", + "trainer.fit(model, dm)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "WNxrugIGRRv5" + }, + "source": [ + "## Defining the CIFAR10 DataModule\n", + "\n", + "Lets prove the `LitModel` we made earlier is dataset agnostic by defining a new datamodule for the CIFAR10 dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "1tkaYLU7RT5P" + }, + "outputs": [], + "source": [ + "class CIFAR10DataModule(pl.LightningDataModule):\n", + "\n", + " def __init__(self, data_dir: str = './'):\n", + " super().__init__()\n", + " self.data_dir = data_dir\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", + " ])\n", + "\n", + " self.dims = (3, 32, 32)\n", + " self.num_classes = 10\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " CIFAR10(self.data_dir, train=True, download=True)\n", + " CIFAR10(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + "\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == 'fit' or stage is None:\n", + " cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)\n", + " self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == 'test' or stage is None:\n", + " self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.cifar_train, batch_size=32)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.cifar_val, batch_size=32)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.cifar_test, batch_size=32)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BrXxf3oX_gsZ" + }, + "source": [ + "## Training the `LitModel` using the `CIFAR10DataModule`\n", + "\n", + "Our model isn't very good, so it will perform pretty badly on the CIFAR10 dataset.\n", + "\n", + "The point here is that we can see that our `LitModel` has no problem using a different datamodule as its input data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "sd-SbWi_krdj" + }, + "outputs": [], + "source": [ + "dm = CIFAR10DataModule()\n", + "model = LitModel(*dm.size(), dm.num_classes, hidden_size=256)\n", + "trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, gpus=1)\n", + "trainer.fit(model, dm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "include_colab_link": true, + "name": "02-datamodules.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/notebooks/03-basic-gan.ipynb b/notebooks/03-basic-gan.ipynb index a19153e133a5f..31555265938d8 100644 --- a/notebooks/03-basic-gan.ipynb +++ b/notebooks/03-basic-gan.ipynb @@ -1,424 +1,472 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "03-basic-gan.ipynb", - "provenance": [], - "collapsed_sections": [], - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "J37PBnE_x7IW", - "colab_type": "text" - }, - "source": [ - "# PyTorch Lightning Basic GAN Tutorial ⚡\n", - "\n", - "How to train a GAN!\n", - "\n", - "Main takeaways:\n", - "1. Generator and discriminator are arbitrary PyTorch modules.\n", - "2. training_step does both the generator and discriminator training.\n", - "\n", - "---\n", - "\n", - " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", - " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", - " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kg2MKpRmybht", - "colab_type": "text" - }, - "source": [ - "### Setup\n", - "Lightning is easy to install. Simply `pip install pytorch-lightning`" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "LfrJLKPFyhsK", - "colab_type": "code", - "colab": {} - }, - "source": [ - "! pip install pytorch-lightning --quiet" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "BjEPuiVLyanw", - "colab_type": "code", - "colab": {} - }, - "source": [ - "import os\n", - "from argparse import ArgumentParser\n", - "from collections import OrderedDict\n", - "\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torchvision\n", - "import torchvision.transforms as transforms\n", - "from torch.utils.data import DataLoader, random_split\n", - "from torchvision.datasets import MNIST\n", - "\n", - "import pytorch_lightning as pl" - ], - "execution_count": 2, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OuXJzr4G2uHV", - "colab_type": "text" - }, - "source": [ - "### MNIST DataModule\n", - "\n", - "Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html)." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "DOY_nHu328g7", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class MNISTDataModule(pl.LightningDataModule):\n", - "\n", - " def __init__(self, data_dir: str = './', batch_size: int = 64, num_workers: int = 8):\n", - " super().__init__()\n", - " self.data_dir = data_dir\n", - " self.batch_size = batch_size\n", - " self.num_workers = num_workers\n", - "\n", - " self.transform = transforms.Compose([\n", - " transforms.ToTensor(),\n", - " transforms.Normalize((0.1307,), (0.3081,))\n", - " ])\n", - "\n", - " # self.dims is returned when you call dm.size()\n", - " # Setting default dims here because we know them.\n", - " # Could optionally be assigned dynamically in dm.setup()\n", - " self.dims = (1, 28, 28)\n", - " self.num_classes = 10\n", - "\n", - " def prepare_data(self):\n", - " # download\n", - " MNIST(self.data_dir, train=True, download=True)\n", - " MNIST(self.data_dir, train=False, download=True)\n", - "\n", - " def setup(self, stage=None):\n", - "\n", - " # Assign train/val datasets for use in dataloaders\n", - " if stage == 'fit' or stage is None:\n", - " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", - " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", - "\n", - " # Assign test dataset for use in dataloader(s)\n", - " if stage == 'test' or stage is None:\n", - " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", - "\n", - " def train_dataloader(self):\n", - " return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)\n", - "\n", - " def val_dataloader(self):\n", - " return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)\n", - "\n", - " def test_dataloader(self):\n", - " return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)" - ], - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tW3c0QrQyF9P", - "colab_type": "text" - }, - "source": [ - "### A. Generator" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "0E2QDjl5yWtz", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class Generator(nn.Module):\n", - " def __init__(self, latent_dim, img_shape):\n", - " super().__init__()\n", - " self.img_shape = img_shape\n", - "\n", - " def block(in_feat, out_feat, normalize=True):\n", - " layers = [nn.Linear(in_feat, out_feat)]\n", - " if normalize:\n", - " layers.append(nn.BatchNorm1d(out_feat, 0.8))\n", - " layers.append(nn.LeakyReLU(0.2, inplace=True))\n", - " return layers\n", - "\n", - " self.model = nn.Sequential(\n", - " *block(latent_dim, 128, normalize=False),\n", - " *block(128, 256),\n", - " *block(256, 512),\n", - " *block(512, 1024),\n", - " nn.Linear(1024, int(np.prod(img_shape))),\n", - " nn.Tanh()\n", - " )\n", - "\n", - " def forward(self, z):\n", - " img = self.model(z)\n", - " img = img.view(img.size(0), *self.img_shape)\n", - " return img" - ], - "execution_count": 4, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uyrltsGvyaI3", - "colab_type": "text" - }, - "source": [ - "### B. Discriminator" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Ed3MR3vnyxyW", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class Discriminator(nn.Module):\n", - " def __init__(self, img_shape):\n", - " super().__init__()\n", - "\n", - " self.model = nn.Sequential(\n", - " nn.Linear(int(np.prod(img_shape)), 512),\n", - " nn.LeakyReLU(0.2, inplace=True),\n", - " nn.Linear(512, 256),\n", - " nn.LeakyReLU(0.2, inplace=True),\n", - " nn.Linear(256, 1),\n", - " nn.Sigmoid(),\n", - " )\n", - "\n", - " def forward(self, img):\n", - " img_flat = img.view(img.size(0), -1)\n", - " validity = self.model(img_flat)\n", - "\n", - " return validity" - ], - "execution_count": 5, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BwUMom3ryySK", - "colab_type": "text" - }, - "source": [ - "### C. GAN\n", - "\n", - "#### A couple of cool features to check out in this example...\n", - "\n", - " - We use `some_tensor.type_as(another_tensor)` to make sure we initialize new tensors on the right device (i.e. GPU, CPU).\n", - " - Lightning will put your dataloader data on the right device automatically\n", - " - In this example, we pull from latent dim on the fly, so we need to dynamically add tensors to the right device.\n", - " - `type_as` is the way we recommend to do this.\n", - " - This example shows how to use multiple dataloaders in your `LightningModule`." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "3vKszYf6y1Vv", - "colab_type": "code", - "colab": {} - }, - "source": [ - " class GAN(pl.LightningModule):\n", - "\n", - " def __init__(\n", - " self,\n", - " channels,\n", - " width,\n", - " height,\n", - " latent_dim: int = 100,\n", - " lr: float = 0.0002,\n", - " b1: float = 0.5,\n", - " b2: float = 0.999,\n", - " batch_size: int = 64,\n", - " **kwargs\n", - " ):\n", - " super().__init__()\n", - " self.save_hyperparameters()\n", - "\n", - " # networks\n", - " data_shape = (channels, width, height)\n", - " self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)\n", - " self.discriminator = Discriminator(img_shape=data_shape)\n", - "\n", - " self.validation_z = torch.randn(8, self.hparams.latent_dim)\n", - "\n", - " self.example_input_array = torch.zeros(2, self.hparams.latent_dim)\n", - "\n", - " def forward(self, z):\n", - " return self.generator(z)\n", - "\n", - " def adversarial_loss(self, y_hat, y):\n", - " return F.binary_cross_entropy(y_hat, y)\n", - "\n", - " def training_step(self, batch, batch_idx, optimizer_idx):\n", - " imgs, _ = batch\n", - "\n", - " # sample noise\n", - " z = torch.randn(imgs.shape[0], self.hparams.latent_dim)\n", - " z = z.type_as(imgs)\n", - "\n", - " # train generator\n", - " if optimizer_idx == 0:\n", - "\n", - " # generate images\n", - " self.generated_imgs = self(z)\n", - "\n", - " # log sampled images\n", - " sample_imgs = self.generated_imgs[:6]\n", - " grid = torchvision.utils.make_grid(sample_imgs)\n", - " self.logger.experiment.add_image('generated_images', grid, 0)\n", - "\n", - " # ground truth result (ie: all fake)\n", - " # put on GPU because we created this tensor inside training_loop\n", - " valid = torch.ones(imgs.size(0), 1)\n", - " valid = valid.type_as(imgs)\n", - "\n", - " # adversarial loss is binary cross-entropy\n", - " g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)\n", - " tqdm_dict = {'g_loss': g_loss}\n", - " output = OrderedDict({\n", - " 'loss': g_loss,\n", - " 'progress_bar': tqdm_dict,\n", - " 'log': tqdm_dict\n", - " })\n", - " return output\n", - "\n", - " # train discriminator\n", - " if optimizer_idx == 1:\n", - " # Measure discriminator's ability to classify real from generated samples\n", - "\n", - " # how well can it label as real?\n", - " valid = torch.ones(imgs.size(0), 1)\n", - " valid = valid.type_as(imgs)\n", - "\n", - " real_loss = self.adversarial_loss(self.discriminator(imgs), valid)\n", - "\n", - " # how well can it label as fake?\n", - " fake = torch.zeros(imgs.size(0), 1)\n", - " fake = fake.type_as(imgs)\n", - "\n", - " fake_loss = self.adversarial_loss(\n", - " self.discriminator(self(z).detach()), fake)\n", - "\n", - " # discriminator loss is the average of these\n", - " d_loss = (real_loss + fake_loss) / 2\n", - " tqdm_dict = {'d_loss': d_loss}\n", - " output = OrderedDict({\n", - " 'loss': d_loss,\n", - " 'progress_bar': tqdm_dict,\n", - " 'log': tqdm_dict\n", - " })\n", - " return output\n", - "\n", - " def configure_optimizers(self):\n", - " lr = self.hparams.lr\n", - " b1 = self.hparams.b1\n", - " b2 = self.hparams.b2\n", - "\n", - " opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))\n", - " opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))\n", - " return [opt_g, opt_d], []\n", - "\n", - " def on_epoch_end(self):\n", - " z = self.validation_z.type_as(self.generator.model[0].weight)\n", - "\n", - " # log sampled images\n", - " sample_imgs = self(z)\n", - " grid = torchvision.utils.make_grid(sample_imgs)\n", - " self.logger.experiment.add_image('generated_images', grid, self.current_epoch)" - ], - "execution_count": 6, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "Ey5FmJPnzm_E", - "colab_type": "code", - "colab": {} - }, - "source": [ - "dm = MNISTDataModule()\n", - "model = GAN(*dm.size())\n", - "trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)\n", - "trainer.fit(model, dm)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "MlECc7cHzolp", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Start tensorboard.\n", - "%load_ext tensorboard\n", - "%tensorboard --logdir lightning_logs/" - ], - "execution_count": null, - "outputs": [] - } - ] + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "J37PBnE_x7IW" + }, + "source": [ + "# PyTorch Lightning Basic GAN Tutorial ⚡\n", + "\n", + "How to train a GAN!\n", + "\n", + "Main takeaways:\n", + "1. Generator and discriminator are arbitrary PyTorch modules.\n", + "2. training_step does both the generator and discriminator training.\n", + "\n", + "---\n", + "\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kg2MKpRmybht" + }, + "source": [ + "### Setup\n", + "Lightning is easy to install. Simply `pip install pytorch-lightning`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "LfrJLKPFyhsK" + }, + "outputs": [], + "source": [ + "! pip install pytorch-lightning --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "BjEPuiVLyanw" + }, + "outputs": [], + "source": [ + "import os\n", + "from argparse import ArgumentParser\n", + "from collections import OrderedDict\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "from torch.utils.data import DataLoader, random_split\n", + "from torchvision.datasets import MNIST\n", + "\n", + "import pytorch_lightning as pl" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "OuXJzr4G2uHV" + }, + "source": [ + "### MNIST DataModule\n", + "\n", + "Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "DOY_nHu328g7" + }, + "outputs": [], + "source": [ + "class MNISTDataModule(pl.LightningDataModule):\n", + "\n", + " def __init__(self, data_dir: str = './', batch_size: int = 64, num_workers: int = 8):\n", + " super().__init__()\n", + " self.data_dir = data_dir\n", + " self.batch_size = batch_size\n", + " self.num_workers = num_workers\n", + "\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ])\n", + "\n", + " # self.dims is returned when you call dm.size()\n", + " # Setting default dims here because we know them.\n", + " # Could optionally be assigned dynamically in dm.setup()\n", + " self.dims = (1, 28, 28)\n", + " self.num_classes = 10\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " MNIST(self.data_dir, train=True, download=True)\n", + " MNIST(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + "\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == 'fit' or stage is None:\n", + " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", + " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == 'test' or stage is None:\n", + " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "tW3c0QrQyF9P" + }, + "source": [ + "### A. Generator" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "0E2QDjl5yWtz" + }, + "outputs": [], + "source": [ + "class Generator(nn.Module):\n", + " def __init__(self, latent_dim, img_shape):\n", + " super().__init__()\n", + " self.img_shape = img_shape\n", + "\n", + " def block(in_feat, out_feat, normalize=True):\n", + " layers = [nn.Linear(in_feat, out_feat)]\n", + " if normalize:\n", + " layers.append(nn.BatchNorm1d(out_feat, 0.8))\n", + " layers.append(nn.LeakyReLU(0.2, inplace=True))\n", + " return layers\n", + "\n", + " self.model = nn.Sequential(\n", + " *block(latent_dim, 128, normalize=False),\n", + " *block(128, 256),\n", + " *block(256, 512),\n", + " *block(512, 1024),\n", + " nn.Linear(1024, int(np.prod(img_shape))),\n", + " nn.Tanh()\n", + " )\n", + "\n", + " def forward(self, z):\n", + " img = self.model(z)\n", + " img = img.view(img.size(0), *self.img_shape)\n", + " return img" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "uyrltsGvyaI3" + }, + "source": [ + "### B. Discriminator" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ed3MR3vnyxyW" + }, + "outputs": [], + "source": [ + "class Discriminator(nn.Module):\n", + " def __init__(self, img_shape):\n", + " super().__init__()\n", + "\n", + " self.model = nn.Sequential(\n", + " nn.Linear(int(np.prod(img_shape)), 512),\n", + " nn.LeakyReLU(0.2, inplace=True),\n", + " nn.Linear(512, 256),\n", + " nn.LeakyReLU(0.2, inplace=True),\n", + " nn.Linear(256, 1),\n", + " nn.Sigmoid(),\n", + " )\n", + "\n", + " def forward(self, img):\n", + " img_flat = img.view(img.size(0), -1)\n", + " validity = self.model(img_flat)\n", + "\n", + " return validity" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BwUMom3ryySK" + }, + "source": [ + "### C. GAN\n", + "\n", + "#### A couple of cool features to check out in this example...\n", + "\n", + " - We use `some_tensor.type_as(another_tensor)` to make sure we initialize new tensors on the right device (i.e. GPU, CPU).\n", + " - Lightning will put your dataloader data on the right device automatically\n", + " - In this example, we pull from latent dim on the fly, so we need to dynamically add tensors to the right device.\n", + " - `type_as` is the way we recommend to do this.\n", + " - This example shows how to use multiple dataloaders in your `LightningModule`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3vKszYf6y1Vv" + }, + "outputs": [], + "source": [ + " class GAN(pl.LightningModule):\n", + "\n", + " def __init__(\n", + " self,\n", + " channels,\n", + " width,\n", + " height,\n", + " latent_dim: int = 100,\n", + " lr: float = 0.0002,\n", + " b1: float = 0.5,\n", + " b2: float = 0.999,\n", + " batch_size: int = 64,\n", + " **kwargs\n", + " ):\n", + " super().__init__()\n", + " self.save_hyperparameters()\n", + "\n", + " # networks\n", + " data_shape = (channels, width, height)\n", + " self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)\n", + " self.discriminator = Discriminator(img_shape=data_shape)\n", + "\n", + " self.validation_z = torch.randn(8, self.hparams.latent_dim)\n", + "\n", + " self.example_input_array = torch.zeros(2, self.hparams.latent_dim)\n", + "\n", + " def forward(self, z):\n", + " return self.generator(z)\n", + "\n", + " def adversarial_loss(self, y_hat, y):\n", + " return F.binary_cross_entropy(y_hat, y)\n", + "\n", + " def training_step(self, batch, batch_idx, optimizer_idx):\n", + " imgs, _ = batch\n", + "\n", + " # sample noise\n", + " z = torch.randn(imgs.shape[0], self.hparams.latent_dim)\n", + " z = z.type_as(imgs)\n", + "\n", + " # train generator\n", + " if optimizer_idx == 0:\n", + "\n", + " # generate images\n", + " self.generated_imgs = self(z)\n", + "\n", + " # log sampled images\n", + " sample_imgs = self.generated_imgs[:6]\n", + " grid = torchvision.utils.make_grid(sample_imgs)\n", + " self.logger.experiment.add_image('generated_images', grid, 0)\n", + "\n", + " # ground truth result (ie: all fake)\n", + " # put on GPU because we created this tensor inside training_loop\n", + " valid = torch.ones(imgs.size(0), 1)\n", + " valid = valid.type_as(imgs)\n", + "\n", + " # adversarial loss is binary cross-entropy\n", + " g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)\n", + " tqdm_dict = {'g_loss': g_loss}\n", + " output = OrderedDict({\n", + " 'loss': g_loss,\n", + " 'progress_bar': tqdm_dict,\n", + " 'log': tqdm_dict\n", + " })\n", + " return output\n", + "\n", + " # train discriminator\n", + " if optimizer_idx == 1:\n", + " # Measure discriminator's ability to classify real from generated samples\n", + "\n", + " # how well can it label as real?\n", + " valid = torch.ones(imgs.size(0), 1)\n", + " valid = valid.type_as(imgs)\n", + "\n", + " real_loss = self.adversarial_loss(self.discriminator(imgs), valid)\n", + "\n", + " # how well can it label as fake?\n", + " fake = torch.zeros(imgs.size(0), 1)\n", + " fake = fake.type_as(imgs)\n", + "\n", + " fake_loss = self.adversarial_loss(\n", + " self.discriminator(self(z).detach()), fake)\n", + "\n", + " # discriminator loss is the average of these\n", + " d_loss = (real_loss + fake_loss) / 2\n", + " tqdm_dict = {'d_loss': d_loss}\n", + " output = OrderedDict({\n", + " 'loss': d_loss,\n", + " 'progress_bar': tqdm_dict,\n", + " 'log': tqdm_dict\n", + " })\n", + " return output\n", + "\n", + " def configure_optimizers(self):\n", + " lr = self.hparams.lr\n", + " b1 = self.hparams.b1\n", + " b2 = self.hparams.b2\n", + "\n", + " opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))\n", + " opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))\n", + " return [opt_g, opt_d], []\n", + "\n", + " def on_epoch_end(self):\n", + " z = self.validation_z.type_as(self.generator.model[0].weight)\n", + "\n", + " # log sampled images\n", + " sample_imgs = self(z)\n", + " grid = torchvision.utils.make_grid(sample_imgs)\n", + " self.logger.experiment.add_image('generated_images', grid, self.current_epoch)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ey5FmJPnzm_E" + }, + "outputs": [], + "source": [ + "dm = MNISTDataModule()\n", + "model = GAN(*dm.size())\n", + "trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)\n", + "trainer.fit(model, dm)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "MlECc7cHzolp" + }, + "outputs": [], + "source": [ + "# Start tensorboard.\n", + "%load_ext tensorboard\n", + "%tensorboard --logdir lightning_logs/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "include_colab_link": true, + "name": "03-basic-gan.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/notebooks/04-transformers-text-classification.ipynb b/notebooks/04-transformers-text-classification.ipynb index ae7424c7d4864..037b24e4ddd9d 100644 --- a/notebooks/04-transformers-text-classification.ipynb +++ b/notebooks/04-transformers-text-classification.ipynb @@ -1,543 +1,591 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "04-transformers-text-classification.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8ag5ANQPJ_j9" + }, + "source": [ + "# Finetune 🤗 Transformers Models with PyTorch Lightning ⚡\n", + "\n", + "This notebook will use HuggingFace's `datasets` library to get data, which will be wrapped in a `LightningDataModule`. Then, we write a class to perform text classification on any dataset from the[ GLUE Benchmark](https://gluebenchmark.com/). (We just show CoLA and MRPC due to constraint on compute/disk)\n", + "\n", + "[HuggingFace's NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola) can help you get a feel for the two datasets we will use and what tasks they are solving for.\n", + "\n", + "---\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Ask a question on [the forum](https://forums.pytorchlightning.ai/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n", + "\n", + " - [HuggingFace datasets](https://github.com/huggingface/datasets)\n", + " - [HuggingFace transformers](https://github.com/huggingface/transformers)" + ] }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "8ag5ANQPJ_j9", - "colab_type": "text" - }, - "source": [ - "# Finetune 🤗 Transformers Models with PyTorch Lightning ⚡\n", - "\n", - "This notebook will use HuggingFace's `datasets` library to get data, which will be wrapped in a `LightningDataModule`. Then, we write a class to perform text classification on any dataset from the[ GLUE Benchmark](https://gluebenchmark.com/). (We just show CoLA and MRPC due to constraint on compute/disk)\n", - "\n", - "[HuggingFace's NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola) can help you get a feel for the two datasets we will use and what tasks they are solving for.\n", - "\n", - "---\n", - " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", - " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", - " - Ask a question on [the forum](https://forums.pytorchlightning.ai/)\n", - " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n", - "\n", - " - [HuggingFace datasets](https://github.com/huggingface/datasets)\n", - " - [HuggingFace transformers](https://github.com/huggingface/transformers)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fqlsVTj7McZ3", - "colab_type": "text" - }, - "source": [ - "### Setup" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "OIhHrRL-MnKK", - "colab_type": "code", - "colab": {} - }, - "source": [ - "!pip install pytorch-lightning datasets transformers" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "6yuQT_ZQMpCg", - "colab_type": "code", - "colab": {} - }, - "source": [ - "from argparse import ArgumentParser\n", - "from datetime import datetime\n", - "from typing import Optional\n", - "\n", - "import datasets\n", - "import numpy as np\n", - "import pytorch_lightning as pl\n", - "import torch\n", - "from torch.utils.data import DataLoader\n", - "from transformers import (\n", - " AdamW,\n", - " AutoModelForSequenceClassification,\n", - " AutoConfig,\n", - " AutoTokenizer,\n", - " get_linear_schedule_with_warmup,\n", - " glue_compute_metrics\n", - ")" - ], - "execution_count": 2, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9ORJfiuiNZ_N", - "colab_type": "text" - }, - "source": [ - "## GLUE DataModule" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "jW9xQhZxMz1G", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class GLUEDataModule(pl.LightningDataModule):\n", - "\n", - " task_text_field_map = {\n", - " 'cola': ['sentence'],\n", - " 'sst2': ['sentence'],\n", - " 'mrpc': ['sentence1', 'sentence2'],\n", - " 'qqp': ['question1', 'question2'],\n", - " 'stsb': ['sentence1', 'sentence2'],\n", - " 'mnli': ['premise', 'hypothesis'],\n", - " 'qnli': ['question', 'sentence'],\n", - " 'rte': ['sentence1', 'sentence2'],\n", - " 'wnli': ['sentence1', 'sentence2'],\n", - " 'ax': ['premise', 'hypothesis']\n", - " }\n", - "\n", - " glue_task_num_labels = {\n", - " 'cola': 2,\n", - " 'sst2': 2,\n", - " 'mrpc': 2,\n", - " 'qqp': 2,\n", - " 'stsb': 1,\n", - " 'mnli': 3,\n", - " 'qnli': 2,\n", - " 'rte': 2,\n", - " 'wnli': 2,\n", - " 'ax': 3\n", - " }\n", - "\n", - " loader_columns = [\n", - " 'datasets_idx',\n", - " 'input_ids',\n", - " 'token_type_ids',\n", - " 'attention_mask',\n", - " 'start_positions',\n", - " 'end_positions',\n", - " 'labels'\n", - " ]\n", - "\n", - " def __init__(\n", - " self,\n", - " model_name_or_path: str,\n", - " task_name: str ='mrpc',\n", - " max_seq_length: int = 128,\n", - " train_batch_size: int = 32,\n", - " eval_batch_size: int = 32,\n", - " **kwargs\n", - " ):\n", - " super().__init__()\n", - " self.model_name_or_path = model_name_or_path\n", - " self.task_name = task_name\n", - " self.max_seq_length = max_seq_length\n", - " self.train_batch_size = train_batch_size\n", - " self.eval_batch_size = eval_batch_size\n", - "\n", - " self.text_fields = self.task_text_field_map[task_name]\n", - " self.num_labels = self.glue_task_num_labels[task_name]\n", - " self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", - "\n", - " def setup(self, stage):\n", - " self.dataset = datasets.load_dataset('glue', self.task_name)\n", - "\n", - " for split in self.dataset.keys():\n", - " self.dataset[split] = self.dataset[split].map(\n", - " self.convert_to_features,\n", - " batched=True,\n", - " remove_columns=['label'],\n", - " )\n", - " self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]\n", - " self.dataset[split].set_format(type=\"torch\", columns=self.columns)\n", - "\n", - " self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]\n", - "\n", - " def prepare_data(self):\n", - " datasets.load_dataset('glue', self.task_name)\n", - " AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", - " \n", - " def train_dataloader(self):\n", - " return DataLoader(self.dataset['train'], batch_size=self.train_batch_size)\n", - " \n", - " def val_dataloader(self):\n", - " if len(self.eval_splits) == 1:\n", - " return DataLoader(self.dataset['validation'], batch_size=self.eval_batch_size)\n", - " elif len(self.eval_splits) > 1:\n", - " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n", - "\n", - " def test_dataloader(self):\n", - " if len(self.eval_splits) == 1:\n", - " return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size)\n", - " elif len(self.eval_splits) > 1:\n", - " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n", - "\n", - " def convert_to_features(self, example_batch, indices=None):\n", - "\n", - " # Either encode single sentence or sentence pairs\n", - " if len(self.text_fields) > 1:\n", - " texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))\n", - " else:\n", - " texts_or_text_pairs = example_batch[self.text_fields[0]]\n", - "\n", - " # Tokenize the text/text pairs\n", - " features = self.tokenizer.batch_encode_plus(\n", - " texts_or_text_pairs,\n", - " max_length=self.max_seq_length,\n", - " pad_to_max_length=True,\n", - " truncation=True\n", - " )\n", - "\n", - " # Rename label to labels to make it easier to pass to model forward\n", - " features['labels'] = example_batch['label']\n", - "\n", - " return features" - ], - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jQC3a6KuOpX3", - "colab_type": "text" - }, - "source": [ - "#### You could use this datamodule with standalone PyTorch if you wanted..." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "JCMH3IAsNffF", - "colab_type": "code", - "colab": {} - }, - "source": [ - "dm = GLUEDataModule('distilbert-base-uncased')\n", - "dm.prepare_data()\n", - "dm.setup('fit')\n", - "next(iter(dm.train_dataloader()))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "l9fQ_67BO2Lj", - "colab_type": "text" - }, - "source": [ - "## GLUE Model" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "gtn5YGKYO65B", - "colab_type": "code", - "colab": {} - }, - "source": [ - "class GLUETransformer(pl.LightningModule):\n", - " def __init__(\n", - " self,\n", - " model_name_or_path: str,\n", - " num_labels: int,\n", - " learning_rate: float = 2e-5,\n", - " adam_epsilon: float = 1e-8,\n", - " warmup_steps: int = 0,\n", - " weight_decay: float = 0.0,\n", - " train_batch_size: int = 32,\n", - " eval_batch_size: int = 32,\n", - " eval_splits: Optional[list] = None,\n", - " **kwargs\n", - " ):\n", - " super().__init__()\n", - "\n", - " self.save_hyperparameters()\n", - "\n", - " self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)\n", - " self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)\n", - " self.metric = datasets.load_metric(\n", - " 'glue',\n", - " self.hparams.task_name,\n", - " experiment_id=datetime.now().strftime(\"%d-%m-%Y_%H-%M-%S\")\n", - " )\n", - "\n", - " def forward(self, **inputs):\n", - " return self.model(**inputs)\n", - "\n", - " def training_step(self, batch, batch_idx):\n", - " outputs = self(**batch)\n", - " loss = outputs[0]\n", - " return loss\n", - "\n", - " def validation_step(self, batch, batch_idx, dataloader_idx=0):\n", - " outputs = self(**batch)\n", - " val_loss, logits = outputs[:2]\n", - "\n", - " if self.hparams.num_labels >= 1:\n", - " preds = torch.argmax(logits, axis=1)\n", - " elif self.hparams.num_labels == 1:\n", - " preds = logits.squeeze()\n", - "\n", - " labels = batch[\"labels\"]\n", - "\n", - " return {'loss': val_loss, \"preds\": preds, \"labels\": labels}\n", - "\n", - " def validation_epoch_end(self, outputs):\n", - " if self.hparams.task_name == 'mnli':\n", - " for i, output in enumerate(outputs):\n", - " # matched or mismatched\n", - " split = self.hparams.eval_splits[i].split('_')[-1]\n", - " preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n", - " labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n", - " loss = torch.stack([x['loss'] for x in output]).mean()\n", - " self.log(f'val_loss_{split}', loss, prog_bar=True)\n", - " split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()}\n", - " self.log_dict(split_metrics, prog_bar=True)\n", - " return loss\n", - "\n", - " preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n", - " labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n", - " loss = torch.stack([x['loss'] for x in outputs]).mean()\n", - " self.log('val_loss', loss, prog_bar=True)\n", - " self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n", - " return loss\n", - "\n", - " def setup(self, stage):\n", - " if stage == 'fit':\n", - " # Get dataloader by calling it - train_dataloader() is called after setup() by default\n", - " train_loader = self.train_dataloader()\n", - "\n", - " # Calculate total steps\n", - " self.total_steps = (\n", - " (len(train_loader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))\n", - " // self.hparams.accumulate_grad_batches\n", - " * float(self.hparams.max_epochs)\n", - " )\n", - "\n", - " def configure_optimizers(self):\n", - " \"Prepare optimizer and schedule (linear warmup and decay)\"\n", - " model = self.model\n", - " no_decay = [\"bias\", \"LayerNorm.weight\"]\n", - " optimizer_grouped_parameters = [\n", - " {\n", - " \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n", - " \"weight_decay\": self.hparams.weight_decay,\n", - " },\n", - " {\n", - " \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n", - " \"weight_decay\": 0.0,\n", - " },\n", - " ]\n", - " optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)\n", - "\n", - " scheduler = get_linear_schedule_with_warmup(\n", - " optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps\n", - " )\n", - " scheduler = {\n", - " 'scheduler': scheduler,\n", - " 'interval': 'step',\n", - " 'frequency': 1\n", - " }\n", - " return [optimizer], [scheduler]\n", - "\n", - " @staticmethod\n", - " def add_model_specific_args(parent_parser):\n", - " parser = ArgumentParser(parents=[parent_parser], add_help=False)\n", - " parser.add_argument(\"--learning_rate\", default=2e-5, type=float)\n", - " parser.add_argument(\"--adam_epsilon\", default=1e-8, type=float)\n", - " parser.add_argument(\"--warmup_steps\", default=0, type=int)\n", - " parser.add_argument(\"--weight_decay\", default=0.0, type=float)\n", - " return parser" - ], - "execution_count": 5, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ha-NdIP_xbd3", - "colab_type": "text" - }, - "source": [ - "### ⚡ Quick Tip \n", - " - Combine arguments from your DataModule, Model, and Trainer into one for easy and robust configuration" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "3dEHnl3RPlAR", - "colab_type": "code", - "colab": {} - }, - "source": [ - "def parse_args(args=None):\n", - " parser = ArgumentParser()\n", - " parser = pl.Trainer.add_argparse_args(parser)\n", - " parser = GLUEDataModule.add_argparse_args(parser)\n", - " parser = GLUETransformer.add_model_specific_args(parser)\n", - " parser.add_argument('--seed', type=int, default=42)\n", - " return parser.parse_args(args)\n", - "\n", - "\n", - "def main(args):\n", - " pl.seed_everything(args.seed)\n", - " dm = GLUEDataModule.from_argparse_args(args)\n", - " dm.prepare_data()\n", - " dm.setup('fit')\n", - " model = GLUETransformer(num_labels=dm.num_labels, eval_splits=dm.eval_splits, **vars(args))\n", - " trainer = pl.Trainer.from_argparse_args(args)\n", - " return dm, model, trainer" - ], - "execution_count": 6, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PkuLaeec3sJ-", - "colab_type": "text" - }, - "source": [ - "# Training" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QSpueK5UPsN7", - "colab_type": "text" - }, - "source": [ - "## CoLA\n", - "\n", - "See an interactive view of the CoLA dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola)" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "NJnFmtpnPu0Y", - "colab_type": "code", - "colab": {} - }, - "source": [ - "mocked_args = \"\"\"\n", - " --model_name_or_path albert-base-v2\n", - " --task_name cola\n", - " --max_epochs 3\n", - " --gpus 1\"\"\".split()\n", - "\n", - "args = parse_args(mocked_args)\n", - "dm, model, trainer = main(args)\n", - "trainer.fit(model, dm)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_MrNsTnqdz4z", - "colab_type": "text" - }, - "source": [ - "## MRPC\n", - "\n", - "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mrpc)" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "LBwRxg9Cb3d-", - "colab_type": "code", - "colab": {} - }, - "source": [ - "mocked_args = \"\"\"\n", - " --model_name_or_path distilbert-base-cased\n", - " --task_name mrpc\n", - " --max_epochs 3\n", - " --gpus 1\"\"\".split()\n", - "\n", - "args = parse_args(mocked_args)\n", - "dm, model, trainer = main(args)\n", - "trainer.fit(model, dm)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "iZhbn0HzfdCu", - "colab_type": "text" - }, - "source": [ - "## MNLI\n", - "\n", - " - The MNLI dataset is huge, so we aren't going to bother trying to train it here.\n", - "\n", - " - Let's just make sure our multi-dataloader logic is right by skipping over training and going straight to validation.\n", - "\n", - "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mnli)" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "AvsZMOggfcWW", - "colab_type": "code", - "colab": {} - }, - "source": [ - "mocked_args = \"\"\"\n", - " --model_name_or_path distilbert-base-uncased\n", - " --task_name mnli\n", - " --max_epochs 1\n", - " --gpus 1\n", - " --limit_train_batches 10\n", - " --progress_bar_refresh_rate 20\"\"\".split()\n", - "\n", - "args = parse_args(mocked_args)\n", - "dm, model, trainer = main(args)\n", - "trainer.fit(model, dm)" - ], - "execution_count": null, - "outputs": [] - } - ] -} \ No newline at end of file + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "fqlsVTj7McZ3" + }, + "source": [ + "### Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "OIhHrRL-MnKK" + }, + "outputs": [], + "source": [ + "!pip install pytorch-lightning datasets transformers" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "6yuQT_ZQMpCg" + }, + "outputs": [], + "source": [ + "from argparse import ArgumentParser\n", + "from datetime import datetime\n", + "from typing import Optional\n", + "\n", + "import datasets\n", + "import numpy as np\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "from transformers import (\n", + " AdamW,\n", + " AutoModelForSequenceClassification,\n", + " AutoConfig,\n", + " AutoTokenizer,\n", + " get_linear_schedule_with_warmup,\n", + " glue_compute_metrics\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9ORJfiuiNZ_N" + }, + "source": [ + "## GLUE DataModule" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "jW9xQhZxMz1G" + }, + "outputs": [], + "source": [ + "class GLUEDataModule(pl.LightningDataModule):\n", + "\n", + " task_text_field_map = {\n", + " 'cola': ['sentence'],\n", + " 'sst2': ['sentence'],\n", + " 'mrpc': ['sentence1', 'sentence2'],\n", + " 'qqp': ['question1', 'question2'],\n", + " 'stsb': ['sentence1', 'sentence2'],\n", + " 'mnli': ['premise', 'hypothesis'],\n", + " 'qnli': ['question', 'sentence'],\n", + " 'rte': ['sentence1', 'sentence2'],\n", + " 'wnli': ['sentence1', 'sentence2'],\n", + " 'ax': ['premise', 'hypothesis']\n", + " }\n", + "\n", + " glue_task_num_labels = {\n", + " 'cola': 2,\n", + " 'sst2': 2,\n", + " 'mrpc': 2,\n", + " 'qqp': 2,\n", + " 'stsb': 1,\n", + " 'mnli': 3,\n", + " 'qnli': 2,\n", + " 'rte': 2,\n", + " 'wnli': 2,\n", + " 'ax': 3\n", + " }\n", + "\n", + " loader_columns = [\n", + " 'datasets_idx',\n", + " 'input_ids',\n", + " 'token_type_ids',\n", + " 'attention_mask',\n", + " 'start_positions',\n", + " 'end_positions',\n", + " 'labels'\n", + " ]\n", + "\n", + " def __init__(\n", + " self,\n", + " model_name_or_path: str,\n", + " task_name: str ='mrpc',\n", + " max_seq_length: int = 128,\n", + " train_batch_size: int = 32,\n", + " eval_batch_size: int = 32,\n", + " **kwargs\n", + " ):\n", + " super().__init__()\n", + " self.model_name_or_path = model_name_or_path\n", + " self.task_name = task_name\n", + " self.max_seq_length = max_seq_length\n", + " self.train_batch_size = train_batch_size\n", + " self.eval_batch_size = eval_batch_size\n", + "\n", + " self.text_fields = self.task_text_field_map[task_name]\n", + " self.num_labels = self.glue_task_num_labels[task_name]\n", + " self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", + "\n", + " def setup(self, stage):\n", + " self.dataset = datasets.load_dataset('glue', self.task_name)\n", + "\n", + " for split in self.dataset.keys():\n", + " self.dataset[split] = self.dataset[split].map(\n", + " self.convert_to_features,\n", + " batched=True,\n", + " remove_columns=['label'],\n", + " )\n", + " self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]\n", + " self.dataset[split].set_format(type=\"torch\", columns=self.columns)\n", + "\n", + " self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]\n", + "\n", + " def prepare_data(self):\n", + " datasets.load_dataset('glue', self.task_name)\n", + " AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", + " \n", + " def train_dataloader(self):\n", + " return DataLoader(self.dataset['train'], batch_size=self.train_batch_size)\n", + " \n", + " def val_dataloader(self):\n", + " if len(self.eval_splits) == 1:\n", + " return DataLoader(self.dataset['validation'], batch_size=self.eval_batch_size)\n", + " elif len(self.eval_splits) > 1:\n", + " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n", + "\n", + " def test_dataloader(self):\n", + " if len(self.eval_splits) == 1:\n", + " return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size)\n", + " elif len(self.eval_splits) > 1:\n", + " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n", + "\n", + " def convert_to_features(self, example_batch, indices=None):\n", + "\n", + " # Either encode single sentence or sentence pairs\n", + " if len(self.text_fields) > 1:\n", + " texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))\n", + " else:\n", + " texts_or_text_pairs = example_batch[self.text_fields[0]]\n", + "\n", + " # Tokenize the text/text pairs\n", + " features = self.tokenizer.batch_encode_plus(\n", + " texts_or_text_pairs,\n", + " max_length=self.max_seq_length,\n", + " pad_to_max_length=True,\n", + " truncation=True\n", + " )\n", + "\n", + " # Rename label to labels to make it easier to pass to model forward\n", + " features['labels'] = example_batch['label']\n", + "\n", + " return features" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "jQC3a6KuOpX3" + }, + "source": [ + "#### You could use this datamodule with standalone PyTorch if you wanted..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "JCMH3IAsNffF" + }, + "outputs": [], + "source": [ + "dm = GLUEDataModule('distilbert-base-uncased')\n", + "dm.prepare_data()\n", + "dm.setup('fit')\n", + "next(iter(dm.train_dataloader()))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "l9fQ_67BO2Lj" + }, + "source": [ + "## GLUE Model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "gtn5YGKYO65B" + }, + "outputs": [], + "source": [ + "class GLUETransformer(pl.LightningModule):\n", + " def __init__(\n", + " self,\n", + " model_name_or_path: str,\n", + " num_labels: int,\n", + " learning_rate: float = 2e-5,\n", + " adam_epsilon: float = 1e-8,\n", + " warmup_steps: int = 0,\n", + " weight_decay: float = 0.0,\n", + " train_batch_size: int = 32,\n", + " eval_batch_size: int = 32,\n", + " eval_splits: Optional[list] = None,\n", + " **kwargs\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.save_hyperparameters()\n", + "\n", + " self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)\n", + " self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)\n", + " self.metric = datasets.load_metric(\n", + " 'glue',\n", + " self.hparams.task_name,\n", + " experiment_id=datetime.now().strftime(\"%d-%m-%Y_%H-%M-%S\")\n", + " )\n", + "\n", + " def forward(self, **inputs):\n", + " return self.model(**inputs)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " outputs = self(**batch)\n", + " loss = outputs[0]\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx, dataloader_idx=0):\n", + " outputs = self(**batch)\n", + " val_loss, logits = outputs[:2]\n", + "\n", + " if self.hparams.num_labels >= 1:\n", + " preds = torch.argmax(logits, axis=1)\n", + " elif self.hparams.num_labels == 1:\n", + " preds = logits.squeeze()\n", + "\n", + " labels = batch[\"labels\"]\n", + "\n", + " return {'loss': val_loss, \"preds\": preds, \"labels\": labels}\n", + "\n", + " def validation_epoch_end(self, outputs):\n", + " if self.hparams.task_name == 'mnli':\n", + " for i, output in enumerate(outputs):\n", + " # matched or mismatched\n", + " split = self.hparams.eval_splits[i].split('_')[-1]\n", + " preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n", + " labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n", + " loss = torch.stack([x['loss'] for x in output]).mean()\n", + " self.log(f'val_loss_{split}', loss, prog_bar=True)\n", + " split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()}\n", + " self.log_dict(split_metrics, prog_bar=True)\n", + " return loss\n", + "\n", + " preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n", + " labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n", + " loss = torch.stack([x['loss'] for x in outputs]).mean()\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n", + " return loss\n", + "\n", + " def setup(self, stage):\n", + " if stage == 'fit':\n", + " # Get dataloader by calling it - train_dataloader() is called after setup() by default\n", + " train_loader = self.train_dataloader()\n", + "\n", + " # Calculate total steps\n", + " self.total_steps = (\n", + " (len(train_loader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))\n", + " // self.hparams.accumulate_grad_batches\n", + " * float(self.hparams.max_epochs)\n", + " )\n", + "\n", + " def configure_optimizers(self):\n", + " \"Prepare optimizer and schedule (linear warmup and decay)\"\n", + " model = self.model\n", + " no_decay = [\"bias\", \"LayerNorm.weight\"]\n", + " optimizer_grouped_parameters = [\n", + " {\n", + " \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n", + " \"weight_decay\": self.hparams.weight_decay,\n", + " },\n", + " {\n", + " \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n", + " \"weight_decay\": 0.0,\n", + " },\n", + " ]\n", + " optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)\n", + "\n", + " scheduler = get_linear_schedule_with_warmup(\n", + " optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps\n", + " )\n", + " scheduler = {\n", + " 'scheduler': scheduler,\n", + " 'interval': 'step',\n", + " 'frequency': 1\n", + " }\n", + " return [optimizer], [scheduler]\n", + "\n", + " @staticmethod\n", + " def add_model_specific_args(parent_parser):\n", + " parser = ArgumentParser(parents=[parent_parser], add_help=False)\n", + " parser.add_argument(\"--learning_rate\", default=2e-5, type=float)\n", + " parser.add_argument(\"--adam_epsilon\", default=1e-8, type=float)\n", + " parser.add_argument(\"--warmup_steps\", default=0, type=int)\n", + " parser.add_argument(\"--weight_decay\", default=0.0, type=float)\n", + " return parser" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ha-NdIP_xbd3" + }, + "source": [ + "### ⚡ Quick Tip \n", + " - Combine arguments from your DataModule, Model, and Trainer into one for easy and robust configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3dEHnl3RPlAR" + }, + "outputs": [], + "source": [ + "def parse_args(args=None):\n", + " parser = ArgumentParser()\n", + " parser = pl.Trainer.add_argparse_args(parser)\n", + " parser = GLUEDataModule.add_argparse_args(parser)\n", + " parser = GLUETransformer.add_model_specific_args(parser)\n", + " parser.add_argument('--seed', type=int, default=42)\n", + " return parser.parse_args(args)\n", + "\n", + "\n", + "def main(args):\n", + " pl.seed_everything(args.seed)\n", + " dm = GLUEDataModule.from_argparse_args(args)\n", + " dm.prepare_data()\n", + " dm.setup('fit')\n", + " model = GLUETransformer(num_labels=dm.num_labels, eval_splits=dm.eval_splits, **vars(args))\n", + " trainer = pl.Trainer.from_argparse_args(args)\n", + " return dm, model, trainer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PkuLaeec3sJ-" + }, + "source": [ + "# Training" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QSpueK5UPsN7" + }, + "source": [ + "## CoLA\n", + "\n", + "See an interactive view of the CoLA dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "NJnFmtpnPu0Y" + }, + "outputs": [], + "source": [ + "mocked_args = \"\"\"\n", + " --model_name_or_path albert-base-v2\n", + " --task_name cola\n", + " --max_epochs 3\n", + " --gpus 1\"\"\".split()\n", + "\n", + "args = parse_args(mocked_args)\n", + "dm, model, trainer = main(args)\n", + "trainer.fit(model, dm)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "_MrNsTnqdz4z" + }, + "source": [ + "## MRPC\n", + "\n", + "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mrpc)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "LBwRxg9Cb3d-" + }, + "outputs": [], + "source": [ + "mocked_args = \"\"\"\n", + " --model_name_or_path distilbert-base-cased\n", + " --task_name mrpc\n", + " --max_epochs 3\n", + " --gpus 1\"\"\".split()\n", + "\n", + "args = parse_args(mocked_args)\n", + "dm, model, trainer = main(args)\n", + "trainer.fit(model, dm)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "iZhbn0HzfdCu" + }, + "source": [ + "## MNLI\n", + "\n", + " - The MNLI dataset is huge, so we aren't going to bother trying to train it here.\n", + "\n", + " - Let's just make sure our multi-dataloader logic is right by skipping over training and going straight to validation.\n", + "\n", + "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mnli)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "AvsZMOggfcWW" + }, + "outputs": [], + "source": [ + "mocked_args = \"\"\"\n", + " --model_name_or_path distilbert-base-uncased\n", + " --task_name mnli\n", + " --max_epochs 1\n", + " --gpus 1\n", + " --limit_train_batches 10\n", + " --progress_bar_refresh_rate 20\"\"\".split()\n", + "\n", + "args = parse_args(mocked_args)\n", + "dm, model, trainer = main(args)\n", + "trainer.fit(model, dm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "04-transformers-text-classification.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/05-trainer-flags-overview.ipynb b/notebooks/05-trainer-flags-overview.ipynb index 4589887ecb986..f1f93104f4552 100644 --- a/notebooks/05-trainer-flags-overview.ipynb +++ b/notebooks/05-trainer-flags-overview.ipynb @@ -1,2871 +1,2919 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "05-trainer-flags-overview.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "goRmGIRI5cfC" - }, - "source": [ - "# Introduction to Lightning Flags ⚡🚩\n", - "\n", - "In this notebook, we'll go over the flags available in the `Trainer` object. Note that not everything will work in the Colab environment (multi-gpu, etc). This notebook accompanies the Trainer videos we'll be putting out.\n", - "\n", - "---\n", - " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", - " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", - " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jKj5lgdr5j48" - }, - "source": [ - "--- \n", - "### Setup \n", - "First thing first, we need to install Lightning. Simply ```pip install pytorch-lightning```" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "UGjilEHk4vb7" - }, - "source": [ - "! pip install pytorch-lightning" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "zaVUShmQ5n8Y" - }, - "source": [ - "import os\n", - "\n", - "from argparse import ArgumentParser\n", - "import torch\n", - "from torch import nn\n", - "from torch.nn import functional as F\n", - "from torch.utils.data import DataLoader\n", - "from torch.utils.data import random_split\n", - "from torchvision.datasets import MNIST\n", - "from torchvision import transforms\n", - "import pytorch_lightning as pl\n", - "from pytorch_lightning.metrics.functional import accuracy\n", - "\n", - "from torchvision.datasets.mnist import MNIST\n", - "from torchvision import transforms" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "6tgkS8IYZwY_" - }, - "source": [ - "# ------------\n", - "# data\n", - "# ------------\n", - "pl.seed_everything(1234)\n", - "batch_size = 32\n", - "\n", - "# Init DataLoader from MNIST Dataset\n", - "\n", - "dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n", - "mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())\n", - "mnist_train, mnist_val = random_split(dataset, [55000, 5000])\n", - "\n", - "train_loader = DataLoader(mnist_train, batch_size=batch_size)\n", - "val_loader = DataLoader(mnist_val, batch_size=batch_size)\n", - "test_loader = DataLoader(mnist_test, batch_size=batch_size)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gEulmrbxwaYL" - }, - "source": [ - "### Simple AutoEncoder Model\n", - "\n", - "Were gonna define a simple Lightning model so we can play with all the settings of the Lightning Trainer.\n", - "\n", - "LightningModule is simply pure Pytorch reorganized into hooks, that represents all the steps in the training process.\n", - "\n", - "You can use LightningModule hooks to control every part of your model, but for the purpose of this video we will use a very simple MNIST classifier, a model that takes 28*28 grayscale images of hand written images, and can predict the digit between 0-9.\n", - "\n", - "The LightningModule can encompass a single model, like an image classifier, or a deep learning system composed of multiple models, like this auto encoder that contains an encoder and a decoder.\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "x-34xKCI40yW" - }, - "source": [ - "class LitAutoEncoder(pl.LightningModule):\n", - "\n", - " def __init__(self, batch_size=32, lr=1e-3):\n", - " super().__init__()\n", - " self.encoder = nn.Sequential(\n", - " nn.Linear(28 * 28, 64),\n", - " nn.ReLU(),\n", - " nn.Linear(64, 3)\n", - " )\n", - " self.decoder = nn.Sequential(\n", - " nn.Linear(3, 64),\n", - " nn.ReLU(),\n", - " nn.Linear(64, 28 * 28)\n", - " )\n", - " self.batch_size=batch_size\n", - " self.learning_rate=lr\n", - "\n", - " def forward(self, x):\n", - " # in lightning, forward defines the prediction/inference actions\n", - " embedding = self.encoder(x)\n", - " return embedding\n", - "\n", - " def training_step(self, batch, batch_idx):\n", - " x, y = batch\n", - " x = x.view(x.size(0), -1)\n", - " z = self.encoder(x)\n", - " x_hat = self.decoder(z)\n", - " loss = F.mse_loss(x_hat, x)\n", - " self.log('train_loss', loss)\n", - " return loss\n", - "\n", - " def validation_step(self, batch, batch_idx):\n", - " x, y = batch\n", - " x = x.view(x.size(0), -1)\n", - " z = self.encoder(x)\n", - " x_hat = self.decoder(z)\n", - " loss = F.mse_loss(x_hat, x)\n", - " self.log('val_loss', loss)\n", - " \n", - " def test_step(self, batch, batch_idx):\n", - " x, y = batch\n", - " x = x.view(x.size(0), -1)\n", - " z = self.encoder(x)\n", - " x_hat = self.decoder(z)\n", - " loss = F.mse_loss(x_hat, x)\n", - " self.log('test_loss', loss)\n", - "\n", - " def configure_optimizers(self):\n", - " optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n", - " return optimizer" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VbxcRCrxiYly" - }, - "source": [ - "You'll notice the LightningModule doesn't have epoch and batch loops, we're not calling model.train() and model.eval(), and no mentions of CUDA or hardware. That's because it is all automated by the Lightning Trainer. All the engineering boilerplate is automated by the trainer: \n", - "\n", - "* Training loops\n", - "* Evaluation and test loops\n", - "* Calling model.train(), model.eval(), no_grad at the right time\n", - "* CUDA or to_device calls\n", - "\n", - "It also allows you to train your models on different hardware like GPUs and TPUs without changing your code!\n", - "\n", - "\n", - "### To use the lightning trainer simply:\n", - "\n", - "1. init your LightningModule and datasets\n", - "\n", - "2. init lightning trainer\n", - "\n", - "3. call trainer.fit\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "HOk9c4_35FKg" - }, - "source": [ - "#####################\n", - "# 1. Init Model\n", - "#####################\n", - "\n", - "model = LitAutoEncoder()\n", - "\n", - "#####################\n", - "# 2. Init Trainer\n", - "#####################\n", - "\n", - "# these 2 flags are explained in the later sections...but for short explanation:\n", - "# - progress_bar_refresh_rate: limits refresh rate of tqdm progress bar so Colab doesn't freak out\n", - "# - max_epochs: only run 2 epochs instead of default of 1000\n", - "trainer = pl.Trainer(progress_bar_refresh_rate=20, max_epochs=2)\n", - "\n", - "#####################\n", - "# 3. Train\n", - "#####################\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3meDako-Qa_6" - }, - "source": [ - "Our model is training just like that, using the Lightning defaults. The beauty of Lightning is that everything is easily configurable.\n", - "In our next videos were going to show you all the ways you can control your Trainer to do things like controlling your training, validation and test loops, running on GPUs and TPUs, checkpointing, early stopping, and a lot more.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "z_Wry2MckQkI" - }, - "source": [ - "# Training loop and eval loop Flags" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0MkI1xB2vsLj" - }, - "source": [ - "\n", - "To really scale up your networks, you can use accelerators like GPUs. GPUs or Graphical Processing Units, parallelize matrix multiplications which enable speed ups of at least 100x over training on CPUs.\n", - "\n", - "Let's say you have a machine with 8 GPUs on it. You can set this flag to 1, 4, or 8 GPUs and lightning will automatically distribute your training for you.\n", - "\n", - "```\n", - "trainer = pl.Trainer(gpus=1)\n", - "```\n", - "\n", - "---------\n", - "\n", - "Lightning makes your code hardware agnostic... This means, you can switch between CPUs, GPUs without code changes.\n", - "\n", - "However, it requires forming good PyTorch habits:\n", - "\n", - "1. First, remove the .cuda() or .to() calls in your code.\n", - "2. Second, when you initialize a new tensor, set the device=self.device in the call since every lightningModule knows what gpu index or TPU core it is on.\n", - "\n", - "You can also use type_as and or you can register the tensor as a buffer in your module’s __init__ method with register_buffer().\n", - "\n", - "```\n", - "# before lightning\n", - "def forward(self, x):\n", - " z = torch.Tensor(2, 3)\n", - " z = z.cuda(0)\n", - "\n", - "# with lightning\n", - "def forward(self, x):\n", - " z = torch.Tensor(2, 3)\n", - " z = z.type_as(x, device=self.device)\n", - "```\n", - "\n", - "\n", - "```\n", - "class LitModel(LightningModule):\n", - "\n", - " def __init__(self):\n", - " ...\n", - " self.register_buffer(\"sigma\", torch.eye(3))\n", - " # you can now access self.sigma anywhere in your module\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hw6jJhhjvlSL" - }, - "source": [ - "Lightning Trainer automates all the engineering boilerplate like iterating over epochs and batches, training eval and test loops, CUDA and to(device) calls, calling model.train and model.eval.\n", - "\n", - "You still have full control over the loops, by using the following trainer flags:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pT5-ETH9eUg6" - }, - "source": [ - "## Calling validation steps\n", - "Sometimes, training an epoch may be pretty fast, like minutes per epoch. In this case, you might not need to validate on every epoch. Instead, you can actually validate after a few epochs.\n", - "\n", - "Use `check_val_every_n_epoch` flag to control the frequency of validation step:" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Z-EMVvKheu3D" - }, - "source": [ - "# run val loop every 10 training epochs\n", - "trainer = pl.Trainer(check_val_every_n_epoch=10)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UOzZr9S2UcSO" - }, - "source": [ - "## val_check_interval\n", - "\n", - "In some cases where your epoch is very long, you might want to check validation within an epoch.\n", - "\n", - "You can also run validation step within your training epochs, by setting `val_check_interval` flag.\n", - "\n", - "Set `val_check_interval` to a float between [0.0 to 1.0] to check your validation set within a training epoch. For example, setting it to 0.25 will check your validation set 4 times during a training epoch.\n", - "\n", - "Default is set to 1.0" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "9kbUbvrUVLrT" - }, - "source": [ - "# check validation set 4 times during a training epoch\n", - "trainer = pl.Trainer(val_check_interval=0.25)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Onm1gBsKVaw4" - }, - "source": [ - "When you have iterable data sets, or when streaming data for production use cases, it is useful to check the validation set every number of steps. \n", - "Set val_check_interval to an int:" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "psn6DVb5Vi85" - }, - "source": [ - "# check validation set every 1000 training batches\n", - "# use this when using iterableDataset and your dataset has no length\n", - "# (ie: production cases with streaming data)\n", - "trainer = pl.Trainer(val_check_interval=1000)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QkoYonrWkb7-" - }, - "source": [ - "## num_sanity_val_steps \n", - "\n", - "You may have run into an issue, where you have a bug in your validation loop, but won't catch it until your training loop ends.\n", - "\n", - "and if your training loop takes hours or days, you will waste valuable compute.\n", - "\n", - "Instead, lightning automatically runs through 2 steps of validation in the beginning to catch these kinds of bugs up front.\n", - "\n", - "\n", - "The `num_sanity_val_steps` flag can help you run n batches of validation before starting the training routine.\n", - "\n", - "You can set it to 0 to turn it off" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "zOcT-ugSkiKW" - }, - "source": [ - "# turn it off\n", - "trainer = pl.Trainer(num_sanity_val_steps=0)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zS0ob1ZmTw56" - }, - "source": [ - "Set it to -1 to check all validation data before training" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "rzqvjA4UT263" - }, - "source": [ - "# check all validation data\n", - "trainer = pl.Trainer(num_sanity_val_steps=-1)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uMB41wq4T3Z2" - }, - "source": [ - "Or use any arbitrary number of validation steps" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "lGP78aQzT7VS" - }, - "source": [ - "trainer = pl.Trainer(num_sanity_val_steps=10)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "H-xaYRtd1rb-" - }, - "source": [ - "## Limit train, validation, and test batches\n", - "\n", - "You can set limits on how much of training, validation and test dataset you want your model to check. This is useful if you have really large validation or tests sets, for debugging or testing something that happens at the end of an epoch.\n", - "\n", - "Set the flag to int to specify the number of batches to run\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "XiK5cFKL1rcA" - }, - "source": [ - "# run for only 10 batches\n", - "trainer = pl.Trainer(limit_test_batches=10)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Y4LK0g65RrBm" - }, - "source": [ - "For example, some metrics need to be computed on the entire validation results, such as AUC ROC. " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8MmeRs2DR3dD" - }, - "source": [ - "trainer = pl.Trainer(limit_val_batches=10)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xmigcNa1A2Vy" - }, - "source": [ - "You can use a float to limit the batches be percentage of the set on every epoch" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "W7uGJt8nA4tv" - }, - "source": [ - "# run through only 25% of the test set each epoch\n", - "trainer = pl.Trainer(limit_test_batches=0.25)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YRI8THtUN7_e" - }, - "source": [ - "# Training on GPUs\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "R8FFkX_FwlfE" - }, - "source": [ - "To run on 1 GPU set the flag to 1" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Nnzkf3KaOE27" - }, - "source": [ - "trainer = pl.Trainer(gpus=1)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cxBg47s5PB1P" - }, - "source": [ - "to run on 2 or 4 GPUs, set the flag to 2 or 4." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "cSEM4ihLrohT" - }, - "source": [ - "trainer = pl.Trainer(gpus=2)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZE6ZgwtNudro" - }, - "source": [ - "You can also select which GPU devices to run on, using a list of indices like [1, 4] \n", - "\n", - "or a string containing a comma separated list of GPU ids like '1,2'\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "gQkJtq0urrjq" - }, - "source": [ - "# list: train on GPUs 1, 4 (by bus ordering)\n", - "# trainer = Trainer(gpus='1, 4') # equivalent\n", - "trainer = pl.Trainer(gpus=[1, 4])\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "XghDPad4us74" - }, - "source": [ - "trainer = pl.Trainer(gpus=list(range(4)))\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6FVkKHpSPMTW" - }, - "source": [ - "You can use all the GPUs you have available by setting `gpus=-1`" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "r6cKQijYrtPe" - }, - "source": [ - "# trainer = Trainer(gpus='-1') - equivalent\n", - "trainer = pl.Trainer(gpus=-1)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2C-fNLm3UGCV" - }, - "source": [ - "Lightning uses the PCI bus_id as the index for ordering GPUs." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_V75s7EhOFhE" - }, - "source": [ - "### `auto_select_gpus`\n", - "\n", - "You can save on GPUs by running in “exclusive mode”, meaning only one process at a time can access them. If your not sure which GPUs you should use when running exclusive mode, Lightning can automatically find unoccupied GPUs for you. \n", - "\n", - "Simply specify the number of gpus as an integer `gpus=k`, and set the trainer flag `auto_select_gpus=True`. Lightning will automatically help you find k gpus that are not occupied by other processes." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "_Sd3XFsAOIwd" - }, - "source": [ - "# enable auto selection (will find two available gpus on system)\n", - "trainer = pl.Trainer(gpus=2, auto_select_gpus=True)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "a5JGSBMQhJNp" - }, - "source": [ - "## analyzing GPU usage\n", - "\n", - "### log_gpu_memory\n", - "\n", - "This is useful to analyze the memory usage of your GPUs.\n", - "\n", - "To get the GPU memory usage for every GPU on the master node, set the flag to log_gpu_memory=all.\n", - "\n", - "Under the hood, lightning uses the nvidia-smi command which may slow your training down.\n", - "\n", - "Your logs can become overwhelmed if you log the usage from many GPUs at once. In this case, you can also set the flag to min_max which will log only the min and max usage across all the GPUs of the master node.\n", - "\n", - "Note that lightning is not logging the usage across all nodes for performance reasons." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "idus3ZGahOki" - }, - "source": [ - "# log all the GPUs (on master node only)\n", - "trainer = Trainer(log_gpu_memory='all')\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-mevgiy_hkip" - }, - "source": [ - "To avoid the performance decrease you can also set `log_gpu_memory=min_max` to only log the min and max memory on the master node.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "SlvLJnWyhs7J" - }, - "source": [ - "# log only the min and max memory on the master node\n", - "trainer = Trainer(log_gpu_memory='min_max')\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "K82FLLIJVQG3" - }, - "source": [ - "\n", - "But what if you want to train on multiple machines and not just one?" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YViQ6PXesAue" - }, - "source": [ - "# Training on multiple GPUs" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WacbBQUivxQq" - }, - "source": [ - "Lightning makes your models hardware agnostic, and you can run on GPUs with a flip of a flag. Lightning also supports training on multiple GPUs across many machines.\n", - "\n", - "You can do this by setting the num_nodes flag.\n", - "\n", - "The world size, or the total number of GPUs you are using, will be gpus*num_nodes.\n", - "\n", - "If i set gpus=8 and num_nodes=32 then I will be training on 256 GPUs." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "5iKckmDvr8zZ" - }, - "source": [ - "trainer = pl.Trainer(gpus=8, num_nodes=32)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GgcSbDjjlSTh" - }, - "source": [ - "## distributed backends\n", - "\n", - "Under the hood, Lightning uses distributed data parallel (or DDP) by default to distribute training across GPUs.\n", - "\n", - "This Lightning implementation of DDP calls your script under the hood multiple times with the correct environment variables.\n", - "\n", - "Under the hood it's as if you had called your script like this:\n", - "\n", - "1. Each GPU across each node gets its own process.\n", - "2. Each GPU gets visibility into a subset of the overall dataset. It will only ever see that subset.\n", - "3. Each process inits the model. (Make sure to set the random seed so that each model initializes with the same weights.)\n", - "4. Each process performs a full forward and backward pass in parallel.\n", - "5. The gradients are synced and averaged across all processes.\n", - "6. Each process updates its optimizer.\n", - "If you request multiple GPUs or nodes without setting a mode, DDP will be automatically used.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "n_Brr7F5wdtj" - }, - "source": [ - "# ddp = DistributedDataParallel\n", - "# trainer = pl.Trainer(gpus=2, num_nodes=2) equivalent\n", - "trainer = pl.Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "edxHyttC5J3e" - }, - "source": [ - "DDP is the fastest and recommended way to distribute your training, but you can pass in other backends to `distributed_backend` trainer flag, when DDP is not supported.\n", - "\n", - "DDP isn't available in\n", - "* Jupyter Notebook, Google COLAB, Kaggle, etc.\n", - "* If You have a nested script without a root package\n", - "* or if Your script needs to invoke .fit or .test multiple times" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZDh96mavxHxf" - }, - "source": [ - "### DDP_SPAWN\n", - "\n", - "In these cases, you can use `ddp_spawn` instead. `ddp_spawn` is exactly like DDP except that it uses `.spawn()` to start the training processes." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "JM5TKtgLxo37" - }, - "source": [ - "trainer = pl.Trainer(gpus=2, num_nodes=2, distributed_backend='ddp_spawn')\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sebhVE3qrhKK" - }, - "source": [ - "We STRONGLY discourage this use because it has limitations (due to Python and PyTorch):\n", - "\n", - "* Since .spawn() trains the model in subprocesses, the model on the main process does not get updated.\n", - "\n", - "* Dataloader(num_workers=N), where N is large, bottlenecks training with DDP… ie: it will be VERY slow or won’t work at all. This is a PyTorch limitation.\n", - "\n", - "* Forces everything to be picklable.\n", - "\n", - "DDP is MUCH faster than DDP_spawn. To be able to use DDP we recommend you: \n", - "\n", - "1. Install a top-level module for your project using setup.py\n", - "\n", - "```\n", - "# setup.py\n", - "#!/usr/bin/env python\n", - "\n", - "from setuptools import setup, find_packages\n", - "\n", - "setup(name='src',\n", - " version='0.0.1',\n", - " description='Describe Your Cool Project',\n", - " author='',\n", - " author_email='',\n", - " url='https://github.com/YourSeed', # REPLACE WITH YOUR OWN GITHUB PROJECT LINK\n", - " install_requires=[\n", - " 'pytorch-lightning'\n", - " ],\n", - " packages=find_packages()\n", - " )\n", - "\n", - "```\n", - "\n", - "2. Setup your project like so:\n", - "\n", - "```\n", - "/project\n", - " /src\n", - " some_file.py\n", - " /or_a_folder\n", - " setup.py\n", - "```\n", - "3. Install as a root-level package\n", - "```\n", - "cd /project\n", - "pip install -e .\n", - "```\n", - "4. You can then call your scripts anywhere\n", - "```\n", - "cd /project/src\n", - "\n", - "python some_file.py --distributed_backend 'ddp' --gpus 8\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cmB3I_oyw7a8" - }, - "source": [ - "### DP\n", - "\n", - "If you're using windows, DDP is not supported. You can use `dp` for DataParallel instead: DataParallel uses multithreading, instead of multiprocessing. It splits a batch across k GPUs. That is, if you have a batch of 32 and use DP with 2 gpus, each GPU will process 16 samples, after which the root node will aggregate the results.\n", - "\n", - "DP use is discouraged by PyTorch and Lightning. Use DDP which is more stable and at least 3x faster.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "OO-J0ISvlVCg" - }, - "source": [ - "# dp = DataParallel\n", - "trainer = pl.Trainer(gpus=2, distributed_backend='dp')\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Y7E2eHZKwUn9" - }, - "source": [ - "### DDP2\n", - "\n", - "In certain cases, it’s advantageous to use ***all*** batches on the same machine, instead of a subset. For instance, in self-supervised learning, a common performance boost comes from increasing the number of negative samples.\n", - "\n", - "In this case, we can use DDP2 which behaves like DP in a machine and DDP across nodes. DDP2 does the following:\n", - "\n", - "* Copies a subset of the data to each node.\n", - "* Inits a model on each node.\n", - "* Runs a forward and backward pass using DP.\n", - "* Syncs gradients across nodes.\n", - "* Applies the optimizer updates.\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Y4xweqL3xHER" - }, - "source": [ - "# ddp2 = DistributedDataParallel + dp\n", - "trainer = pl.Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lhKNCnveeeq5" - }, - "source": [ - "- The second mode is ddp_spawn. This works like ddp, but instead of calling your script multiple times, lightning will use multiprocessing spawn to start a subprocess per GPU. \n", - "\n", - "However, you should be careful of mixing this mode with num_workers > 0 in your dataloaders because it will bottleneck your training. This is a current known limitation of PyTorch which is why we recommend using our ddp implementation instead.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HUf9ANyQkFFO" - }, - "source": [ - "\n", - "### mocking ddp\n", - "\n", - "Testing or debugging DDP can be hard, so we have a distributed backend that simulates ddp on cpus to make it easier. Set `num_processes` to a number greater than 1 when using distributed_backend=\"ddp_cpu\" to mimic distributed training on a machine without GPUs. Note that while this is useful for debugging, it will not provide any speedup, since single-process Torch already makes efficient use of multiple CPUs." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ZSal5Da9kHOf" - }, - "source": [ - "# Simulate DDP for debugging on your GPU-less laptop\n", - "trainer = Trainer(distributed_backend=\"ddp_cpu\", num_processes=2)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Br_btCy5lgES" - }, - "source": [ - "# Training on TPUS\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DXkBNITdv44d" - }, - "source": [ - "Another option for accelerating your training is using TPUs.\n", - "A TPU is a Tensor processing unit, designed specifically for deep learning. Each TPU has 8 cores where each core is optimized for 128x128 matrix multiplies. Google estimates that 8 TPU cores are about as fast as 4 V100 GPUs!\n", - "\n", - "A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores! You can request a full pod from Google cloud or a “slice” which gives you some subset of those 2048 cores.\n", - "\n", - "At this moment, TPUs are available on Google Cloud (GCP), Google Colab and Kaggle Environments.\n", - "\n", - "Lightning supports training on TPUs without any code adjustments to your model. Just like when using GPUs, Lightning automatically inserts the correct samplers - no need to do this yourself!\n", - "\n", - "Under the hood, lightning uses the XLA framework developed jointly by the facebook and google XLA teams. And we want to recognize their efforts in advancing TPU adoption of PyTorch.\n", - "\n", - "## tpu_cores\n", - "To train on TPUs, set the tpu_cores flag.\n", - "\n", - "When using colab or kaggle, the allowed values are 1 or 8 cores. When using google cloud, any value above 8 is allowed.\n", - "\n", - "Your effective batch size is the batch size passed into a dataloader times the total number of tpu cores." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "itP9y70gmD9M" - }, - "source": [ - "# int: train on a single core\n", - "trainer = pl.Trainer(tpu_cores=1)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "NJKnzPb3mKEg" - }, - "source": [ - "# int: train on all cores few cores\n", - "trainer = pl.Trainer(tpu_cores=8)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8a4exfWUmOHq" - }, - "source": [ - "You can also choose which TPU core to train on, by passing a list [1-8]. This is not an officially supported use case but we are working with the XLA team to improve this user experience.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "S6OrjE_bmT-_" - }, - "source": [ - "# list: train on a single selected core\n", - "trainer = pl.Trainer(tpu_cores=[2])\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Afqx3sFUmfWD" - }, - "source": [ - "To train on more than 8 cores (ie: a POD), submit this script using the xla_dist script.\n", - "\n", - "\n", - "\n", - "```\n", - "python -m torch_xla.distributed.xla_dist\n", - "--tpu=$TPU_POD_NAME\n", - "--conda-env=torch-xla-nightly\n", - "--env=XLA_USE_BF16=1\n", - "-- python your_trainer_file.py\n", - "```\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ncPvbUVQqKOh" - }, - "source": [ - "# Advanced distributed training\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4MP7bEgnv7qK" - }, - "source": [ - "\n", - "Lightning supports distributed training across multiple GPUs and TPUs out of the box by setting trainer flags, but it also allows you to control the way sampling is done if you need to." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wdHiTfAMepKH" - }, - "source": [ - "## replace_sampler_ddp\n", - "In PyTorch, you must use torch.nn.DistributedSampler for multi-node or GPU training. The sampler makes sure each GPU sees the appropriate part of your data.\n", - "\n", - "```\n", - "# without lightning\n", - "def train_dataloader(self):\n", - " dataset = MNIST(...)\n", - " sampler = None\n", - "\n", - " if self.on_tpu:\n", - " sampler = DistributedSampler(dataset)\n", - "\n", - " return DataLoader(dataset, sampler=sampler)\n", - "```\n", - "Lightning adds the correct samplers when needed, so no need to explicitly add samplers. By default it will add `shuffle=True` for train sampler and `shuffle=False` for val/test sampler.\n", - "\n", - "If you want to customize this behaviour, you can set `replace_sampler_ddp=False` and add your own distributed sampler.\n", - "\n", - "(note: For iterable datasets, we don’t do this automatically.)\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ZfmcB_e_7HbE" - }, - "source": [ - "sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False)\n", - "dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)\n", - "\n", - "trainer = pl.Trainer(gpus=2, num_nodes=2, replace_sampler_ddp=False)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-IOhk1n0lL3_" - }, - "source": [ - "## prepare_data_per_node\n", - "\n", - "When doing multi NODE training, if your nodes share the same file system, then you don't want to download data more than once to avoid possible collisions. \n", - "\n", - "Lightning automatically calls the prepare_data hook on the root GPU of the master node (ie: only a single GPU).\n", - "\n", - "In some cases where your nodes don't share the same file system, you need to download the data on each node. In this case you can set this flag to true and lightning will download the data on the root GPU of each node.\n", - "\n", - "This flag is defaulted to True." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "WFBMUR48lM04" - }, - "source": [ - "trainer = pl.Trainer(gpus=2, num_nodes=2, prepare_data_per_node=False)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FKBwXqo4q-Vp" - }, - "source": [ - "## sync_batchnorm\n", - "\n", - "Batch norm is computed per GPU/TPU. This flag enables synchronization between batchnorm layers across all GPUs.\n", - "It is recommended if you have small batch sizes.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "GhaCLTEZrAQi" - }, - "source": [ - "trainer = Trainer(gpus=4, sync_batchnorm=True)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XuFA7VTFMY9-" - }, - "source": [ - "# Debugging flags\n", - "\n", - "Lightning offers a couple of flags to make debugging your models easier:\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AKoS3fdml4Jx" - }, - "source": [ - "## Fast Dev Run\n", - "\n", - "To help you save time debugging, your first run should use the fast_dev_run flag.\n", - "\n", - "This won't generate logs or save checkpoints but will touch every line of your code to make sure that it is working as intended.\n", - "\n", - "Think about this flag like a compiler. You make changes to your code, and run Trainer with this flag to verify that your changes are bug free.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "L5vuG7GSmhzK" - }, - "source": [ - "trainer = pl.Trainer(fast_dev_run=True)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HRP1qQR5nT4p" - }, - "source": [ - "## overfit_batches\n", - "\n", - "Uses this much data of the training set. If nonzero, will use the same training set for validation and testing. If the training dataloaders have shuffle=True, Lightning will automatically disable it.\n", - "\n", - "Useful for quickly debugging or trying to overfit on purpose." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "NTM-dqGMnXms" - }, - "source": [ - "# use only 1% of the train set (and use the train set for val and test)\n", - "trainer = pl.Trainer(overfit_batches=0.01)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "c0LV0gC3nl1X" - }, - "source": [ - "# overfit on 10 of the same batches\n", - "trainer = pl.Trainer(overfit_batches=10)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lt3UHU6WgtS_" - }, - "source": [ - "Or a float to represent percentage of data to run" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "K3yUqADhgnkf" - }, - "source": [ - "# run through only 25% of the test set each epoch\n", - "trainer = pl.Trainer(limit_test_batches=0.25)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ODN66NeVg_2o" - }, - "source": [ - "In the case of multiple test dataloaders, the limit applies to each dataloader individually.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8aQx5SLeMz1R" - }, - "source": [ - "# accumulate_grad_batches\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "g8GczZXFwKC7" - }, - "source": [ - "The batch size controls the accuracy of the estimate of the gradients. Small batch size use less memory, but decrease accuracy. When training large models, such as NLP transformers, it is useful to accumulate gradients before calling backwards(). It allows for bigger batch sizes than what can actually fit on a GPU/TPU in a single step.\n", - "\n", - "Use accumulate_grad_batches to accumulate gradients every k batches or as set up in the dict. Trainer also calls optimizer.step() for the last indivisible step number.\n", - "\n", - "For example, set accumulate_grad_batches to 4 to accumulate every 4 batches. In this case the effective batch size is batch_size*4, so if your batch size is 32, effectively it will be 128." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "2jB6-Z_yPhhf" - }, - "source": [ - "# accumulate every 4 batches (effective batch size is batch*4)\n", - "trainer = pl.Trainer(accumulate_grad_batches=4)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_Yi-bdTOgINC" - }, - "source": [ - "You can also pass a dictionary to specify different accumulation per epoch. We can set it to `{5: 3, 10: 20}` to have no accumulation for epochs 1 to 4, accumulate 3 batches for epoch 5 to 10, and 20 batches after that." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "X3xsoZ3YPgBv" - }, - "source": [ - "# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that\n", - "trainer = pl.Trainer(accumulate_grad_batches={5: 3, 10: 20})\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "myzH8mV4M1_9" - }, - "source": [ - "# 16 bit precision\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "v9EaFAonwOk6" - }, - "source": [ - "Most deep learning frameworks like PyTorch, train with 32-bit floating point arithmetic. \n", - "\n", - "But many models can still achieve full accuracy using half the precision.\n", - "\n", - "In 2017, NVIDIA researchers successfully used a combination of 32 and 16 bit precision (also known as mixed precision) and achieved the same accuracy as 32 bit precision training.\n", - "\n", - "The main two advantages are:\n", - "\n", - "- a reduction in memory requirements which enables larger batch sizes and models.\n", - "- and a speed up in compute. On ampere, turing and volta architectures 16 bit precision models can train at least 3 times faster.\n", - "\n", - "As of PyTorch 1.6, NVIDIA and Facebook moved mixed precision functionality into PyTorch core as the AMP package, torch.cuda.amp. \n", - "\n", - "This package supersedes the apex package developed by NVIDIA." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TjNypZPHnxvJ" - }, - "source": [ - "## precision\n", - "\n", - "Use precision flag to switch between full precision (32) to half precision (16). Can be used on CPU, GPU or TPUs.\n", - "\n", - "When using PyTorch 1.6+ Lightning uses the native amp implementation to support 16-bit.\n", - "\n", - "If used on TPU will use torch.bfloat16 but tensor printing will still show torch.float32" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "kBZKMVx1nw-D" - }, - "source": [ - "# 16-bit precision\n", - "trainer = pl.Trainer(gpus=1, precision=16)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VJGj3Jh7oQXU" - }, - "source": [ - "In earlier version of Lightning, we use NVIDIA Apex for 16-bit precision. Apex was the first library to attempt 16-bit and the automatic mixed precision library (amp), has since been merged into core PyTorch as of 1.6.\n", - "\n", - "If you insist in using Apex, you can set the amp_backend flag to 'apex' and install Apex on your own." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "BDV1trAUPc9h" - }, - "source": [ - "trainer = pl.Trainer(gpus=1, precision=16, amp_backend='apex')\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HK5c_aVfNV4e" - }, - "source": [ - "## amp_level\n", - "Apex includes 4 optimization levels:\n", - "O0 (FP32 training)\n", - "O1 (Conservative Mixed Precision): only some whitelist ops are done in FP16.\n", - "O2 (Fast Mixed Precision): this is the standard mixed precision training. It maintains FP32 master weights and optimizer.step acts directly on the FP32 master weights.\n", - "O3 (FP16 training): full FP16. Passing keep_batchnorm_fp32=True can speed things up as cudnn batchnorm is faster anyway.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "FshMFPowNbWt" - }, - "source": [ - "# default used by the Trainer\n", - "trainer = pl.Trainer(gpus=1, precision=16, amp_backend='apex', amp_level='O2')\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "y8KEr1YvNgkC" - }, - "source": [ - "# `auto_scale_batch_size`\n", - "\n", - " \n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7F1pKFIuwSFl" - }, - "source": [ - "Lightning can help you improve your model by using auto_scale_batch_size flag, which tries to find the largest batch size that fits into memory, before you start your training.\n", - "Larger batch size often yields better estimates of gradients, but may also result in longer training time. \n", - "\n", - "Set it to True to initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "9_jE-iyyheIv" - }, - "source": [ - "trainer = pl.Trainer(auto_scale_batch_size=True)\n", - "\n", - "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yaHsJvwFhNJt" - }, - "source": [ - "You can set the value to `power`. `power` scaling starts from a batch size of 1 and keeps doubling the batch size until an out-of-memory (OOM) error is encountered.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Qx0FbQrphgw1" - }, - "source": [ - "trainer = pl.Trainer(auto_scale_batch_size='power')\n", - "\n", - "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8bwgVF9zhZ75" - }, - "source": [ - "You can also set it to `binsearch`, that continues to finetune the batch size by performing a binary search.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "QObXNs3yNrg9" - }, - "source": [ - "# run batch size scaling, result overrides hparams.batch_size\n", - "trainer = pl.Trainer(auto_scale_batch_size='binsearch')\n", - "\n", - "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5OWdhSsZjqW7" - }, - "source": [ - "This feature expects that a batch_size field in the hparams of your model, i.e., model.hparams.batch_size should exist and will be overridden by the results of this algorithm. \n", - "\n", - "Additionally, your train_dataloader() method should depend on this field for this feature to work.\n", - "\n", - "The algorithm in short works by:\n", - "1. Dumping the current state of the model and trainer\n", - "\n", - "2. Iteratively until convergence or maximum number of tries max_trials (default 25) has been reached:\n", - "* Call fit() method of trainer. This evaluates steps_per_trial (default 3) number of training steps. Each training step can trigger an OOM error if the tensors (training batch, weights, gradients etc.) allocated during the steps have a too large memory footprint.\n", - " * If an OOM error is encountered, decrease the batch size\n", - " * Else increase it.\n", - "* How much the batch size is increased/decreased is determined by the chosen strategy.\n", - "\n", - "3. The found batch size is saved to model.hparams.batch_size\n", - "\n", - "4. Restore the initial state of model and trainer\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "q4CvxfZmOWBd" - }, - "source": [ - "# `auto_lr_find`\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "j85e8usNwdBV" - }, - "source": [ - "Selecting a good learning rate for your deep learning training is essential for both better performance and faster convergence.\n", - "\n", - "Even optimizers such as Adam that are self-adjusting the learning rate can benefit from more optimal choices.\n", - "\n", - "To reduce the amount of guesswork concerning choosing a good initial learning rate, you can use Lightning auto learning rate finder.\n", - "\n", - "The learning rate finder does a small run where the learning rate is increased after each processed batch and the corresponding loss is logged. The result of this is a lr vs. loss plot that can be used as guidance for choosing an optimal initial lr.\n", - "\n", - "\n", - "warning: For the moment, this feature only works with models having a single optimizer. LR support for DDP is not implemented yet, it is coming soon.\n", - "\n", - "\n", - "***auto_lr_find=***\n", - "\n", - "In the most basic use case, this feature can be enabled during trainer construction with Trainer(auto_lr_find=True).\n", - "When .fit(model) is called, the LR finder will automatically run before any training is done. The lr that is found and used will be written to the console and logged together with all other hyperparameters of the model." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "iuhve9RBOfFh" - }, - "source": [ - "# default used by the Trainer (no learning rate finder)\n", - "trainer = pl.Trainer(mnist_model, auto_lr_find=False)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BL-gjXNCPDXk" - }, - "source": [ - "This flag sets your learning rate which can be accessed via self.lr or self.learning_rate.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "wEb-vIMmPJQf" - }, - "source": [ - "class LitModel(LightningModule):\n", - "\n", - " def __init__(self, learning_rate):\n", - " self.learning_rate = learning_rate\n", - "\n", - " def configure_optimizers(self):\n", - " return Adam(self.parameters(), lr=(self.lr or self.learning_rate))\n", - "\n", - "# finds learning rate automatically\n", - "# sets hparams.lr or hparams.learning_rate to that learning rate\n", - "trainer = pl.Trainer(mnist_model, auto_lr_find=True)\n", - "\n", - "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RweqvpnVPPSh" - }, - "source": [ - "To use an arbitrary value set it as auto_lr_find\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "4LKI39IfPLJv" - }, - "source": [ - "trainer = pl.Trainer(mnist_model, auto_lr_find='my_value')\n", - "\n", - "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9VAhPRKbPX-m" - }, - "source": [ - "Under the hood, when you call tune it runs the learning rate finder.\n", - "\n", - "If you want to inspect the results of the learning rate finder before doing any actual training or just play around with the parameters of the algorithm, this can be done by invoking the lr_find method of the trainer. A typical example of this would look like\n", - "\n", - "\n", - "```\n", - "trainer = pl.Trainer(auto_lr_find=True)\n", - "\n", - "# Run learning rate finder\n", - "lr_finder = trainer.lr_find(model)\n", - "\n", - "# Results can be found in\n", - "lr_finder.results\n", - "\n", - "# Plot with\n", - "fig = lr_finder.plot(suggest=True)\n", - "fig.show()\n", - "\n", - "# Pick point based on plot, or get suggestion\n", - "new_lr = lr_finder.suggestion()\n", - "\n", - "# update hparams of the model\n", - "model.hparams.lr = new_lr\n", - "\n", - "# Fit model\n", - "trainer.fit(model)\n", - "```\n", - "\n", - "The figure produced by lr_finder.plot() should look something like the figure below. It is recommended to not pick the learning rate that achieves the lowest loss, but instead something in the middle of the sharpest downward slope (red point). This is the point returned py lr_finder.suggestion().\n", - "\n", - "![image.png]()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tn1RV-jfOjt1" - }, - "source": [ - "# `benchmark`\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rsmTl5zfwjM3" - }, - "source": [ - "You can try to speed your system by setting `benchmark=True`, which enables cudnn.benchmark. This flag is likely to increase the speed of your system if your input sizes don’t change. This flag makes cudnn auto-tuner look for the optimal set of algorithms for the given hardware configuration. This usually leads to faster runtime.\n", - "But if your input sizes changes at each iteration, then cudnn will benchmark every time a new size appears, possibly leading to worse runtime performances." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "dWr-OCBgQCeb" - }, - "source": [ - "trainer = pl.Trainer(gpus=1, benchmark=True)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qwAvSKYGa24K" - }, - "source": [ - "# `deterministic`\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tl5mfmafwmat" - }, - "source": [ - "PyTorch does not guarantee reproducible results, even when using identical seeds. To guarentee reproducible results, you can remove most of the randomness from your process by setting the `deterministic` flag to True.\n", - "\n", - "Note that it might make your system slower." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Mhv5LZ3HbNCK" - }, - "source": [ - "trainer = pl.Trainer(gpus=1, deterministic=True)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "u_5eJSvTf60f" - }, - "source": [ - "# Exploding and vanishing gradients" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "B6drjh4pq6Jv" - }, - "source": [ - "## track_grad_norm\n", - "\n", - "You can debug your grad norm to identify exploding or vanishing gradients using the `track_grad_norm` flag.\n", - "\n", - "Set value to 2 to track the 2-norm. or p to any p-norm." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "2taHUir8rflR" - }, - "source": [ - "# track the 2-norm\n", - "trainer = pl.Trainer(track_grad_norm=2)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3vHKxmruk62f" - }, - "source": [ - "May be set to ‘inf’ infinity-norm." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "g7TbD6SxlAjP" - }, - "source": [ - "trainer = pl.Trainer(track_grad_norm='inf')\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TcMlRe7ywpe6" - }, - "source": [ - "## Gradient clipping\n", - "\n", - "\n", - "Exploding gradients refer to the problem that the gradients get too large and overflow in training, making the model unstable. Gradient clipping will ‘clip’ the gradients or cap them to a Threshold value to prevent the gradients from getting too large. To avoid this, we can set `gradient_clip_val` (default is set to 0.0).\n", - "\n", - "[when to use it, what are relevant values]" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "jF9JwmbOgOWF" - }, - "source": [ - "trainer = pl.Trainer(gradient_clip_val=0.1)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ggb4MkkQrr1h" - }, - "source": [ - "# truncated_bptt_steps\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "s1Iu6PyAw9_r" - }, - "source": [ - "If you have a large recurrent model, you can use truncated_bptt_steps flag to split up the backprop over portions of the sequence. This flag will automatically truncate your batches and the trainer will apply Truncated Backprop to it.\n", - "\n", - "Make sure your batches have a sequence dimension.\n", - "\n", - "Lightning takes care of splitting your batch along the time-dimension.\n", - "```\n", - "# we use the second as the time dimension\n", - "# (batch, time, ...)\n", - "sub_batch = batch[0, 0:t, ...]\n", - "Using this feature requires updating your LightningModule’s pytorch_lightning.core.LightningModule.training_step() to include a hiddens arg with the hidden\n", - "\n", - "# Truncated back-propagation through time\n", - "def training_step(self, batch, batch_idx, hiddens):\n", - " # hiddens are the hiddens from the previous truncated backprop step\n", - " out, hiddens = self.lstm(data, hiddens)\n", - "\n", - " return {\n", - " \"loss\": ...,\n", - " \"hiddens\": hiddens # remember to detach() this\n", - " }\n", - "```" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "WiTF1VMtruMU" - }, - "source": [ - "# backprop every 5 steps in a batch\n", - "trainer = pl.Trainer(truncated_bptt_steps=5)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8XI_kEWkS-nT" - }, - "source": [ - "To modify how the batch is split, override pytorch_lightning.core.LightningModule.tbptt_split_batch():\n", - "\n", - "```\n", - "class LitMNIST(LightningModule):\n", - " def tbptt_split_batch(self, batch, split_size):\n", - " # do your own splitting on the batch\n", - " return splits\n", - "```\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oLbEmbmupwQ8" - }, - "source": [ - "# reload_dataloaders_every_epoch\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CLdNGVv9xD_L" - }, - "source": [ - "Set to True to reload dataloaders every epoch (instead of loading just once in the beginning of training).\n", - "\n", - "```\n", - "# if False (default)\n", - "train_loader = model.train_dataloader()\n", - "for epoch in epochs:\n", - " for batch in train_loader:\n", - " ...\n", - "\n", - "# if True\n", - "for epoch in epochs:\n", - " train_loader = model.train_dataloader()\n", - " for batch in train_loader:\n", - "\n", - "```" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "10AXthXxp311" - }, - "source": [ - "trainer = pl.Trainer(reload_dataloaders_every_epoch=True)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "f513EYl0bmmL" - }, - "source": [ - "# Callbacks\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2pt7iGh4xNs5" - }, - "source": [ - "\n", - "Lightning Callbacks are self-contained programs that can be reused across projects.\n", - "Callbacks should capture NON-ESSENTIAL logic that is NOT required for your LightningModule to run. Lightning includes some a few built-in callbacks that can be used with flags like early stopping and Model Checkpointing, but you can also create your own callbacks to add any functionality to your models.\n", - "\n", - "The callback API includes hooks that allow you to add logic at every point of your training:\n", - "setup, teardown, on_epoch_start, on_epoch_end, on_batch_start, on_batch_end, on_init_start, on_keyboard_interrupt etc. \n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1t84gvDNsUuh" - }, - "source": [ - "## callbacks\n", - "\n", - "Use **callbacks=** to pass a list of user defined callbacks. These callbacks DO NOT replace the built-in callbacks (loggers or EarlyStopping). \n", - "\n", - "In this example, we create a dummy callback that prints a message when training starts and ends, using on_train_start and on_train_end hooks." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "oIXZYabub3f0" - }, - "source": [ - "from pytorch_lightning.callbacks import Callback\n", - "\n", - "class PrintCallback(Callback):\n", - " def on_train_start(self, trainer, pl_module):\n", - " print(\"Training is started!\")\n", - " def on_train_end(self, trainer, pl_module):\n", - " print(\"Training is done.\")\n", - "\n", - "# a list of callbacks\n", - "callbacks = [PrintCallback()]\n", - "trainer = pl.Trainer(callbacks=callbacks)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cNF74CLYfJJu" - }, - "source": [ - "# Model checkpointing\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2blgquBrxLtS" - }, - "source": [ - "Checkpoints capture the exact value of all parameters used by a model.\n", - "\n", - "Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model.\n", - "\n", - "Lightning automates saving and loading checkpoints so you restore a training session, saving all the required parameters including: \n", - "* 16-bit scaling factor (apex)\n", - "* Current epoch\n", - "* Global step\n", - "* Model state_dict\n", - "* State of all optimizers\n", - "* State of all learningRate schedulers\n", - "* State of all callbacks\n", - "* The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)\n", - "\n", - "By default Lightning will save a checkpoint in the working directory, which will be updated every epoch.\n", - "\n", - "### Automatic saving\n", - "By default Lightning will save a checkpoint in the end of the first epoch in the working directory, which will be updated every epoch." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "XGu0JULrg9l7" - }, - "source": [ - "# default used by the Trainer\n", - "trainer = pl.Trainer(default_root_path=os.getcwd())\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3s9OjkGuhq1W" - }, - "source": [ - "To change the checkpoint path pass in **default_root_dir=**" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "DgdxkrIQhvfw" - }, - "source": [ - "trainer = pl.Trainer(default_root_dir='/your/path/to/save/checkpoints')\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qyvj_bkWrJiE" - }, - "source": [ - "\n", - "You can also have Lightning update your checkpoint based on a specific metric that you are logging (using self.log), by passing the key to `monitor=`. For example, if we want to save checkpoint based on the validation loss, logged as `val_loss`, you can pass:\n", - "\n", - "\n", - "```\n", - "checkpoint_callback = ModelCheckpoint(\n", - " filepath=os.getcwd(),\n", - " save_top_k=1,\n", - " verbose=True,\n", - " monitor='val_loss',\n", - " mode='min',\n", - " prefix=''\n", - ")\n", - "```\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "YzYMivw1rO1O" - }, - "source": [ - "from pytorch_lightning.callbacks import ModelCheckpoint\n", - "\n", - "trainer = pl.Trainer(callbacks=[ModelCheckpoint(monitor='val_loss')])\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5hYs_FV8iDMn" - }, - "source": [ - "You can modify the behavior of checkpointing by creating your own callback, and passing it to the trainer. \n", - "You can control\n", - "* filepath- where logs are saved\n", - "* save_top_k- save k top models\n", - "* verbose\n", - "* monitor- the metric to monitor\n", - "* mode\n", - "* prefix\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Tb1K2VYDiNTu" - }, - "source": [ - "from pytorch_lightning.callbacks import ModelCheckpoint\n", - "\n", - "# DEFAULTS used by the Trainer\n", - "checkpoint_callback = ModelCheckpoint(\n", - " filepath=os.getcwd(),\n", - " save_top_k=3,\n", - " verbose=True,\n", - " monitor='val_loss',\n", - " mode='min',\n", - " prefix='',\n", - ")\n", - "\n", - "trainer = Trainer(callbacks=[checkpoint_callback])\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YKhZ6xRojJcl" - }, - "source": [ - "You can disable checkpointing it by passing\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Yt8zd2ZFjOXX" - }, - "source": [ - "trainer = Trainer(checkpoint_callback=False)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HcLy8asCjrj9" - }, - "source": [ - "### Manual saving\n", - "\n", - "You can manually save checkpoints and restore your model from the checkpointed state.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "kZSkMJf0jR4x" - }, - "source": [ - "trainer.fit(model)\n", - "trainer.save_checkpoint(\"example.ckpt\")\n", - "new_model = LitAutoEncoder.load_from_checkpoint(checkpoint_path=\"example.ckpt\")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "X2d9cjVPj7CP" - }, - "source": [ - "### Checkpoint Loading\n", - "To load a model along with its weights, biases and module_arguments use following method:\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "BpAFfg5zkFmH" - }, - "source": [ - "model = LitAutoEncoder.load_from_checkpoint(PATH)\n", - "\n", - "print(model.learning_rate)\n", - "# prints the learning_rate you used in this checkpoint\n", - "\n", - "model.eval()\n", - "y_hat = model(x)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jTQ3mxSJkhFN" - }, - "source": [ - "But if you don’t want to use the values saved in the checkpoint, pass in your own here" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "IoMcOh9-kfUP" - }, - "source": [ - "class LitAutoEncoder(LightningModule):\n", - "\n", - " def __init__(self, in_dim, out_dim):\n", - " super().__init__()\n", - " self.save_hyperparameters()\n", - " self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ITPVY8mNknut" - }, - "source": [ - "you can restore the model like this\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "H7XeRJzVkuY8" - }, - "source": [ - "# if you train and save the model like this it will use these values when loading\n", - "# the weights. But you can overwrite this\n", - "LitAutoEncoder(in_dim=32, out_dim=10)\n", - "\n", - "# uses in_dim=32, out_dim=10\n", - "model = LitAutoEncoder.load_from_checkpoint(PATH)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "14WwGpnVk0a4" - }, - "source": [ - "# uses in_dim=128, out_dim=10\n", - "model = LitAutoEncoder.load_from_checkpoint(PATH, in_dim=128, out_dim=10)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bY5s6wP_k1CU" - }, - "source": [ - "\n", - "\n", - "## Restoring Training State (resume_from_checkpoint)\n", - "If your training was cut short for some reason, you can resume exactly from where you left off using the `resume_from_checkpoint` flag, which will automatically restore model, epoch, step, LR schedulers, apex, etc..." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "9zfhHtyrk3rO" - }, - "source": [ - "model = LitAutoEncoder()\n", - "trainer = pl.Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')\n", - "\n", - "# automatically restores model, epoch, step, LR schedulers, apex, etc...\n", - "trainer.fit(model)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xkKdvALFsmT2" - }, - "source": [ - "## weights_save_path\n", - "You can specify a directory for saving weights file using `weights_save_path`.\n", - "\n", - "(If you are using a custom checkpoint callback, the checkpoint callback will override this flag)." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "9OwHHFcCsrgT" - }, - "source": [ - "# save to your custom path\n", - "trainer = pl.Trainer(weights_save_path='my/path')\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "PbNtlJ9Wsscf" - }, - "source": [ - "# if checkpoint callback used, then overrides the weights path\n", - "# **NOTE: this saves weights to some/path NOT my/path\n", - "checkpoint = ModelCheckpoint(filepath='some/path')\n", - "trainer = pl.Trainer(\n", - " callbacks=[checkpoint],\n", - " weights_save_path='my/path'\n", - ")\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uDdxCuyHdWQt" - }, - "source": [ - "# Early stopping\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fqAy3ihRxTfR" - }, - "source": [ - "The EarlyStopping callback can be used to monitor a validation metric and stop the training when no improvement is observed, to help you avoid overfitting.\n", - "\n", - "To enable Early Stopping you can init the EarlyStopping callback, and pass it to `callbacks=` trainer flag. The callback will look for a logged metric to early stop on.\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "lFx976CheH93" - }, - "source": [ - "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", - "\n", - "trainer = pl.Trainer(callbacks=[EarlyStopping('val_loss')])\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MwpJfTvjeOwF" - }, - "source": [ - "You can customize the callback using the following params:\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "V6I9h6HteK2U" - }, - "source": [ - "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", - "\n", - "early_stop_callback = EarlyStopping(\n", - " monitor='val_accuracy',\n", - " min_delta=0.00,\n", - " patience=3,\n", - " verbose=False,\n", - " mode='max'\n", - ")\n", - "trainer = pl.Trainer(callbacks=[early_stop_callback])\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7TAIerPYe_Q1" - }, - "source": [ - "The EarlyStopping callback runs at the end of every validation epoch, which, under the default configuration, happens after every training epoch. However, the frequency of validation can be modified by setting various parameters on the Trainer, for example check_val_every_n_epoch and val_check_interval. It must be noted that the patience parameter counts the number of validation epochs with no improvement, and not the number of training epochs. Therefore, with parameters check_val_every_n_epoch=10 and patience=3, the trainer will perform at least 40 training epochs before being stopped." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VoKrX2ENh9Fg" - }, - "source": [ - "# Logging" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-CQTPKd7iKLm" - }, - "source": [ - "Lightning has built in integration with various loggers such as TensorBoard, wandb, commet, etc.\n", - "\n", - "\n", - "You can pass any metrics you want to log during training to `self.log`, such as loss or accuracy. Similarly, pass in to self.log any metric you want to log during validation step.\n", - "\n", - "These values will be passed in to the logger of your choise. simply pass in any supported logger to logger trainer flag.\n", - "\n", - "\n", - "\n", - "Use the as`logger=` trainer flag to pass in a Logger, or iterable collection of Loggers, for experiment tracking.\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ty5VPS3AiS8L" - }, - "source": [ - "from pytorch_lightning.loggers import TensorBoardLogger\n", - "\n", - "# default logger used by trainer\n", - "logger = TensorBoardLogger(\n", - " save_dir=os.getcwd(),\n", - " version=1,\n", - " name='lightning_logs'\n", - ")\n", - "trainer = pl.Trainer(logger=logger)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jc5oWNpoiuuc" - }, - "source": [ - "Lightning supports the use of multiple loggers, just pass a list to the Trainer.\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "BlYwMRRyivp_" - }, - "source": [ - "from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger\n", - "logger1 = TensorBoardLogger('tb_logs', name='my_model')\n", - "logger2 = TestTubeLogger('tb_logs', name='my_model')\n", - "trainer = pl.Trainer(logger=[logger1, logger2])" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "a7EyspQPh7iQ" - }, - "source": [ - "## flush_logs_every_n_steps\n", - "\n", - "Use this flag to determine when logging to disc should happen." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Em_XvsmyiBbk" - }, - "source": [ - "trainer = pl.Trainer(flush_logs_every_n_steps=100)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_vDeKE98qsl1" - }, - "source": [ - "## log_every_n_steps\n", - "How often to add logging rows (does not write to disk)\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "HkqD7D_0w1Tt" - }, - "source": [ - "trainer = pl.Trainer(log_every_n_steps=1000)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9uw0gfe422CT" - }, - "source": [ - "# info logging" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dQXpt0aatDGo" - }, - "source": [ - "### default_root_dir\n", - "\n", - "---\n", - "\n", - "\n", - "\n", - "Default path for logs and weights when no logger or pytorch_lightning.callbacks.ModelCheckpoint callback passed. On certain clusters you might want to separate where logs and checkpoints are stored. If you don’t then use this argument for convenience. Paths can be local paths or remote paths such as s3://bucket/path or ‘hdfs://path/’. Credentials will need to be set up to use remote filepaths." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CMmID2Bts5W3" - }, - "source": [ - "## weights_summary\n", - "Prints a summary of the weights when training begins. Default is set to `top`- print summary of top level modules.\n", - "\n", - "Options: ‘full’, ‘top’, None." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "KTl6EdwDs6j2" - }, - "source": [ - "\n", - "# print full summary of all modules and submodules\n", - "trainer = pl.Trainer(weights_summary='full')\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "R57cSLl9w9ma" - }, - "source": [ - "# don't print a summary\n", - "trainer = Trainer(weights_summary=None)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bSc2hU5AotAP" - }, - "source": [ - "# progress bar" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GgvbyDsBxcH6" - }, - "source": [ - "## process_position\n", - "\n", - "Orders the progress bar. Useful when running multiple trainers on the same node.\n", - "\n", - "(This argument is ignored if a custom callback is passed to callbacks)\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "6ekz8Es8owDn" - }, - "source": [ - "# default used by the Trainer\n", - "trainer = pl.Trainer(process_position=0)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "itivQFgEphBU" - }, - "source": [ - "## progress_bar_refresh_rate\n", - "\n", - "How often to refresh the progress bar (in steps). In notebooks, faster refresh rates (lower number) is known to crash them because of their screen refresh rates, so raise it to 50 or more." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "GKe6eVxmplL5" - }, - "source": [ - "# default used by the Trainer\n", - "trainer = pl.Trainer(progress_bar_refresh_rate=1)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "8rDHJOJbxNtf" - }, - "source": [ - "# disable progress bar\n", - "trainer = Trainer(progress_bar_refresh_rate=0)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NCNvYLwjpWne" - }, - "source": [ - "# profiler" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "pRknrG_zpY6M" - }, - "source": [ - "# to profile standard training events\n", - "trainer = pl.Trainer(profiler=True)\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ji6aWpU73kMM" - }, - "source": [ - "You can also use Lightning AdvancedProfiler if you want more detailed information about time spent in each function call recorded during a given action. The output is quite verbose and you should only use this if you want very detailed reports.\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "layG55pt316C" - }, - "source": [ - "from pytorch_lightning.profiler import AdvancedProfiler\n", - "\n", - "trainer = Trainer(profiler=AdvancedProfiler())\n", - "\n", - "trainer.fit(model, train_loader, val_loader)" - ], - "execution_count": null, - "outputs": [] - } - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "goRmGIRI5cfC" + }, + "source": [ + "# Introduction to Lightning Flags ⚡🚩\n", + "\n", + "In this notebook, we'll go over the flags available in the `Trainer` object. Note that not everything will work in the Colab environment (multi-gpu, etc). This notebook accompanies the Trainer videos we'll be putting out.\n", + "\n", + "---\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jKj5lgdr5j48" + }, + "source": [ + "--- \n", + "### Setup \n", + "First thing first, we need to install Lightning. Simply ```pip install pytorch-lightning```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UGjilEHk4vb7" + }, + "outputs": [], + "source": [ + "! pip install pytorch-lightning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zaVUShmQ5n8Y" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from argparse import ArgumentParser\n", + "import torch\n", + "from torch import nn\n", + "from torch.nn import functional as F\n", + "from torch.utils.data import DataLoader\n", + "from torch.utils.data import random_split\n", + "from torchvision.datasets import MNIST\n", + "from torchvision import transforms\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning.metrics.functional import accuracy\n", + "\n", + "from torchvision.datasets.mnist import MNIST\n", + "from torchvision import transforms" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6tgkS8IYZwY_" + }, + "outputs": [], + "source": [ + "# ------------\n", + "# data\n", + "# ------------\n", + "pl.seed_everything(1234)\n", + "batch_size = 32\n", + "\n", + "# Init DataLoader from MNIST Dataset\n", + "\n", + "dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n", + "mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())\n", + "mnist_train, mnist_val = random_split(dataset, [55000, 5000])\n", + "\n", + "train_loader = DataLoader(mnist_train, batch_size=batch_size)\n", + "val_loader = DataLoader(mnist_val, batch_size=batch_size)\n", + "test_loader = DataLoader(mnist_test, batch_size=batch_size)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gEulmrbxwaYL" + }, + "source": [ + "### Simple AutoEncoder Model\n", + "\n", + "Were gonna define a simple Lightning model so we can play with all the settings of the Lightning Trainer.\n", + "\n", + "LightningModule is simply pure Pytorch reorganized into hooks, that represents all the steps in the training process.\n", + "\n", + "You can use LightningModule hooks to control every part of your model, but for the purpose of this video we will use a very simple MNIST classifier, a model that takes 28*28 grayscale images of hand written images, and can predict the digit between 0-9.\n", + "\n", + "The LightningModule can encompass a single model, like an image classifier, or a deep learning system composed of multiple models, like this auto encoder that contains an encoder and a decoder.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x-34xKCI40yW" + }, + "outputs": [], + "source": [ + "class LitAutoEncoder(pl.LightningModule):\n", + "\n", + " def __init__(self, batch_size=32, lr=1e-3):\n", + " super().__init__()\n", + " self.encoder = nn.Sequential(\n", + " nn.Linear(28 * 28, 64),\n", + " nn.ReLU(),\n", + " nn.Linear(64, 3)\n", + " )\n", + " self.decoder = nn.Sequential(\n", + " nn.Linear(3, 64),\n", + " nn.ReLU(),\n", + " nn.Linear(64, 28 * 28)\n", + " )\n", + " self.batch_size=batch_size\n", + " self.learning_rate=lr\n", + "\n", + " def forward(self, x):\n", + " # in lightning, forward defines the prediction/inference actions\n", + " embedding = self.encoder(x)\n", + " return embedding\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " x = x.view(x.size(0), -1)\n", + " z = self.encoder(x)\n", + " x_hat = self.decoder(z)\n", + " loss = F.mse_loss(x_hat, x)\n", + " self.log('train_loss', loss)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " x = x.view(x.size(0), -1)\n", + " z = self.encoder(x)\n", + " x_hat = self.decoder(z)\n", + " loss = F.mse_loss(x_hat, x)\n", + " self.log('val_loss', loss)\n", + " \n", + " def test_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " x = x.view(x.size(0), -1)\n", + " z = self.encoder(x)\n", + " x_hat = self.decoder(z)\n", + " loss = F.mse_loss(x_hat, x)\n", + " self.log('test_loss', loss)\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n", + " return optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VbxcRCrxiYly" + }, + "source": [ + "You'll notice the LightningModule doesn't have epoch and batch loops, we're not calling model.train() and model.eval(), and no mentions of CUDA or hardware. That's because it is all automated by the Lightning Trainer. All the engineering boilerplate is automated by the trainer: \n", + "\n", + "* Training loops\n", + "* Evaluation and test loops\n", + "* Calling model.train(), model.eval(), no_grad at the right time\n", + "* CUDA or to_device calls\n", + "\n", + "It also allows you to train your models on different hardware like GPUs and TPUs without changing your code!\n", + "\n", + "\n", + "### To use the lightning trainer simply:\n", + "\n", + "1. init your LightningModule and datasets\n", + "\n", + "2. init lightning trainer\n", + "\n", + "3. call trainer.fit\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HOk9c4_35FKg" + }, + "outputs": [], + "source": [ + "#####################\n", + "# 1. Init Model\n", + "#####################\n", + "\n", + "model = LitAutoEncoder()\n", + "\n", + "#####################\n", + "# 2. Init Trainer\n", + "#####################\n", + "\n", + "# these 2 flags are explained in the later sections...but for short explanation:\n", + "# - progress_bar_refresh_rate: limits refresh rate of tqdm progress bar so Colab doesn't freak out\n", + "# - max_epochs: only run 2 epochs instead of default of 1000\n", + "trainer = pl.Trainer(progress_bar_refresh_rate=20, max_epochs=2)\n", + "\n", + "#####################\n", + "# 3. Train\n", + "#####################\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3meDako-Qa_6" + }, + "source": [ + "Our model is training just like that, using the Lightning defaults. The beauty of Lightning is that everything is easily configurable.\n", + "In our next videos were going to show you all the ways you can control your Trainer to do things like controlling your training, validation and test loops, running on GPUs and TPUs, checkpointing, early stopping, and a lot more.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z_Wry2MckQkI" + }, + "source": [ + "# Training loop and eval loop Flags" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0MkI1xB2vsLj" + }, + "source": [ + "\n", + "To really scale up your networks, you can use accelerators like GPUs. GPUs or Graphical Processing Units, parallelize matrix multiplications which enable speed ups of at least 100x over training on CPUs.\n", + "\n", + "Let's say you have a machine with 8 GPUs on it. You can set this flag to 1, 4, or 8 GPUs and lightning will automatically distribute your training for you.\n", + "\n", + "```\n", + "trainer = pl.Trainer(gpus=1)\n", + "```\n", + "\n", + "---------\n", + "\n", + "Lightning makes your code hardware agnostic... This means, you can switch between CPUs, GPUs without code changes.\n", + "\n", + "However, it requires forming good PyTorch habits:\n", + "\n", + "1. First, remove the .cuda() or .to() calls in your code.\n", + "2. Second, when you initialize a new tensor, set the device=self.device in the call since every lightningModule knows what gpu index or TPU core it is on.\n", + "\n", + "You can also use type_as and or you can register the tensor as a buffer in your module’s __init__ method with register_buffer().\n", + "\n", + "```\n", + "# before lightning\n", + "def forward(self, x):\n", + " z = torch.Tensor(2, 3)\n", + " z = z.cuda(0)\n", + "\n", + "# with lightning\n", + "def forward(self, x):\n", + " z = torch.Tensor(2, 3)\n", + " z = z.type_as(x, device=self.device)\n", + "```\n", + "\n", + "\n", + "```\n", + "class LitModel(LightningModule):\n", + "\n", + " def __init__(self):\n", + " ...\n", + " self.register_buffer(\"sigma\", torch.eye(3))\n", + " # you can now access self.sigma anywhere in your module\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hw6jJhhjvlSL" + }, + "source": [ + "Lightning Trainer automates all the engineering boilerplate like iterating over epochs and batches, training eval and test loops, CUDA and to(device) calls, calling model.train and model.eval.\n", + "\n", + "You still have full control over the loops, by using the following trainer flags:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pT5-ETH9eUg6" + }, + "source": [ + "## Calling validation steps\n", + "Sometimes, training an epoch may be pretty fast, like minutes per epoch. In this case, you might not need to validate on every epoch. Instead, you can actually validate after a few epochs.\n", + "\n", + "Use `check_val_every_n_epoch` flag to control the frequency of validation step:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Z-EMVvKheu3D" + }, + "outputs": [], + "source": [ + "# run val loop every 10 training epochs\n", + "trainer = pl.Trainer(check_val_every_n_epoch=10)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UOzZr9S2UcSO" + }, + "source": [ + "## val_check_interval\n", + "\n", + "In some cases where your epoch is very long, you might want to check validation within an epoch.\n", + "\n", + "You can also run validation step within your training epochs, by setting `val_check_interval` flag.\n", + "\n", + "Set `val_check_interval` to a float between [0.0 to 1.0] to check your validation set within a training epoch. For example, setting it to 0.25 will check your validation set 4 times during a training epoch.\n", + "\n", + "Default is set to 1.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9kbUbvrUVLrT" + }, + "outputs": [], + "source": [ + "# check validation set 4 times during a training epoch\n", + "trainer = pl.Trainer(val_check_interval=0.25)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Onm1gBsKVaw4" + }, + "source": [ + "When you have iterable data sets, or when streaming data for production use cases, it is useful to check the validation set every number of steps. \n", + "Set val_check_interval to an int:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "psn6DVb5Vi85" + }, + "outputs": [], + "source": [ + "# check validation set every 1000 training batches\n", + "# use this when using iterableDataset and your dataset has no length\n", + "# (ie: production cases with streaming data)\n", + "trainer = pl.Trainer(val_check_interval=1000)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QkoYonrWkb7-" + }, + "source": [ + "## num_sanity_val_steps \n", + "\n", + "You may have run into an issue, where you have a bug in your validation loop, but won't catch it until your training loop ends.\n", + "\n", + "and if your training loop takes hours or days, you will waste valuable compute.\n", + "\n", + "Instead, lightning automatically runs through 2 steps of validation in the beginning to catch these kinds of bugs up front.\n", + "\n", + "\n", + "The `num_sanity_val_steps` flag can help you run n batches of validation before starting the training routine.\n", + "\n", + "You can set it to 0 to turn it off" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zOcT-ugSkiKW" + }, + "outputs": [], + "source": [ + "# turn it off\n", + "trainer = pl.Trainer(num_sanity_val_steps=0)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zS0ob1ZmTw56" + }, + "source": [ + "Set it to -1 to check all validation data before training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rzqvjA4UT263" + }, + "outputs": [], + "source": [ + "# check all validation data\n", + "trainer = pl.Trainer(num_sanity_val_steps=-1)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uMB41wq4T3Z2" + }, + "source": [ + "Or use any arbitrary number of validation steps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lGP78aQzT7VS" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(num_sanity_val_steps=10)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "H-xaYRtd1rb-" + }, + "source": [ + "## Limit train, validation, and test batches\n", + "\n", + "You can set limits on how much of training, validation and test dataset you want your model to check. This is useful if you have really large validation or tests sets, for debugging or testing something that happens at the end of an epoch.\n", + "\n", + "Set the flag to int to specify the number of batches to run\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XiK5cFKL1rcA" + }, + "outputs": [], + "source": [ + "# run for only 10 batches\n", + "trainer = pl.Trainer(limit_test_batches=10)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y4LK0g65RrBm" + }, + "source": [ + "For example, some metrics need to be computed on the entire validation results, such as AUC ROC. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8MmeRs2DR3dD" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(limit_val_batches=10)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xmigcNa1A2Vy" + }, + "source": [ + "You can use a float to limit the batches be percentage of the set on every epoch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "W7uGJt8nA4tv" + }, + "outputs": [], + "source": [ + "# run through only 25% of the test set each epoch\n", + "trainer = pl.Trainer(limit_test_batches=0.25)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YRI8THtUN7_e" + }, + "source": [ + "# Training on GPUs\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R8FFkX_FwlfE" + }, + "source": [ + "To run on 1 GPU set the flag to 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Nnzkf3KaOE27" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=1)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cxBg47s5PB1P" + }, + "source": [ + "to run on 2 or 4 GPUs, set the flag to 2 or 4." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cSEM4ihLrohT" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=2)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZE6ZgwtNudro" + }, + "source": [ + "You can also select which GPU devices to run on, using a list of indices like [1, 4] \n", + "\n", + "or a string containing a comma separated list of GPU ids like '1,2'\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gQkJtq0urrjq" + }, + "outputs": [], + "source": [ + "# list: train on GPUs 1, 4 (by bus ordering)\n", + "# trainer = Trainer(gpus='1, 4') # equivalent\n", + "trainer = pl.Trainer(gpus=[1, 4])\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XghDPad4us74" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=list(range(4)))\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6FVkKHpSPMTW" + }, + "source": [ + "You can use all the GPUs you have available by setting `gpus=-1`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "r6cKQijYrtPe" + }, + "outputs": [], + "source": [ + "# trainer = Trainer(gpus='-1') - equivalent\n", + "trainer = pl.Trainer(gpus=-1)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2C-fNLm3UGCV" + }, + "source": [ + "Lightning uses the PCI bus_id as the index for ordering GPUs." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_V75s7EhOFhE" + }, + "source": [ + "### `auto_select_gpus`\n", + "\n", + "You can save on GPUs by running in “exclusive mode”, meaning only one process at a time can access them. If your not sure which GPUs you should use when running exclusive mode, Lightning can automatically find unoccupied GPUs for you. \n", + "\n", + "Simply specify the number of gpus as an integer `gpus=k`, and set the trainer flag `auto_select_gpus=True`. Lightning will automatically help you find k gpus that are not occupied by other processes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_Sd3XFsAOIwd" + }, + "outputs": [], + "source": [ + "# enable auto selection (will find two available gpus on system)\n", + "trainer = pl.Trainer(gpus=2, auto_select_gpus=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a5JGSBMQhJNp" + }, + "source": [ + "## analyzing GPU usage\n", + "\n", + "### log_gpu_memory\n", + "\n", + "This is useful to analyze the memory usage of your GPUs.\n", + "\n", + "To get the GPU memory usage for every GPU on the master node, set the flag to log_gpu_memory=all.\n", + "\n", + "Under the hood, lightning uses the nvidia-smi command which may slow your training down.\n", + "\n", + "Your logs can become overwhelmed if you log the usage from many GPUs at once. In this case, you can also set the flag to min_max which will log only the min and max usage across all the GPUs of the master node.\n", + "\n", + "Note that lightning is not logging the usage across all nodes for performance reasons." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "idus3ZGahOki" + }, + "outputs": [], + "source": [ + "# log all the GPUs (on master node only)\n", + "trainer = Trainer(log_gpu_memory='all')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-mevgiy_hkip" + }, + "source": [ + "To avoid the performance decrease you can also set `log_gpu_memory=min_max` to only log the min and max memory on the master node.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SlvLJnWyhs7J" + }, + "outputs": [], + "source": [ + "# log only the min and max memory on the master node\n", + "trainer = Trainer(log_gpu_memory='min_max')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K82FLLIJVQG3" + }, + "source": [ + "\n", + "But what if you want to train on multiple machines and not just one?" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YViQ6PXesAue" + }, + "source": [ + "# Training on multiple GPUs" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WacbBQUivxQq" + }, + "source": [ + "Lightning makes your models hardware agnostic, and you can run on GPUs with a flip of a flag. Lightning also supports training on multiple GPUs across many machines.\n", + "\n", + "You can do this by setting the num_nodes flag.\n", + "\n", + "The world size, or the total number of GPUs you are using, will be gpus*num_nodes.\n", + "\n", + "If i set gpus=8 and num_nodes=32 then I will be training on 256 GPUs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5iKckmDvr8zZ" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=8, num_nodes=32)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GgcSbDjjlSTh" + }, + "source": [ + "## distributed backends\n", + "\n", + "Under the hood, Lightning uses distributed data parallel (or DDP) by default to distribute training across GPUs.\n", + "\n", + "This Lightning implementation of DDP calls your script under the hood multiple times with the correct environment variables.\n", + "\n", + "Under the hood it's as if you had called your script like this:\n", + "\n", + "1. Each GPU across each node gets its own process.\n", + "2. Each GPU gets visibility into a subset of the overall dataset. It will only ever see that subset.\n", + "3. Each process inits the model. (Make sure to set the random seed so that each model initializes with the same weights.)\n", + "4. Each process performs a full forward and backward pass in parallel.\n", + "5. The gradients are synced and averaged across all processes.\n", + "6. Each process updates its optimizer.\n", + "If you request multiple GPUs or nodes without setting a mode, DDP will be automatically used.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "n_Brr7F5wdtj" + }, + "outputs": [], + "source": [ + "# ddp = DistributedDataParallel\n", + "# trainer = pl.Trainer(gpus=2, num_nodes=2) equivalent\n", + "trainer = pl.Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "edxHyttC5J3e" + }, + "source": [ + "DDP is the fastest and recommended way to distribute your training, but you can pass in other backends to `distributed_backend` trainer flag, when DDP is not supported.\n", + "\n", + "DDP isn't available in\n", + "* Jupyter Notebook, Google COLAB, Kaggle, etc.\n", + "* If You have a nested script without a root package\n", + "* or if Your script needs to invoke .fit or .test multiple times" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZDh96mavxHxf" + }, + "source": [ + "### DDP_SPAWN\n", + "\n", + "In these cases, you can use `ddp_spawn` instead. `ddp_spawn` is exactly like DDP except that it uses `.spawn()` to start the training processes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JM5TKtgLxo37" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=2, num_nodes=2, distributed_backend='ddp_spawn')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sebhVE3qrhKK" + }, + "source": [ + "We STRONGLY discourage this use because it has limitations (due to Python and PyTorch):\n", + "\n", + "* Since .spawn() trains the model in subprocesses, the model on the main process does not get updated.\n", + "\n", + "* Dataloader(num_workers=N), where N is large, bottlenecks training with DDP… ie: it will be VERY slow or won’t work at all. This is a PyTorch limitation.\n", + "\n", + "* Forces everything to be picklable.\n", + "\n", + "DDP is MUCH faster than DDP_spawn. To be able to use DDP we recommend you: \n", + "\n", + "1. Install a top-level module for your project using setup.py\n", + "\n", + "```\n", + "# setup.py\n", + "#!/usr/bin/env python\n", + "\n", + "from setuptools import setup, find_packages\n", + "\n", + "setup(name='src',\n", + " version='0.0.1',\n", + " description='Describe Your Cool Project',\n", + " author='',\n", + " author_email='',\n", + " url='https://github.com/YourSeed', # REPLACE WITH YOUR OWN GITHUB PROJECT LINK\n", + " install_requires=[\n", + " 'pytorch-lightning'\n", + " ],\n", + " packages=find_packages()\n", + " )\n", + "\n", + "```\n", + "\n", + "2. Setup your project like so:\n", + "\n", + "```\n", + "/project\n", + " /src\n", + " some_file.py\n", + " /or_a_folder\n", + " setup.py\n", + "```\n", + "3. Install as a root-level package\n", + "```\n", + "cd /project\n", + "pip install -e .\n", + "```\n", + "4. You can then call your scripts anywhere\n", + "```\n", + "cd /project/src\n", + "\n", + "python some_file.py --distributed_backend 'ddp' --gpus 8\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cmB3I_oyw7a8" + }, + "source": [ + "### DP\n", + "\n", + "If you're using windows, DDP is not supported. You can use `dp` for DataParallel instead: DataParallel uses multithreading, instead of multiprocessing. It splits a batch across k GPUs. That is, if you have a batch of 32 and use DP with 2 gpus, each GPU will process 16 samples, after which the root node will aggregate the results.\n", + "\n", + "DP use is discouraged by PyTorch and Lightning. Use DDP which is more stable and at least 3x faster.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OO-J0ISvlVCg" + }, + "outputs": [], + "source": [ + "# dp = DataParallel\n", + "trainer = pl.Trainer(gpus=2, distributed_backend='dp')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y7E2eHZKwUn9" + }, + "source": [ + "### DDP2\n", + "\n", + "In certain cases, it’s advantageous to use ***all*** batches on the same machine, instead of a subset. For instance, in self-supervised learning, a common performance boost comes from increasing the number of negative samples.\n", + "\n", + "In this case, we can use DDP2 which behaves like DP in a machine and DDP across nodes. DDP2 does the following:\n", + "\n", + "* Copies a subset of the data to each node.\n", + "* Inits a model on each node.\n", + "* Runs a forward and backward pass using DP.\n", + "* Syncs gradients across nodes.\n", + "* Applies the optimizer updates.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Y4xweqL3xHER" + }, + "outputs": [], + "source": [ + "# ddp2 = DistributedDataParallel + dp\n", + "trainer = pl.Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lhKNCnveeeq5" + }, + "source": [ + "- The second mode is ddp_spawn. This works like ddp, but instead of calling your script multiple times, lightning will use multiprocessing spawn to start a subprocess per GPU. \n", + "\n", + "However, you should be careful of mixing this mode with num_workers > 0 in your dataloaders because it will bottleneck your training. This is a current known limitation of PyTorch which is why we recommend using our ddp implementation instead.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HUf9ANyQkFFO" + }, + "source": [ + "\n", + "### mocking ddp\n", + "\n", + "Testing or debugging DDP can be hard, so we have a distributed backend that simulates ddp on cpus to make it easier. Set `num_processes` to a number greater than 1 when using distributed_backend=\"ddp_cpu\" to mimic distributed training on a machine without GPUs. Note that while this is useful for debugging, it will not provide any speedup, since single-process Torch already makes efficient use of multiple CPUs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZSal5Da9kHOf" + }, + "outputs": [], + "source": [ + "# Simulate DDP for debugging on your GPU-less laptop\n", + "trainer = Trainer(distributed_backend=\"ddp_cpu\", num_processes=2)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Br_btCy5lgES" + }, + "source": [ + "# Training on TPUS\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DXkBNITdv44d" + }, + "source": [ + "Another option for accelerating your training is using TPUs.\n", + "A TPU is a Tensor processing unit, designed specifically for deep learning. Each TPU has 8 cores where each core is optimized for 128x128 matrix multiplies. Google estimates that 8 TPU cores are about as fast as 4 V100 GPUs!\n", + "\n", + "A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores! You can request a full pod from Google cloud or a “slice” which gives you some subset of those 2048 cores.\n", + "\n", + "At this moment, TPUs are available on Google Cloud (GCP), Google Colab and Kaggle Environments.\n", + "\n", + "Lightning supports training on TPUs without any code adjustments to your model. Just like when using GPUs, Lightning automatically inserts the correct samplers - no need to do this yourself!\n", + "\n", + "Under the hood, lightning uses the XLA framework developed jointly by the facebook and google XLA teams. And we want to recognize their efforts in advancing TPU adoption of PyTorch.\n", + "\n", + "## tpu_cores\n", + "To train on TPUs, set the tpu_cores flag.\n", + "\n", + "When using colab or kaggle, the allowed values are 1 or 8 cores. When using google cloud, any value above 8 is allowed.\n", + "\n", + "Your effective batch size is the batch size passed into a dataloader times the total number of tpu cores." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "itP9y70gmD9M" + }, + "outputs": [], + "source": [ + "# int: train on a single core\n", + "trainer = pl.Trainer(tpu_cores=1)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NJKnzPb3mKEg" + }, + "outputs": [], + "source": [ + "# int: train on all cores few cores\n", + "trainer = pl.Trainer(tpu_cores=8)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8a4exfWUmOHq" + }, + "source": [ + "You can also choose which TPU core to train on, by passing a list [1-8]. This is not an officially supported use case but we are working with the XLA team to improve this user experience.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "S6OrjE_bmT-_" + }, + "outputs": [], + "source": [ + "# list: train on a single selected core\n", + "trainer = pl.Trainer(tpu_cores=[2])\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Afqx3sFUmfWD" + }, + "source": [ + "To train on more than 8 cores (ie: a POD), submit this script using the xla_dist script.\n", + "\n", + "\n", + "\n", + "```\n", + "python -m torch_xla.distributed.xla_dist\n", + "--tpu=$TPU_POD_NAME\n", + "--conda-env=torch-xla-nightly\n", + "--env=XLA_USE_BF16=1\n", + "-- python your_trainer_file.py\n", + "```\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ncPvbUVQqKOh" + }, + "source": [ + "# Advanced distributed training\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4MP7bEgnv7qK" + }, + "source": [ + "\n", + "Lightning supports distributed training across multiple GPUs and TPUs out of the box by setting trainer flags, but it also allows you to control the way sampling is done if you need to." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wdHiTfAMepKH" + }, + "source": [ + "## replace_sampler_ddp\n", + "In PyTorch, you must use torch.nn.DistributedSampler for multi-node or GPU training. The sampler makes sure each GPU sees the appropriate part of your data.\n", + "\n", + "```\n", + "# without lightning\n", + "def train_dataloader(self):\n", + " dataset = MNIST(...)\n", + " sampler = None\n", + "\n", + " if self.on_tpu:\n", + " sampler = DistributedSampler(dataset)\n", + "\n", + " return DataLoader(dataset, sampler=sampler)\n", + "```\n", + "Lightning adds the correct samplers when needed, so no need to explicitly add samplers. By default it will add `shuffle=True` for train sampler and `shuffle=False` for val/test sampler.\n", + "\n", + "If you want to customize this behaviour, you can set `replace_sampler_ddp=False` and add your own distributed sampler.\n", + "\n", + "(note: For iterable datasets, we don’t do this automatically.)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZfmcB_e_7HbE" + }, + "outputs": [], + "source": [ + "sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False)\n", + "dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)\n", + "\n", + "trainer = pl.Trainer(gpus=2, num_nodes=2, replace_sampler_ddp=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-IOhk1n0lL3_" + }, + "source": [ + "## prepare_data_per_node\n", + "\n", + "When doing multi NODE training, if your nodes share the same file system, then you don't want to download data more than once to avoid possible collisions. \n", + "\n", + "Lightning automatically calls the prepare_data hook on the root GPU of the master node (ie: only a single GPU).\n", + "\n", + "In some cases where your nodes don't share the same file system, you need to download the data on each node. In this case you can set this flag to true and lightning will download the data on the root GPU of each node.\n", + "\n", + "This flag is defaulted to True." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WFBMUR48lM04" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=2, num_nodes=2, prepare_data_per_node=False)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FKBwXqo4q-Vp" + }, + "source": [ + "## sync_batchnorm\n", + "\n", + "Batch norm is computed per GPU/TPU. This flag enables synchronization between batchnorm layers across all GPUs.\n", + "It is recommended if you have small batch sizes.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GhaCLTEZrAQi" + }, + "outputs": [], + "source": [ + "trainer = Trainer(gpus=4, sync_batchnorm=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XuFA7VTFMY9-" + }, + "source": [ + "# Debugging flags\n", + "\n", + "Lightning offers a couple of flags to make debugging your models easier:\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AKoS3fdml4Jx" + }, + "source": [ + "## Fast Dev Run\n", + "\n", + "To help you save time debugging, your first run should use the fast_dev_run flag.\n", + "\n", + "This won't generate logs or save checkpoints but will touch every line of your code to make sure that it is working as intended.\n", + "\n", + "Think about this flag like a compiler. You make changes to your code, and run Trainer with this flag to verify that your changes are bug free.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "L5vuG7GSmhzK" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(fast_dev_run=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HRP1qQR5nT4p" + }, + "source": [ + "## overfit_batches\n", + "\n", + "Uses this much data of the training set. If nonzero, will use the same training set for validation and testing. If the training dataloaders have shuffle=True, Lightning will automatically disable it.\n", + "\n", + "Useful for quickly debugging or trying to overfit on purpose." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NTM-dqGMnXms" + }, + "outputs": [], + "source": [ + "# use only 1% of the train set (and use the train set for val and test)\n", + "trainer = pl.Trainer(overfit_batches=0.01)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "c0LV0gC3nl1X" + }, + "outputs": [], + "source": [ + "# overfit on 10 of the same batches\n", + "trainer = pl.Trainer(overfit_batches=10)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lt3UHU6WgtS_" + }, + "source": [ + "Or a float to represent percentage of data to run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "K3yUqADhgnkf" + }, + "outputs": [], + "source": [ + "# run through only 25% of the test set each epoch\n", + "trainer = pl.Trainer(limit_test_batches=0.25)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ODN66NeVg_2o" + }, + "source": [ + "In the case of multiple test dataloaders, the limit applies to each dataloader individually.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8aQx5SLeMz1R" + }, + "source": [ + "# accumulate_grad_batches\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g8GczZXFwKC7" + }, + "source": [ + "The batch size controls the accuracy of the estimate of the gradients. Small batch size use less memory, but decrease accuracy. When training large models, such as NLP transformers, it is useful to accumulate gradients before calling backwards(). It allows for bigger batch sizes than what can actually fit on a GPU/TPU in a single step.\n", + "\n", + "Use accumulate_grad_batches to accumulate gradients every k batches or as set up in the dict. Trainer also calls optimizer.step() for the last indivisible step number.\n", + "\n", + "For example, set accumulate_grad_batches to 4 to accumulate every 4 batches. In this case the effective batch size is batch_size*4, so if your batch size is 32, effectively it will be 128." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2jB6-Z_yPhhf" + }, + "outputs": [], + "source": [ + "# accumulate every 4 batches (effective batch size is batch*4)\n", + "trainer = pl.Trainer(accumulate_grad_batches=4)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_Yi-bdTOgINC" + }, + "source": [ + "You can also pass a dictionary to specify different accumulation per epoch. We can set it to `{5: 3, 10: 20}` to have no accumulation for epochs 1 to 4, accumulate 3 batches for epoch 5 to 10, and 20 batches after that." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X3xsoZ3YPgBv" + }, + "outputs": [], + "source": [ + "# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that\n", + "trainer = pl.Trainer(accumulate_grad_batches={5: 3, 10: 20})\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "myzH8mV4M1_9" + }, + "source": [ + "# 16 bit precision\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "v9EaFAonwOk6" + }, + "source": [ + "Most deep learning frameworks like PyTorch, train with 32-bit floating point arithmetic. \n", + "\n", + "But many models can still achieve full accuracy using half the precision.\n", + "\n", + "In 2017, NVIDIA researchers successfully used a combination of 32 and 16 bit precision (also known as mixed precision) and achieved the same accuracy as 32 bit precision training.\n", + "\n", + "The main two advantages are:\n", + "\n", + "- a reduction in memory requirements which enables larger batch sizes and models.\n", + "- and a speed up in compute. On ampere, turing and volta architectures 16 bit precision models can train at least 3 times faster.\n", + "\n", + "As of PyTorch 1.6, NVIDIA and Facebook moved mixed precision functionality into PyTorch core as the AMP package, torch.cuda.amp. \n", + "\n", + "This package supersedes the apex package developed by NVIDIA." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TjNypZPHnxvJ" + }, + "source": [ + "## precision\n", + "\n", + "Use precision flag to switch between full precision (32) to half precision (16). Can be used on CPU, GPU or TPUs.\n", + "\n", + "When using PyTorch 1.6+ Lightning uses the native amp implementation to support 16-bit.\n", + "\n", + "If used on TPU will use torch.bfloat16 but tensor printing will still show torch.float32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kBZKMVx1nw-D" + }, + "outputs": [], + "source": [ + "# 16-bit precision\n", + "trainer = pl.Trainer(gpus=1, precision=16)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VJGj3Jh7oQXU" + }, + "source": [ + "In earlier version of Lightning, we use NVIDIA Apex for 16-bit precision. Apex was the first library to attempt 16-bit and the automatic mixed precision library (amp), has since been merged into core PyTorch as of 1.6.\n", + "\n", + "If you insist in using Apex, you can set the amp_backend flag to 'apex' and install Apex on your own." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BDV1trAUPc9h" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=1, precision=16, amp_backend='apex')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HK5c_aVfNV4e" + }, + "source": [ + "## amp_level\n", + "Apex includes 4 optimization levels:\n", + "O0 (FP32 training)\n", + "O1 (Conservative Mixed Precision): only some whitelist ops are done in FP16.\n", + "O2 (Fast Mixed Precision): this is the standard mixed precision training. It maintains FP32 master weights and optimizer.step acts directly on the FP32 master weights.\n", + "O3 (FP16 training): full FP16. Passing keep_batchnorm_fp32=True can speed things up as cudnn batchnorm is faster anyway.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FshMFPowNbWt" + }, + "outputs": [], + "source": [ + "# default used by the Trainer\n", + "trainer = pl.Trainer(gpus=1, precision=16, amp_backend='apex', amp_level='O2')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y8KEr1YvNgkC" + }, + "source": [ + "# `auto_scale_batch_size`\n", + "\n", + " \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7F1pKFIuwSFl" + }, + "source": [ + "Lightning can help you improve your model by using auto_scale_batch_size flag, which tries to find the largest batch size that fits into memory, before you start your training.\n", + "Larger batch size often yields better estimates of gradients, but may also result in longer training time. \n", + "\n", + "Set it to True to initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9_jE-iyyheIv" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(auto_scale_batch_size=True)\n", + "\n", + "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yaHsJvwFhNJt" + }, + "source": [ + "You can set the value to `power`. `power` scaling starts from a batch size of 1 and keeps doubling the batch size until an out-of-memory (OOM) error is encountered.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Qx0FbQrphgw1" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(auto_scale_batch_size='power')\n", + "\n", + "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8bwgVF9zhZ75" + }, + "source": [ + "You can also set it to `binsearch`, that continues to finetune the batch size by performing a binary search.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QObXNs3yNrg9" + }, + "outputs": [], + "source": [ + "# run batch size scaling, result overrides hparams.batch_size\n", + "trainer = pl.Trainer(auto_scale_batch_size='binsearch')\n", + "\n", + "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5OWdhSsZjqW7" + }, + "source": [ + "This feature expects that a batch_size field in the hparams of your model, i.e., model.hparams.batch_size should exist and will be overridden by the results of this algorithm. \n", + "\n", + "Additionally, your train_dataloader() method should depend on this field for this feature to work.\n", + "\n", + "The algorithm in short works by:\n", + "1. Dumping the current state of the model and trainer\n", + "\n", + "2. Iteratively until convergence or maximum number of tries max_trials (default 25) has been reached:\n", + "* Call fit() method of trainer. This evaluates steps_per_trial (default 3) number of training steps. Each training step can trigger an OOM error if the tensors (training batch, weights, gradients etc.) allocated during the steps have a too large memory footprint.\n", + " * If an OOM error is encountered, decrease the batch size\n", + " * Else increase it.\n", + "* How much the batch size is increased/decreased is determined by the chosen strategy.\n", + "\n", + "3. The found batch size is saved to model.hparams.batch_size\n", + "\n", + "4. Restore the initial state of model and trainer\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "q4CvxfZmOWBd" + }, + "source": [ + "# `auto_lr_find`\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "j85e8usNwdBV" + }, + "source": [ + "Selecting a good learning rate for your deep learning training is essential for both better performance and faster convergence.\n", + "\n", + "Even optimizers such as Adam that are self-adjusting the learning rate can benefit from more optimal choices.\n", + "\n", + "To reduce the amount of guesswork concerning choosing a good initial learning rate, you can use Lightning auto learning rate finder.\n", + "\n", + "The learning rate finder does a small run where the learning rate is increased after each processed batch and the corresponding loss is logged. The result of this is a lr vs. loss plot that can be used as guidance for choosing an optimal initial lr.\n", + "\n", + "\n", + "warning: For the moment, this feature only works with models having a single optimizer. LR support for DDP is not implemented yet, it is coming soon.\n", + "\n", + "\n", + "***auto_lr_find=***\n", + "\n", + "In the most basic use case, this feature can be enabled during trainer construction with Trainer(auto_lr_find=True).\n", + "When .fit(model) is called, the LR finder will automatically run before any training is done. The lr that is found and used will be written to the console and logged together with all other hyperparameters of the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iuhve9RBOfFh" + }, + "outputs": [], + "source": [ + "# default used by the Trainer (no learning rate finder)\n", + "trainer = pl.Trainer(mnist_model, auto_lr_find=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BL-gjXNCPDXk" + }, + "source": [ + "This flag sets your learning rate which can be accessed via self.lr or self.learning_rate.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wEb-vIMmPJQf" + }, + "outputs": [], + "source": [ + "class LitModel(LightningModule):\n", + "\n", + " def __init__(self, learning_rate):\n", + " self.learning_rate = learning_rate\n", + "\n", + " def configure_optimizers(self):\n", + " return Adam(self.parameters(), lr=(self.lr or self.learning_rate))\n", + "\n", + "# finds learning rate automatically\n", + "# sets hparams.lr or hparams.learning_rate to that learning rate\n", + "trainer = pl.Trainer(mnist_model, auto_lr_find=True)\n", + "\n", + "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RweqvpnVPPSh" + }, + "source": [ + "To use an arbitrary value set it as auto_lr_find\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4LKI39IfPLJv" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(mnist_model, auto_lr_find='my_value')\n", + "\n", + "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9VAhPRKbPX-m" + }, + "source": [ + "Under the hood, when you call tune it runs the learning rate finder.\n", + "\n", + "If you want to inspect the results of the learning rate finder before doing any actual training or just play around with the parameters of the algorithm, this can be done by invoking the lr_find method of the trainer. A typical example of this would look like\n", + "\n", + "\n", + "```\n", + "trainer = pl.Trainer(auto_lr_find=True)\n", + "\n", + "# Run learning rate finder\n", + "lr_finder = trainer.lr_find(model)\n", + "\n", + "# Results can be found in\n", + "lr_finder.results\n", + "\n", + "# Plot with\n", + "fig = lr_finder.plot(suggest=True)\n", + "fig.show()\n", + "\n", + "# Pick point based on plot, or get suggestion\n", + "new_lr = lr_finder.suggestion()\n", + "\n", + "# update hparams of the model\n", + "model.hparams.lr = new_lr\n", + "\n", + "# Fit model\n", + "trainer.fit(model)\n", + "```\n", + "\n", + "The figure produced by lr_finder.plot() should look something like the figure below. It is recommended to not pick the learning rate that achieves the lowest loss, but instead something in the middle of the sharpest downward slope (red point). This is the point returned py lr_finder.suggestion().\n", + "\n", + "![image.png]()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tn1RV-jfOjt1" + }, + "source": [ + "# `benchmark`\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rsmTl5zfwjM3" + }, + "source": [ + "You can try to speed your system by setting `benchmark=True`, which enables cudnn.benchmark. This flag is likely to increase the speed of your system if your input sizes don’t change. This flag makes cudnn auto-tuner look for the optimal set of algorithms for the given hardware configuration. This usually leads to faster runtime.\n", + "But if your input sizes changes at each iteration, then cudnn will benchmark every time a new size appears, possibly leading to worse runtime performances." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dWr-OCBgQCeb" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=1, benchmark=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qwAvSKYGa24K" + }, + "source": [ + "# `deterministic`\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tl5mfmafwmat" + }, + "source": [ + "PyTorch does not guarantee reproducible results, even when using identical seeds. To guarentee reproducible results, you can remove most of the randomness from your process by setting the `deterministic` flag to True.\n", + "\n", + "Note that it might make your system slower." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Mhv5LZ3HbNCK" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gpus=1, deterministic=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u_5eJSvTf60f" + }, + "source": [ + "# Exploding and vanishing gradients" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B6drjh4pq6Jv" + }, + "source": [ + "## track_grad_norm\n", + "\n", + "You can debug your grad norm to identify exploding or vanishing gradients using the `track_grad_norm` flag.\n", + "\n", + "Set value to 2 to track the 2-norm. or p to any p-norm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2taHUir8rflR" + }, + "outputs": [], + "source": [ + "# track the 2-norm\n", + "trainer = pl.Trainer(track_grad_norm=2)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3vHKxmruk62f" + }, + "source": [ + "May be set to ‘inf’ infinity-norm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "g7TbD6SxlAjP" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(track_grad_norm='inf')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TcMlRe7ywpe6" + }, + "source": [ + "## Gradient clipping\n", + "\n", + "\n", + "Exploding gradients refer to the problem that the gradients get too large and overflow in training, making the model unstable. Gradient clipping will ‘clip’ the gradients or cap them to a Threshold value to prevent the gradients from getting too large. To avoid this, we can set `gradient_clip_val` (default is set to 0.0).\n", + "\n", + "[when to use it, what are relevant values]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jF9JwmbOgOWF" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(gradient_clip_val=0.1)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ggb4MkkQrr1h" + }, + "source": [ + "# truncated_bptt_steps\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s1Iu6PyAw9_r" + }, + "source": [ + "If you have a large recurrent model, you can use truncated_bptt_steps flag to split up the backprop over portions of the sequence. This flag will automatically truncate your batches and the trainer will apply Truncated Backprop to it.\n", + "\n", + "Make sure your batches have a sequence dimension.\n", + "\n", + "Lightning takes care of splitting your batch along the time-dimension.\n", + "```\n", + "# we use the second as the time dimension\n", + "# (batch, time, ...)\n", + "sub_batch = batch[0, 0:t, ...]\n", + "Using this feature requires updating your LightningModule’s pytorch_lightning.core.LightningModule.training_step() to include a hiddens arg with the hidden\n", + "\n", + "# Truncated back-propagation through time\n", + "def training_step(self, batch, batch_idx, hiddens):\n", + " # hiddens are the hiddens from the previous truncated backprop step\n", + " out, hiddens = self.lstm(data, hiddens)\n", + "\n", + " return {\n", + " \"loss\": ...,\n", + " \"hiddens\": hiddens # remember to detach() this\n", + " }\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WiTF1VMtruMU" + }, + "outputs": [], + "source": [ + "# backprop every 5 steps in a batch\n", + "trainer = pl.Trainer(truncated_bptt_steps=5)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8XI_kEWkS-nT" + }, + "source": [ + "To modify how the batch is split, override pytorch_lightning.core.LightningModule.tbptt_split_batch():\n", + "\n", + "```\n", + "class LitMNIST(LightningModule):\n", + " def tbptt_split_batch(self, batch, split_size):\n", + " # do your own splitting on the batch\n", + " return splits\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oLbEmbmupwQ8" + }, + "source": [ + "# reload_dataloaders_every_epoch\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CLdNGVv9xD_L" + }, + "source": [ + "Set to True to reload dataloaders every epoch (instead of loading just once in the beginning of training).\n", + "\n", + "```\n", + "# if False (default)\n", + "train_loader = model.train_dataloader()\n", + "for epoch in epochs:\n", + " for batch in train_loader:\n", + " ...\n", + "\n", + "# if True\n", + "for epoch in epochs:\n", + " train_loader = model.train_dataloader()\n", + " for batch in train_loader:\n", + "\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "10AXthXxp311" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(reload_dataloaders_every_epoch=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f513EYl0bmmL" + }, + "source": [ + "# Callbacks\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2pt7iGh4xNs5" + }, + "source": [ + "\n", + "Lightning Callbacks are self-contained programs that can be reused across projects.\n", + "Callbacks should capture NON-ESSENTIAL logic that is NOT required for your LightningModule to run. Lightning includes some a few built-in callbacks that can be used with flags like early stopping and Model Checkpointing, but you can also create your own callbacks to add any functionality to your models.\n", + "\n", + "The callback API includes hooks that allow you to add logic at every point of your training:\n", + "setup, teardown, on_epoch_start, on_epoch_end, on_batch_start, on_batch_end, on_init_start, on_keyboard_interrupt etc. \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1t84gvDNsUuh" + }, + "source": [ + "## callbacks\n", + "\n", + "Use **callbacks=** to pass a list of user defined callbacks. These callbacks DO NOT replace the built-in callbacks (loggers or EarlyStopping). \n", + "\n", + "In this example, we create a dummy callback that prints a message when training starts and ends, using on_train_start and on_train_end hooks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oIXZYabub3f0" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.callbacks import Callback\n", + "\n", + "class PrintCallback(Callback):\n", + " def on_train_start(self, trainer, pl_module):\n", + " print(\"Training is started!\")\n", + " def on_train_end(self, trainer, pl_module):\n", + " print(\"Training is done.\")\n", + "\n", + "# a list of callbacks\n", + "callbacks = [PrintCallback()]\n", + "trainer = pl.Trainer(callbacks=callbacks)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cNF74CLYfJJu" + }, + "source": [ + "# Model checkpointing\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2blgquBrxLtS" + }, + "source": [ + "Checkpoints capture the exact value of all parameters used by a model.\n", + "\n", + "Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model.\n", + "\n", + "Lightning automates saving and loading checkpoints so you restore a training session, saving all the required parameters including: \n", + "* 16-bit scaling factor (apex)\n", + "* Current epoch\n", + "* Global step\n", + "* Model state_dict\n", + "* State of all optimizers\n", + "* State of all learningRate schedulers\n", + "* State of all callbacks\n", + "* The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)\n", + "\n", + "By default Lightning will save a checkpoint in the working directory, which will be updated every epoch.\n", + "\n", + "### Automatic saving\n", + "By default Lightning will save a checkpoint in the end of the first epoch in the working directory, which will be updated every epoch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XGu0JULrg9l7" + }, + "outputs": [], + "source": [ + "# default used by the Trainer\n", + "trainer = pl.Trainer(default_root_path=os.getcwd())\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3s9OjkGuhq1W" + }, + "source": [ + "To change the checkpoint path pass in **default_root_dir=**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DgdxkrIQhvfw" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(default_root_dir='/your/path/to/save/checkpoints')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qyvj_bkWrJiE" + }, + "source": [ + "\n", + "You can also have Lightning update your checkpoint based on a specific metric that you are logging (using self.log), by passing the key to `monitor=`. For example, if we want to save checkpoint based on the validation loss, logged as `val_loss`, you can pass:\n", + "\n", + "\n", + "```\n", + "checkpoint_callback = ModelCheckpoint(\n", + " filepath=os.getcwd(),\n", + " save_top_k=1,\n", + " verbose=True,\n", + " monitor='val_loss',\n", + " mode='min',\n", + " prefix=''\n", + ")\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YzYMivw1rO1O" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.callbacks import ModelCheckpoint\n", + "\n", + "trainer = pl.Trainer(callbacks=[ModelCheckpoint(monitor='val_loss')])\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5hYs_FV8iDMn" + }, + "source": [ + "You can modify the behavior of checkpointing by creating your own callback, and passing it to the trainer. \n", + "You can control\n", + "* filepath- where logs are saved\n", + "* save_top_k- save k top models\n", + "* verbose\n", + "* monitor- the metric to monitor\n", + "* mode\n", + "* prefix\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Tb1K2VYDiNTu" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.callbacks import ModelCheckpoint\n", + "\n", + "# DEFAULTS used by the Trainer\n", + "checkpoint_callback = ModelCheckpoint(\n", + " filepath=os.getcwd(),\n", + " save_top_k=3,\n", + " verbose=True,\n", + " monitor='val_loss',\n", + " mode='min',\n", + " prefix='',\n", + ")\n", + "\n", + "trainer = Trainer(callbacks=[checkpoint_callback])\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YKhZ6xRojJcl" + }, + "source": [ + "You can disable checkpointing it by passing\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Yt8zd2ZFjOXX" + }, + "outputs": [], + "source": [ + "trainer = Trainer(checkpoint_callback=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HcLy8asCjrj9" + }, + "source": [ + "### Manual saving\n", + "\n", + "You can manually save checkpoints and restore your model from the checkpointed state.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kZSkMJf0jR4x" + }, + "outputs": [], + "source": [ + "trainer.fit(model)\n", + "trainer.save_checkpoint(\"example.ckpt\")\n", + "new_model = LitAutoEncoder.load_from_checkpoint(checkpoint_path=\"example.ckpt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "X2d9cjVPj7CP" + }, + "source": [ + "### Checkpoint Loading\n", + "To load a model along with its weights, biases and module_arguments use following method:\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BpAFfg5zkFmH" + }, + "outputs": [], + "source": [ + "model = LitAutoEncoder.load_from_checkpoint(PATH)\n", + "\n", + "print(model.learning_rate)\n", + "# prints the learning_rate you used in this checkpoint\n", + "\n", + "model.eval()\n", + "y_hat = model(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jTQ3mxSJkhFN" + }, + "source": [ + "But if you don’t want to use the values saved in the checkpoint, pass in your own here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IoMcOh9-kfUP" + }, + "outputs": [], + "source": [ + "class LitAutoEncoder(LightningModule):\n", + "\n", + " def __init__(self, in_dim, out_dim):\n", + " super().__init__()\n", + " self.save_hyperparameters()\n", + " self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ITPVY8mNknut" + }, + "source": [ + "you can restore the model like this\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "H7XeRJzVkuY8" + }, + "outputs": [], + "source": [ + "# if you train and save the model like this it will use these values when loading\n", + "# the weights. But you can overwrite this\n", + "LitAutoEncoder(in_dim=32, out_dim=10)\n", + "\n", + "# uses in_dim=32, out_dim=10\n", + "model = LitAutoEncoder.load_from_checkpoint(PATH)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "14WwGpnVk0a4" + }, + "outputs": [], + "source": [ + "# uses in_dim=128, out_dim=10\n", + "model = LitAutoEncoder.load_from_checkpoint(PATH, in_dim=128, out_dim=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bY5s6wP_k1CU" + }, + "source": [ + "\n", + "\n", + "## Restoring Training State (resume_from_checkpoint)\n", + "If your training was cut short for some reason, you can resume exactly from where you left off using the `resume_from_checkpoint` flag, which will automatically restore model, epoch, step, LR schedulers, apex, etc..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9zfhHtyrk3rO" + }, + "outputs": [], + "source": [ + "model = LitAutoEncoder()\n", + "trainer = pl.Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')\n", + "\n", + "# automatically restores model, epoch, step, LR schedulers, apex, etc...\n", + "trainer.fit(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xkKdvALFsmT2" + }, + "source": [ + "## weights_save_path\n", + "You can specify a directory for saving weights file using `weights_save_path`.\n", + "\n", + "(If you are using a custom checkpoint callback, the checkpoint callback will override this flag)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9OwHHFcCsrgT" + }, + "outputs": [], + "source": [ + "# save to your custom path\n", + "trainer = pl.Trainer(weights_save_path='my/path')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PbNtlJ9Wsscf" + }, + "outputs": [], + "source": [ + "# if checkpoint callback used, then overrides the weights path\n", + "# **NOTE: this saves weights to some/path NOT my/path\n", + "checkpoint = ModelCheckpoint(filepath='some/path')\n", + "trainer = pl.Trainer(\n", + " callbacks=[checkpoint],\n", + " weights_save_path='my/path'\n", + ")\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uDdxCuyHdWQt" + }, + "source": [ + "# Early stopping\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fqAy3ihRxTfR" + }, + "source": [ + "The EarlyStopping callback can be used to monitor a validation metric and stop the training when no improvement is observed, to help you avoid overfitting.\n", + "\n", + "To enable Early Stopping you can init the EarlyStopping callback, and pass it to `callbacks=` trainer flag. The callback will look for a logged metric to early stop on.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lFx976CheH93" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", + "\n", + "trainer = pl.Trainer(callbacks=[EarlyStopping('val_loss')])\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MwpJfTvjeOwF" + }, + "source": [ + "You can customize the callback using the following params:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "V6I9h6HteK2U" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", + "\n", + "early_stop_callback = EarlyStopping(\n", + " monitor='val_accuracy',\n", + " min_delta=0.00,\n", + " patience=3,\n", + " verbose=False,\n", + " mode='max'\n", + ")\n", + "trainer = pl.Trainer(callbacks=[early_stop_callback])\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7TAIerPYe_Q1" + }, + "source": [ + "The EarlyStopping callback runs at the end of every validation epoch, which, under the default configuration, happens after every training epoch. However, the frequency of validation can be modified by setting various parameters on the Trainer, for example check_val_every_n_epoch and val_check_interval. It must be noted that the patience parameter counts the number of validation epochs with no improvement, and not the number of training epochs. Therefore, with parameters check_val_every_n_epoch=10 and patience=3, the trainer will perform at least 40 training epochs before being stopped." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VoKrX2ENh9Fg" + }, + "source": [ + "# Logging" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-CQTPKd7iKLm" + }, + "source": [ + "Lightning has built in integration with various loggers such as TensorBoard, wandb, commet, etc.\n", + "\n", + "\n", + "You can pass any metrics you want to log during training to `self.log`, such as loss or accuracy. Similarly, pass in to self.log any metric you want to log during validation step.\n", + "\n", + "These values will be passed in to the logger of your choise. simply pass in any supported logger to logger trainer flag.\n", + "\n", + "\n", + "\n", + "Use the as`logger=` trainer flag to pass in a Logger, or iterable collection of Loggers, for experiment tracking.\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ty5VPS3AiS8L" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.loggers import TensorBoardLogger\n", + "\n", + "# default logger used by trainer\n", + "logger = TensorBoardLogger(\n", + " save_dir=os.getcwd(),\n", + " version=1,\n", + " name='lightning_logs'\n", + ")\n", + "trainer = pl.Trainer(logger=logger)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jc5oWNpoiuuc" + }, + "source": [ + "Lightning supports the use of multiple loggers, just pass a list to the Trainer.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BlYwMRRyivp_" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger\n", + "logger1 = TensorBoardLogger('tb_logs', name='my_model')\n", + "logger2 = TestTubeLogger('tb_logs', name='my_model')\n", + "trainer = pl.Trainer(logger=[logger1, logger2])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a7EyspQPh7iQ" + }, + "source": [ + "## flush_logs_every_n_steps\n", + "\n", + "Use this flag to determine when logging to disc should happen." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Em_XvsmyiBbk" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(flush_logs_every_n_steps=100)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_vDeKE98qsl1" + }, + "source": [ + "## log_every_n_steps\n", + "How often to add logging rows (does not write to disk)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HkqD7D_0w1Tt" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(log_every_n_steps=1000)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9uw0gfe422CT" + }, + "source": [ + "# info logging" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dQXpt0aatDGo" + }, + "source": [ + "### default_root_dir\n", + "\n", + "---\n", + "\n", + "\n", + "\n", + "Default path for logs and weights when no logger or pytorch_lightning.callbacks.ModelCheckpoint callback passed. On certain clusters you might want to separate where logs and checkpoints are stored. If you don’t then use this argument for convenience. Paths can be local paths or remote paths such as s3://bucket/path or ‘hdfs://path/’. Credentials will need to be set up to use remote filepaths." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CMmID2Bts5W3" + }, + "source": [ + "## weights_summary\n", + "Prints a summary of the weights when training begins. Default is set to `top`- print summary of top level modules.\n", + "\n", + "Options: ‘full’, ‘top’, None." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KTl6EdwDs6j2" + }, + "outputs": [], + "source": [ + "\n", + "# print full summary of all modules and submodules\n", + "trainer = pl.Trainer(weights_summary='full')\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "R57cSLl9w9ma" + }, + "outputs": [], + "source": [ + "# don't print a summary\n", + "trainer = Trainer(weights_summary=None)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bSc2hU5AotAP" + }, + "source": [ + "# progress bar" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GgvbyDsBxcH6" + }, + "source": [ + "## process_position\n", + "\n", + "Orders the progress bar. Useful when running multiple trainers on the same node.\n", + "\n", + "(This argument is ignored if a custom callback is passed to callbacks)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6ekz8Es8owDn" + }, + "outputs": [], + "source": [ + "# default used by the Trainer\n", + "trainer = pl.Trainer(process_position=0)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "itivQFgEphBU" + }, + "source": [ + "## progress_bar_refresh_rate\n", + "\n", + "How often to refresh the progress bar (in steps). In notebooks, faster refresh rates (lower number) is known to crash them because of their screen refresh rates, so raise it to 50 or more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GKe6eVxmplL5" + }, + "outputs": [], + "source": [ + "# default used by the Trainer\n", + "trainer = pl.Trainer(progress_bar_refresh_rate=1)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8rDHJOJbxNtf" + }, + "outputs": [], + "source": [ + "# disable progress bar\n", + "trainer = Trainer(progress_bar_refresh_rate=0)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NCNvYLwjpWne" + }, + "source": [ + "# profiler" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pRknrG_zpY6M" + }, + "outputs": [], + "source": [ + "# to profile standard training events\n", + "trainer = pl.Trainer(profiler=True)\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ji6aWpU73kMM" + }, + "source": [ + "You can also use Lightning AdvancedProfiler if you want more detailed information about time spent in each function call recorded during a given action. The output is quite verbose and you should only use this if you want very detailed reports.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "layG55pt316C" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.profiler import AdvancedProfiler\n", + "\n", + "trainer = Trainer(profiler=AdvancedProfiler())\n", + "\n", + "trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "05-trainer-flags-overview.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/pl_examples/basic_examples/README.md b/pl_examples/basic_examples/README.md index 4dcf06a74bf92..18ae204396290 100644 --- a/pl_examples/basic_examples/README.md +++ b/pl_examples/basic_examples/README.md @@ -14,7 +14,15 @@ python mnist.py python mnist.py --gpus 2 --distributed_backend 'dp' ``` ---- +--- +#### MNIST with DALI +The MNIST example above using [NVIDIA DALI](https://developer.nvidia.com/DALI). +Requires NVIDIA DALI to be installed based on your CUDA version, see [here](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html). +```bash +python mnist_dali.py +``` + +--- #### Image classifier Generic image classifier with an arbitrary backbone (ie: a simple system) ```bash diff --git a/pl_examples/basic_examples/mnist_dali.py b/pl_examples/basic_examples/mnist_dali.py new file mode 100644 index 0000000000000..649198053a01b --- /dev/null +++ b/pl_examples/basic_examples/mnist_dali.py @@ -0,0 +1,204 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC +from argparse import ArgumentParser +from random import shuffle +from warnings import warn + +import numpy as np +import torch +from torch.nn import functional as F +from torch.utils.data import random_split + +import pytorch_lightning as pl + +try: + from torchvision.datasets.mnist import MNIST + from torchvision import transforms +except Exception: + from tests.base.datasets import MNIST + +try: + import nvidia.dali.ops as ops + import nvidia.dali.types as types + from nvidia.dali.pipeline import Pipeline + from nvidia.dali.plugin.pytorch import DALIClassificationIterator +except (ImportError, ModuleNotFoundError): + warn('NVIDIA DALI is not available') + ops, types, Pipeline, DALIClassificationIterator = ..., ..., ABC, ABC + + +class ExternalMNISTInputIterator(object): + """ + This iterator class wraps torchvision's MNIST dataset and returns the images and labels in batches + """ + + def __init__(self, mnist_ds, batch_size): + self.batch_size = batch_size + self.mnist_ds = mnist_ds + self.indices = list(range(len(self.mnist_ds))) + shuffle(self.indices) + + def __iter__(self): + self.i = 0 + self.n = len(self.mnist_ds) + return self + + def __next__(self): + batch = [] + labels = [] + for _ in range(self.batch_size): + index = self.indices[self.i] + img, label = self.mnist_ds[index] + batch.append(img.numpy()) + labels.append(np.array([label], dtype=np.uint8)) + self.i = (self.i + 1) % self.n + return (batch, labels) + + +class ExternalSourcePipeline(Pipeline): + """ + This DALI pipeline class just contains the MNIST iterator + """ + + def __init__(self, batch_size, eii, num_threads, device_id): + super(ExternalSourcePipeline, self).__init__(batch_size, num_threads, device_id, seed=12) + self.source = ops.ExternalSource(source=eii, num_outputs=2) + self.build() + + def define_graph(self): + images, labels = self.source() + return images, labels + + +class DALIClassificationLoader(DALIClassificationIterator): + """ + This class extends DALI's original DALIClassificationIterator with the __len__() function so that we can call len() on it + """ + + def __init__( + self, + pipelines, + size=-1, + reader_name=None, + auto_reset=False, + fill_last_batch=True, + dynamic_shape=False, + last_batch_padded=False, + ): + super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded) + + def __len__(self): + batch_count = self._size // (self._num_gpus * self.batch_size) + last_batch = 1 if self._fill_last_batch else 0 + return batch_count + last_batch + + +class LitClassifier(pl.LightningModule): + def __init__(self, hidden_dim=128, learning_rate=1e-3): + super().__init__() + self.save_hyperparameters() + + self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) + self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = torch.relu(self.l1(x)) + x = torch.relu(self.l2(x)) + return x + + def split_batch(self, batch): + return batch[0]["data"], batch[0]["label"].squeeze().long() + + def training_step(self, batch, batch_idx): + x, y = self.split_batch(batch) + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = self.split_batch(batch) + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('valid_loss', loss) + + def test_step(self, batch, batch_idx): + x, y = self.split_batch(batch) + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('test_loss', loss) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--hidden_dim', type=int, default=128) + parser.add_argument('--learning_rate', type=float, default=0.0001) + return parser + + +def cli_main(): + pl.seed_everything(1234) + + # ------------ + # args + # ------------ + parser = ArgumentParser() + parser.add_argument('--batch_size', default=32, type=int) + parser = pl.Trainer.add_argparse_args(parser) + parser = LitClassifier.add_model_specific_args(parser) + args = parser.parse_args() + + # ------------ + # data + # ------------ + dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor()) + mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor()) + mnist_train, mnist_val = random_split(dataset, [55000, 5000]) + + eii_train = ExternalMNISTInputIterator(mnist_train, args.batch_size) + eii_val = ExternalMNISTInputIterator(mnist_val, args.batch_size) + eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size) + + pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0) + train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=False) + + pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0) + val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False) + + pipe_test = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_test, num_threads=2, device_id=0) + test_loader = DALIClassificationLoader(pipe_test, size=len(mnist_test), auto_reset=True, fill_last_batch=False) + + # ------------ + # model + # ------------ + model = LitClassifier(args.hidden_dim, args.learning_rate) + + # ------------ + # training + # ------------ + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model, train_loader, val_loader) + + # ------------ + # testing + # ------------ + trainer.test(test_dataloaders=test_loader) + + +if __name__ == "__main__": + cli_main() diff --git a/pl_examples/test_examples.py b/pl_examples/test_examples.py index 7fe5d4ed604dc..60f10a637e583 100644 --- a/pl_examples/test_examples.py +++ b/pl_examples/test_examples.py @@ -1,6 +1,15 @@ +import platform from unittest import mock -import torch + import pytest +import torch + +try: + from nvidia.dali import ops, types, pipeline, plugin +except (ImportError, ModuleNotFoundError): + DALI_AVAILABLE = False +else: + DALI_AVAILABLE = True dp_16_args = """ --max_epochs 1 \ @@ -28,7 +37,7 @@ --precision 16 \ """ - +# TODO # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") # @pytest.mark.parametrize('cli_args', [dp_16_args]) # def test_examples_dp_mnist(cli_args): @@ -38,6 +47,7 @@ # cli_main() +# TODO # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") # @pytest.mark.parametrize('cli_args', [dp_16_args]) # def test_examples_dp_image_classifier(cli_args): @@ -45,8 +55,9 @@ # # with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): # cli_main() -# -# + + +# TODO # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") # @pytest.mark.parametrize('cli_args', [dp_16_args]) # def test_examples_dp_autoencoder(cli_args): @@ -56,6 +67,7 @@ # cli_main() +# TODO # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") # @pytest.mark.parametrize('cli_args', [ddp_args]) # def test_examples_ddp_mnist(cli_args): @@ -63,8 +75,9 @@ # # with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): # cli_main() -# -# + + +# TODO # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") # @pytest.mark.parametrize('cli_args', [ddp_args]) # def test_examples_ddp_image_classifier(cli_args): @@ -72,8 +85,9 @@ # # with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): # cli_main() -# -# + + +# TODO # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") # @pytest.mark.parametrize('cli_args', [ddp_args]) # def test_examples_ddp_autoencoder(cli_args): @@ -92,3 +106,14 @@ def test_examples_cpu(cli_args): for cli_cmd in [mnist_cli, ic_cli, ae_cli]: with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): cli_cmd() + + +@pytest.mark.skipif(not DALI_AVAILABLE, reason="Nvidia DALI required") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(platform.system() != 'Linux', reason='Only applies to Linux platform.') +@pytest.mark.parametrize('cli_args', [cpu_args]) +def test_examples_mnist_dali(cli_args): + from pl_examples.basic_examples.mnist_dali import cli_main + + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): + cli_main() diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index e69addf234a36..408c979d1ff2e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -14,7 +14,7 @@ import os import math from enum import Enum -from typing import Any, Optional +from typing import Any, Optional, Union import torch @@ -30,6 +30,12 @@ except ImportError: amp = None +if torch.distributed.is_available(): + from torch.distributed import ReduceOp +else: + class ReduceOp: + SUM = None + EPSILON = 1e-6 EPSILON_FP16 = 1e-5 @@ -209,6 +215,22 @@ def init_ddp_connection( torch_backend, rank=global_rank, world_size=world_size ) + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + """ + Function to reduce a tensor from several distributed processes to one aggregated tensor. + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum. + Can also be a string of 'avg', 'mean' to calculate the mean during reduction. + Return: + reduced value + """ + raise NotImplementedError() + def __getstate__(self): return { 'trainer': self.trainer, diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index b9f01b5ddc167..b127fdd40c934 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -18,17 +18,18 @@ import sys from os.path import abspath from time import sleep -from typing import Optional, List +from typing import Any, Optional, List, Union import numpy as np from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import find_free_network_port from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything from torch.nn.parallel import DistributedDataParallel @@ -298,3 +299,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py index 2aad005a07847..c80e8a4ec355c 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License import os -from typing import List, Optional +from typing import Any, List, Optional, Union import torch import torch.distributed as torch_distrib @@ -20,10 +20,11 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import sync_ddp_if_available from pytorch_lightning.distributed.dist import LightningDistributed @@ -199,3 +200,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index f1813361c5eec..64e326b7ee0fc 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License import os -from typing import List, Optional +from typing import Any, List, Optional, Union import torch import torch.distributed as torch_distrib @@ -21,11 +21,11 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn -from pytorch_lightning.utilities.distributed import find_free_network_port +from pytorch_lightning.utilities.distributed import find_free_network_port, sync_ddp_if_available from pytorch_lightning.distributed.dist import LightningDistributed try: @@ -229,3 +229,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py index 6b27e7da330ea..a90d7750eaeea 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License import os -from typing import List, Optional +from typing import Any, List, Optional, Union import torch import torch.distributed as torch_distrib @@ -20,11 +20,12 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import sync_ddp_if_available try: from hydra.utils import to_absolute_path, get_original_cwd @@ -198,3 +199,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/ddp_slurm_accelerator.py b/pytorch_lightning/accelerators/ddp_slurm_accelerator.py index 8a6326d3d5cb8..1ea4461c3c3cc 100644 --- a/pytorch_lightning/accelerators/ddp_slurm_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_slurm_accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License import os -from typing import List +from typing import Any, List, Optional, Union import torch import torch.distributed as torch_distrib @@ -20,11 +20,11 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything try: @@ -123,7 +123,7 @@ def ddp_train(self, process_idx, model): self.set_world_ranks(process_idx) # toggle prog bar - if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None: + if self.trainer.global_rank != 0 and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() # set warning rank @@ -205,3 +205,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index b204494773362..2e0bac46c4c20 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -13,7 +13,7 @@ # limitations under the License import os import re -from typing import List, Optional +from typing import Any, List, Optional, Union import torch import torch.multiprocessing as mp @@ -22,11 +22,12 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.cloud_io import atomic_save, load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, find_free_network_port +from pytorch_lightning.utilities.distributed import sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.distributed.dist import LightningDistributed @@ -254,3 +255,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py b/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py index 8a9e6ac77e574..e54ad905de80e 100644 --- a/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License import os -from typing import List, Optional +from typing import Any, List, Optional, Union import torch import torch.distributed as torch_distrib @@ -20,11 +20,12 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import sync_ddp_if_available try: @@ -201,3 +202,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 91a5400999f6e..e5314a983f9db 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import ExitStack -from typing import Optional +from typing import Any, Optional, Union import torch from torch.optim.lr_scheduler import _LRScheduler -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only @@ -161,3 +161,41 @@ def barrier(self, name: Optional[str] = None): def broadcast(self, obj, src=0): obj = hvd.broadcast_object(obj, src) return obj + + def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None): + if group is not None: + raise ValueError( + "Horovod does not support allgather using a subcommunicator at this time. " + "Unset `group`." + ) + + if len(result.shape) == 0: + # Convert scalars to single dimension tensors + result = result.reshape(1) + + # sync and gather all + hvd.join() + gathered = hvd.allgather(result) + gathered_result = list(gathered.split(1, dim=0)) + return gathered_result + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + if group is not None: + raise ValueError( + "Horovod does not support allreduce using a subcommunicator at this time. " + "Unset `group`." + ) + + if reduce_op is None or reduce_op == "sum": + reduce_op = hvd.Sum + elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): + reduce_op = hvd.Average + else: + raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") + + # sync all processes before reduction + hvd.join() + return hvd.allreduce(tensor, op=reduce_op) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 78b9c7025711a..05eb8ee86be63 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -258,6 +258,8 @@ def log( raise MisconfigurationException( f"Logged key: {name} should not contain information about dataloader_idx.") + accelerator = self.trainer.accelerator_backend + self._results.log( name, value, @@ -272,6 +274,7 @@ def log( sync_dist, sync_dist_op, sync_dist_group, + accelerator.sync_tensor, self._current_dataloader_idx, ) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 059c724aa75a9..0eca72095e0e0 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -124,15 +124,17 @@ def log( sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, + sync_fn: Callable = None, dataloader_idx: Optional[int] = None, ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): value = value.detach() - # sync across ddp + # sync across workers when using distributed training + sync_fn = sync_fn or sync_ddp_if_available if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)): - value = sync_ddp_if_available(value, group=sync_dist_group, reduce_op=sync_dist_op) + value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) if 'meta' not in self: self.__setitem__('meta', {}) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index c7255c6e4497e..0f01fb9813407 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -50,6 +50,9 @@ class Accuracy(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None Example: @@ -67,11 +70,13 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index b716817427230..1a568bab37209 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -24,7 +24,7 @@ from torch import nn from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import gather_all_tensors_if_available +from pytorch_lightning.utilities.distributed import gather_all_tensors from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum @@ -53,21 +53,26 @@ class Metric(nn.Module, ABC): Forward only calls ``update()`` and returns None if this is set to False. default: True dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False + before returning the value at the step. process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None """ def __init__( self, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__() self.dist_sync_on_step = dist_sync_on_step self.compute_on_step = compute_on_step self.process_group = process_group + self.dist_sync_fn = dist_sync_fn self._to_sync = True self.update = self._wrap_update(self.update) @@ -174,12 +179,12 @@ def forward(self, *args, **kwargs): return self._forward_cache - def _sync_dist(self): + def _sync_dist(self, dist_sync_fn=gather_all_tensors): input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} output_dict = apply_to_collection( input_dict, torch.Tensor, - gather_all_tensors_if_available, + dist_sync_fn, group=self.process_group, ) @@ -208,12 +213,15 @@ def wrapped_func(*args, **kwargs): if self._computed is not None: return self._computed - if ( - self._to_sync - and torch.distributed.is_available() # noqa: W503 - and torch.distributed.is_initialized() # noqa: W503 - ): - self._sync_dist() + dist_sync_fn = self.dist_sync_fn + if (dist_sync_fn is None + and torch.distributed.is_available() + and torch.distributed.is_initialized()): + # User provided a bool, so we assume DDP if available + dist_sync_fn = gather_all_tensors + + if self._to_sync and dist_sync_fn is not None: + self._sync_dist(dist_sync_fn) self._computed = compute(*args, **kwargs) self.reset() diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index 79fc8b4c4e183..f59ce0b67de62 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from typing import Any, Optional +from typing import Any, Callable, Optional from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn @@ -74,11 +74,13 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') if multioutput not in allowed_multioutput: diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index 89cb56d431ad4..ba6d2c6d79a08 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -49,11 +49,13 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index 87c1fddf2674c..6da6d55d5dd1c 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from typing import Any, Optional +from typing import Any, Callable, Optional from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.functional.mean_squared_error import ( @@ -50,11 +50,13 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index 256fac20365af..696ad01ca829d 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from typing import Any, Optional +from typing import Any, Callable, Optional from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.functional.mean_squared_log_error import ( @@ -50,11 +50,13 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 3ce4b523545c3..e9c33cea70b8a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -192,7 +192,7 @@ def _on_validation_start_log(): @staticmethod def _on_validation_end_log(): """Called when the validation loop ends.""" - return {"on_step": [False], "on_epoch": [False, True]} + return None @staticmethod def _on_test_start_log(): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 2a9d68807e694..2980b037c95f7 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from collections import defaultdict -from copy import deepcopy +import os +from collections import defaultdict, ChainMap from enum import Enum -from typing import Union, Tuple, Any, Mapping - +from typing import Union, Tuple, Any, Dict, Optional, List +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.core.step_result import Result @@ -92,73 +91,70 @@ def __init__(self, fx_name): self._internals_reduced = {} self._internal_type = None self.has_reduced = False + self._latest_ref = {} - def get_reduced_metrics(self): - return self._internals_reduced - - def add_dataloader_idx(self): - return len(self._internals) > 1 + @property + def has_several_dataloaders(self) -> bool: + return self.num_dataloaders > 1 @property - def num_dataloaders(self): - return len(self._internals) - - def get_latest_from_dict(self, dl_idx): - num_opt_idx = len(self._internals[dl_idx]) - 1 - assert num_opt_idx >= 0 - num_opt_idx = str(num_opt_idx) - num_batch_idx = len(self._internals[dl_idx][num_opt_idx]) - 1 - batch_indexes = [*self._internals[dl_idx][num_opt_idx].keys()] - # sort them by increasing order - batch_indexes.sort(key=float) - assert num_batch_idx >= 0 - return self._internals[dl_idx][num_opt_idx][batch_indexes[-1]][-1] + def num_dataloaders(self) -> int: + _inter = self._internals_reduced if self.has_reduced else self._internals + return len(_inter) def check_dataloader_idx(self, result: Result) -> bool: - add_dataloader_idx = False - try: - if len(result.keys()) > 1: - random_key = [*result.keys()][-1] - add_dataloader_idx = result["meta"][random_key]["dataloader_idx"] is not None - return add_dataloader_idx - return add_dataloader_idx - except Exception: - return add_dataloader_idx - - def get_lastest_from_func_name(self, func_name, *args, latest=True, **kwargs): + random_key = [*result.keys()][-1] + add_dataloader_idx = result["meta"][random_key]["dataloader_idx"] is not None + return add_dataloader_idx + + def get_lastest_from_func_name(self, latest_result, func_name: str, *args, **kwargs) -> Dict: results = {} - if latest: - for dl_idx in range(self.num_dataloaders): - dl_idx = str(dl_idx) - if self._internal_type == ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP: - latest_result = self._internals[dl_idx][-1] - else: - latest_result = self.get_latest_from_dict(dl_idx) - add_dataloader_idx = self.check_dataloader_idx(latest_result) - func = getattr(latest_result, func_name) - results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs)) - return results - raise NotImplementedError + add_dataloader_idx = self.check_dataloader_idx(latest_result) + func = getattr(latest_result, func_name) + results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs)) + return results - def get_batch_pbar_metrics(self, latest=True, *args, **kwargs): - return self.get_lastest_from_func_name("get_batch_pbar_metrics", *args, latest=latest, **kwargs) + def run_lastest_batch_metrics_with_func_name(self, func_name, *args, **kwargs) -> List[Dict]: + """ + This function used cache_ref and cache_result to optimize loading metrics - def get_batch_log_metrics(self, latest=True, *args, **kwargs): - return self.get_lastest_from_func_name("get_batch_log_metrics", *args, latest=latest, **kwargs) + Context: As we update the logger_connector metrics on every `self.log` call, + and it can be pretty time consuming, especially when logging outside batch loop. + + HookResultStore keeps track of its latest added result object, + and cache its pbar and log metrics if already called on, + """ + results = [] + for dl_idx in range(self.num_dataloaders): + dl_idx = str(dl_idx) + latest_result = self._latest_ref[dl_idx] + result = self.get_lastest_from_func_name(latest_result, func_name, *args, **kwargs) + results.append(result) + return results + + def get_batch_pbar_metrics(self, *args, **kwargs): + return self.run_lastest_batch_metrics_with_func_name("get_batch_pbar_metrics", + *args, + **kwargs) + + def get_batch_log_metrics(self, *args, **kwargs): + return self.run_lastest_batch_metrics_with_func_name("get_batch_log_metrics", + *args, + **kwargs) def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> None: if isinstance(opt_metric, Result): func = getattr(opt_metric, func_name) metrics_to_log = func( *args, - add_dataloader_idx=self.add_dataloader_idx, + add_dataloader_idx=self.has_several_dataloaders, **kwargs) - results.update(metrics_to_log) + results.append(metrics_to_log) else: raise Exception("The provided opt_metric should be a Result Object. Something is wrong") - def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> Mapping: - results = {} + def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> List[Dict]: + results = [] for dl_idx in range(self.num_dataloaders): dl_idx = str(dl_idx) opt_metrics = self._internals_reduced[dl_idx] @@ -169,13 +165,13 @@ def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> Mapping: self.run_epoch_func(results, opt_metrics, func_name, *args, **kwargs) return results - def get_epoch_pbar_metrics(self, *args, **kwargs) -> Mapping: + def get_epoch_pbar_metrics(self, *args, **kwargs) -> List[Dict]: return self.get_epoch_from_func_name("get_epoch_pbar_metrics") - def get_epoch_log_metrics(self, *args, **kwargs) -> Mapping: + def get_epoch_log_metrics(self, *args, **kwargs) -> List[Dict]: return self.get_epoch_from_func_name("get_epoch_log_metrics") - def get_forked_metrics(self, *args, **kwargs) -> Mapping: + def get_forked_metrics(self, *args, **kwargs) -> List[Dict]: return self.get_epoch_from_func_name("get_forked_metrics") @staticmethod @@ -211,6 +207,8 @@ def append(self, result, dataloader_idx=None, extra_info: dict = {}) -> None: self._append_to_structure(self._internals[primary_key], opt_idx, batch_idx, result) + self._latest_ref[primary_key] = result + # [dataloader_idx] is a list else: self._internal_type = ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP @@ -218,6 +216,8 @@ def append(self, result, dataloader_idx=None, extra_info: dict = {}) -> None: self._internals[primary_key] = [] self._internals[primary_key].append(result) + self._latest_ref[primary_key] = result + def auto_reduce_results_on_epoch_end(self) -> None: """ This function is called to reduce `self._internals` Result object. @@ -271,7 +271,7 @@ def auto_reduce_results_on_epoch_end(self) -> None: self._internals_reduced[dl_idx][str(opt_idx)] = opt_outputs # free memory - del self._internals[dl_idx] + del self._internals[dl_idx][opt_idx] else: # no need to reduce as called only once if len(epoch_metrics) == 1: @@ -301,13 +301,9 @@ def __repr__(self): class EpochResultStore: """ This class is defined for internal usage. - It holds all metrics logged using the self.log function using `HookResultStore` object. - The internal datastructure is as follow: - self._internals = {"fx_name_0": HookResultStore(), ..., "fx_name_n": HookResultStore()} - Pseudo Code Example: ``` model._current_fx_name = 'something' @@ -315,7 +311,6 @@ class EpochResultStore: model.log('a', ...) epoch_result_store.cache_result() ``` - """ def __init__(self, trainer, stage): self.trainer = trainer @@ -365,7 +360,7 @@ def current_model_info(self): model_ref = self.trainer.get_model() # extract hook information fx_name = model_ref._current_hook_fx_name - if fx_name == '': + if fx_name is None: fx_name = model_ref._current_fx_name dataloader_idx = model_ref._current_dataloader_idx return fx_name, dataloader_idx @@ -398,7 +393,7 @@ def cache_result(self) -> None: Result.attach_batch_size(self._batch_size, hook_result) self._internals[fx_name].append( - deepcopy(hook_result), + hook_result, dataloader_idx=dataloader_idx, extra_info=extra_info) @@ -456,18 +451,22 @@ def update_logger_connector(self, fx_name: str = None) -> None: logger_connector.callback_metrics.update(callback_metrics) logger_connector.callback_metrics.pop("epoch", None) - def run_batch_from_func_name(self, func_name) -> Mapping: - results = {} + def run_batch_from_func_name(self, func_name) -> Dict: + results = [] for fx_name, hook_result in self._internals.items(): func = getattr(hook_result, func_name) - results.update(func(latest=True, include_forked_originals=False)) - return results + results.append(func(include_forked_originals=False)) + return dict(ChainMap(*sum(results, []))) - def get_latest_batch_log_metrics(self) -> Mapping: - return self.run_batch_from_func_name("get_batch_log_metrics") + def get_latest_batch_log_metrics(self) -> Dict: + batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics") + batch_log_metrics.update(self.legacy_batch_log_metrics) + return batch_log_metrics - def get_latest_batch_pbar_metrics(self) -> Mapping: - return self.run_batch_from_func_name("get_batch_pbar_metrics") + def get_latest_batch_pbar_metrics(self) -> Dict: + batch_pbar_metrics = self.run_batch_from_func_name("get_batch_pbar_metrics") + batch_pbar_metrics.update(self.legacy_batch_pbar_metrics) + return batch_pbar_metrics @property def has_reduced(self) -> bool: @@ -495,27 +494,24 @@ def has_batch_loop_finished(self, has_batch_loop_finished): self._has_batch_loop_finished = has_batch_loop_finished self.update_logger_connector() - def run_epoch_by_func_name(self, func_name) -> Mapping: + def run_epoch_by_func_name(self, func_name) -> Dict: if not self.has_reduced: self.auto_reduce_results_on_epoch_end() - results = {} + results = [] for fx_name, hook_result in self._internals.items(): func = getattr(hook_result, func_name) - results.update(func()) - return results + results.append(func()) + return dict(ChainMap(*sum(results, []))) - def get_epoch_pbar_metrics(self) -> Mapping: + def get_epoch_pbar_metrics(self) -> Dict: return self.run_epoch_by_func_name("get_epoch_pbar_metrics") - def get_epoch_log_metrics(self) -> Mapping: + def get_epoch_log_metrics(self) -> Dict: return self.run_epoch_by_func_name("get_epoch_log_metrics") - def get_forked_metrics(self) -> Mapping: + def get_forked_metrics(self) -> Dict: return self.run_epoch_by_func_name("get_forked_metrics") - def get_reduced_metrics(self) -> Mapping: - return self.run_epoch_by_func_name("get_reduced_metrics") - def reset(self): self._internals = {} self._dataloader_idx: Union[int, None] = None @@ -523,6 +519,96 @@ def reset(self): self._opt_idx: Union[int, None] = None self._batch_size: Union[int, None] = None self._has_batch_loop_finished = False + self.legacy_batch_log_metrics = {} + self.legacy_batch_pbar_metrics = {} + + def __call__( + self, + fx_name: Optional[Union[str, int]] = None, + dl_idx: Optional[Union[str, int]] = None, + opt_idx: Optional[Union[str, int]] = None, + batch_idx: Optional[Union[str, int]] = None, + split_idx: Optional[Union[str, int]] = None, + reduced: bool = False, + ): + """ + This function is an helper to access stored data + + It access data from the HookResultStore. Please, + check its data structure for better understanding + + Data can be accessed with the following chains: + + IF REDUCED: + * IF accessing a fx_name defined in batch training loop: + fx_name -> dl_idx -> opt_idx -> batch_idx -> split_idx + * ELSE fx_name -> dl_idx -> batch_idx + ELSE: + * IF accessing a fx_name defined in batch training loop: + fx_name -> dl_idx -> opt_idx + * ELSE fx_name -> dl_idx + + Note: + As soon as a param is None, it breaks the chain and returns associated stored data. + + Example:: + + result: Result = self(fx_name="training_step", dl_idx="0", opt_idx="0", reduced=True) + result['train_loss_epoch'] # aggregated train_loss over one epoch. + + Args: + + fx_name: Hook name from ModelHooks or Callback. Example: `training_step` + + dl_idx: Dataloader idx in short. It starts from 0 to num_dataloaders - 1 + + opt_idx: Optimizer idx in short. It starts from 0 to num_optimizers - 1 + + batch_idx: Index of batch idx seen during batch training or evaluation. + Works only with reduced=False + + split_idx: Index of split idx in training loop when ttbt is used. + + reduced: Data are being aggregated on on_epoch_end. + Indicates if we want to access aggregated Result or not. + """ + + hook_result = self[str(fx_name)] + + dl_idx = str(dl_idx) if dl_idx is not None else None + opt_idx = str(opt_idx) if opt_idx is not None else None + batch_idx = str(batch_idx) if batch_idx is not None else None + split_idx = int(split_idx) if split_idx is not None else None + + internal_type = hook_result._internal_type + + if reduced: + result = hook_result._internals_reduced + else: + result = hook_result._internals + + if internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: + if not reduced: + if dl_idx is not None: + result = result[dl_idx] + if opt_idx is not None: + result = result[opt_idx] + if batch_idx is not None: + result = result[batch_idx] + if split_idx is not None: + result = result[split_idx] + else: + if dl_idx is not None: + result = result[dl_idx] + if opt_idx is not None: + result = result[opt_idx] + else: + if dl_idx is not None: + result = result[dl_idx] + if batch_idx and not reduced: + result = result[batch_idx] + + return result def __repr__(self): return f"{self.__class__.__name__}(stage={self._stage}, internals={self._internals})" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0a1cb836eda6d..946064660f818 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -44,25 +44,14 @@ def __init__(self, trainer): self._callback_hook_validator = CallbackHookNameValidator() self._current_stage = None - def cached_results(self, stage_or_testing: Union[str, bool]) -> Union[EpochResultStore, None]: - """ Function to access cached_results using str or bool. Bool is used only for testing""" - stage_or_testing = str(stage_or_testing) - stages = self._stages - if stage_or_testing in self._stages: - return self._cached_results[stage_or_testing] - if stage_or_testing in LOOKUP_TABLE: - # Acces using trainer.testing - stage = LOOKUP_TABLE[stage_or_testing] - return self._cached_results[stage] - raise MisconfigurationException( - f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {self._stages}" - f" or {LOOKUP_TABLE.keys()}" - ) + @property + def cached_results(self) -> Union[EpochResultStore, None]: + return self._cached_results[self._current_stage] def set_stage(self, stage_or_testing: str, reset:bool = False) -> None: self._current_stage = self._determine_stage(stage_or_testing) if reset: - self.cached_results(stage_or_testing).reset() + self.cached_results.reset() def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None: self._callback_hook_validator.check_logging_in_callbacks(current_hook_fx_name=hook_fx_name, @@ -75,17 +64,17 @@ def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataload model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None # track batch_size - self.cached_results(testing)._batch_size = Result.extract_batch_size(batch) + self.cached_results._batch_size = Result.extract_batch_size(batch) - def on_batch_start(self, split_idx: int, opt_idx: int, split_batch) -> None: - self._cached_results["train"]._split_idx = split_idx - self._cached_results["train"]._opt_idx = opt_idx - self._cached_results["train"]._batch_size = Result.extract_batch_size(split_batch) + def on_train_split_start(self, split_idx: int, opt_idx: int, split_batch) -> None: + self.cached_results._split_idx = split_idx + self.cached_results._opt_idx = opt_idx + self.cached_results._batch_size = Result.extract_batch_size(split_batch) def on_train_batch_end(self) -> None: - self._cached_results["train"]._split_idx = None - self._cached_results["train"]._opt_idx = None - self._cached_results["train"]._batch_size = None + self.cached_results._split_idx = None + self.cached_results._opt_idx = None + self.cached_results._batch_size = None def _determine_stage(self, stage_or_testing: Union[str, bool]) -> str: stage_or_testing = str(stage_or_testing) @@ -112,6 +101,16 @@ def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps): self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps self.trainer.log_every_n_steps = log_every_n_steps + @property + def should_flush_logs(self): + should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 + return should_flush or self.trainer.should_stop + + @property + def should_update_logs(self): + should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 + return should_log_every_n_steps or self.trainer.should_stop + def configure_logger(self, logger): if logger is True: version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) @@ -130,6 +129,53 @@ def configure_logger(self, logger): else: self.trainer.logger = logger + def cache_training_step_metrics(self, opt_closure_result): + """ + This function is responsible to update + logger_connector internals metrics holder based for depreceated logging + """ + using_results_obj = isinstance(opt_closure_result.training_step_output, Result) + + # temporary dict to collect metrics + logged_metrics_tmp = {} + pbar_metrics_tmp = {} + callback_metrics_tmp = {} + + if using_results_obj: + batch_log_metrics = opt_closure_result.training_step_output.get_batch_log_metrics( + include_forked_originals=False + ) + logged_metrics_tmp.update(batch_log_metrics) + + batch_pbar_metrics = opt_closure_result.training_step_output.get_batch_pbar_metrics( + include_forked_originals=False + ) + pbar_metrics_tmp.update(batch_pbar_metrics) + + forked_metrics = opt_closure_result.training_step_output.get_forked_metrics() + callback_metrics_tmp.update(forked_metrics) + callback_metrics_tmp.update(logged_metrics_tmp) + + else: + batch_log_metrics = opt_closure_result.training_step_output.log_metrics + logged_metrics_tmp.update(batch_log_metrics) + + callback_metrics = opt_closure_result.training_step_output.callback_metrics + callback_metrics_tmp.update(callback_metrics) + + batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end + pbar_metrics_tmp.update(batch_pbar_metrics) + + # track progress bar metrics + if len(pbar_metrics_tmp) > 0: + self.add_progress_bar_metrics(pbar_metrics_tmp) + + self.callback_metrics.update(callback_metrics_tmp) + + # save legacy log metrics + self.logged_metrics.update(logged_metrics_tmp) + self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp) + def log_metrics(self, metrics, grad_norm_dic, step=None): """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, @@ -396,8 +442,9 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod if num_loaders == 1: self.__process_eval_epoch_end_results_and_log_legacy_update(prog_bar_metrics, log_metrics, callback_metrics) - def on_train_epoch_end(self, epoch_output): - pass + def on_train_epoch_end(self): + # inform cached logger connector epoch finished + self.cached_results.has_batch_loop_finished = True def log_train_epoch_end_metrics(self, epoch_output, @@ -441,12 +488,10 @@ def log_train_epoch_end_metrics(self, # ------------------ if is_1_0_result: # lightning module hook - epoch_end_log_result = self.training_epoch_end(model, epoch_output, num_optimizers) + self.training_epoch_end(model, epoch_output, num_optimizers) # log/aggregate metrics automatically epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output) - epoch_log_metrics.update(epoch_end_log_result.get_epoch_log_metrics()) - epoch_progress_bar_metrics.update(epoch_end_log_result.get_epoch_pbar_metrics()) # TODO: deprecate 1.0 else: @@ -459,6 +504,14 @@ def log_train_epoch_end_metrics(self, ) epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics = out + # it will perform reduction over epoch and return log metrics + cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics() + cached_epoch_pbar_metrics = self.cached_results.get_epoch_pbar_metrics() + + # update + epoch_log_metrics.update(cached_epoch_log_metrics) + epoch_progress_bar_metrics.update(cached_epoch_pbar_metrics) + # -------------------------- # track results # -------------------------- @@ -475,15 +528,16 @@ def log_train_epoch_end_metrics(self, self.add_progress_bar_metrics(epoch_progress_bar_metrics) self.callback_metrics.update(epoch_progress_bar_metrics) + # reset epoch loop result for next epoch + self.cached_results.reset() + def training_epoch_end(self, model, epoch_output, num_optimizers): if not is_overridden('training_epoch_end', model=model): - return Result() + return # run training_epoch_end # refresh the result for custom logging at the epoch level model._current_fx_name = 'training_epoch_end' - model._results = Result() - epoch_output = self.__prepare_epoch_end_inputs(epoch_output) if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization: @@ -492,15 +546,11 @@ def training_epoch_end(self, model, epoch_output, num_optimizers): # lightningmodule hook epoch_output = model.training_epoch_end(epoch_output) - model._current_fx_name = '' - if epoch_output is not None: raise MisconfigurationException('training_epoch_end expects a return of None. ' 'HINT: remove the return statement in training_epoch_end') - - # user can ALSO log at the end of an epoch - new_epoch_end_logs = model._results - return new_epoch_end_logs + # capture logging + self.trainer.logger_connector.cache_logged_metrics() def __run_legacy_training_epoch_end( self, @@ -527,8 +577,12 @@ def __run_legacy_training_epoch_end( # run training_epoch_end # a list with a result per optimizer index + model._current_fx_name = 'training_epoch_end' epoch_output = model.training_epoch_end(epoch_output) + # capture logging + self.trainer.logger_connector.cache_logged_metrics() + if isinstance(epoch_output, Result): epoch_log_metrics = epoch_output.epoch_log_metrics epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics @@ -563,7 +617,7 @@ def __auto_reduce_results_on_epoch_end(self, epoch_output): # reduce across training steps opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs) - # with manual opt need 1+ metrics because meta is always there + # with manual opt need 1 + metrics because meta is always there if opt_outputs.minimize is not None: opt_outputs.minimize = opt_outputs.minimize.mean() epoch_log_metrics.update(opt_outputs.epoch_log_metrics) @@ -623,12 +677,9 @@ def __gather_result_across_time_and_optimizers(self, epoch_output): def log_train_step_metrics(self, batch_output): # when metrics should be logged - should_log_metrics = ( - (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 or self.trainer.should_stop - ) - if should_log_metrics or self.trainer.fast_dev_run: + if self.should_update_logs or self.trainer.fast_dev_run: # logs user requested information to logger - metrics = batch_output.batch_log_metrics + metrics = self.cached_results.get_latest_batch_log_metrics() grad_norm_dic = batch_output.grad_norm_dic if len(metrics) > 0 or len(grad_norm_dic) > 0: self.log_metrics(metrics, grad_norm_dic) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 89a242dbfd886..6ebab1ade0f1d 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -358,6 +358,9 @@ def __log_result_step_metrics(self, output, batch_idx): step_log_metrics = output.get_batch_log_metrics(include_forked_originals=False) step_pbar_metrics = output.get_batch_pbar_metrics(include_forked_originals=False) + cached_batch_log_metrics = \ + self.trainer.logger_connector.cached_results.get_latest_batch_log_metrics() + if len(step_log_metrics) > 0: # make the metrics appear as a different line in the same graph metrics_by_epoch = {} diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 49cf232f76ac7..2d4e2c0d9e4bd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -838,7 +838,25 @@ def call_setup_hook(self, model): self.setup(stage_name) model.setup(stage_name) + def _reset_result_and_set_hook_fx_name(self, hook_name): + model_ref = self.get_model() + if model_ref is not None: + # used to track current hook name called + model_ref._results = Result() + model_ref._current_hook_fx_name = hook_name + + def _cache_logged_metrics(self): + model_ref = self.get_model() + if model_ref is not None: + # capture logging for this hook + self.logger_connector.cache_logged_metrics() + def call_hook(self, hook_name, *args, **kwargs): + # temporary. Don't modify evaluation behaviour + if self.logger_connector._current_stage == "train": + # set hook_name to model + reset Result obj + self._reset_result_and_set_hook_fx_name(hook_name) + # always profile hooks with self.profiler.profile(hook_name): @@ -860,4 +878,8 @@ def call_hook(self, hook_name, *args, **kwargs): accelerator_hook = getattr(self.accelerator_backend, hook_name) output = accelerator_hook(*args, **kwargs) - return output + # temporary. Don't modify evaluation behaviour + if self.logger_connector._current_stage == "train": + # capture logging + self._cache_logged_metrics() + return output diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3845b7eb728ac..2f66f5b1a600e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -251,12 +251,15 @@ def on_train_epoch_start(self, epoch): self.trainer.call_hook("on_train_epoch_start") def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx): + # hook + self.trainer.call_hook('on_batch_end') + self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx) + # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs) - # hook - self.trainer.call_hook("on_batch_end") - self.trainer.call_hook("on_train_batch_end", epoch_end_outputs, batch, batch_idx, dataloader_idx) + # reset batch logger internals + self.trainer.logger_connector.on_train_batch_end() def reset_train_val_dataloaders(self, model): if not self.trainer.reload_dataloaders_every_epoch: @@ -305,13 +308,16 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss): def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging - model = self.trainer.get_model() - model._results = Result() - model._current_fx_name = "training_step" + model_ref = self.trainer.get_model() with self.trainer.profiler.profile("model_forward"): args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) + + # manually capture logged metrics + model_ref._current_fx_name = 'training_step' training_step_output = self.trainer.accelerator_backend.training_step(args) + self.trainer.logger_connector.cache_logged_metrics() + training_step_output = self.trainer.call_hook("training_step_end", training_step_output) training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( @@ -484,35 +490,6 @@ def _track_gradient_norm(self): grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) return grad_norm_dict - def log_training_step_metrics(self, opt_closure_result, batch_callback_metrics, batch_log_metrics): - # track callback metrics - callback_metrics = opt_closure_result.training_step_output.callback_metrics - - # decide which metrics to log (results vs dict return) - using_results_obj = isinstance(opt_closure_result.training_step_output, Result) - if using_results_obj: - metrics_to_log = opt_closure_result.training_step_output.get_batch_log_metrics( - include_forked_originals=False - ) - step_pbar_metrics = opt_closure_result.training_step_output.get_batch_pbar_metrics( - include_forked_originals=False - ) - forked_metrics = opt_closure_result.training_step_output.get_forked_metrics() - callback_metrics.update(forked_metrics) - else: - metrics_to_log = opt_closure_result.training_step_output.log_metrics - step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end - - # track batch log metrics - batch_log_metrics.append(metrics_to_log) - - # track progress bar metrics - if len(step_pbar_metrics) > 0: - self.trainer.logger_connector.add_progress_bar_metrics(step_pbar_metrics) - self.trainer.logger_connector.callback_metrics.update(step_pbar_metrics) - - batch_callback_metrics.append(callback_metrics) - def process_hiddens(self, opt_closure_result): hiddens = opt_closure_result.hiddens if isinstance(opt_closure_result.training_step_output, Result): @@ -578,6 +555,8 @@ def run_training_epoch(self): should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) if should_check_val: self.trainer.run_evaluation(test_mode=False) + # reset stage to train + self.trainer.logger_connector.set_stage("train") # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) @@ -586,7 +565,6 @@ def run_training_epoch(self): # update LR schedulers monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) - monitor_metrics.update(batch_output.batch_log_metrics) self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) self.trainer.checkpoint_connector.has_trained = True @@ -612,19 +590,19 @@ def run_training_epoch(self): # progress global step according to grads progress self.increment_accumulated_grad_global_step() + # epoch end hook + self.run_on_epoch_end_hook(epoch_output) + # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( - epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers + epoch_output, + self.checkpoint_accumulator, + self.early_stopping_accumulator, + self.num_optimizers ) - # hook - self.trainer.logger_connector.on_train_epoch_end(epoch_output) - # when no val loop is present or fast-dev-run still need to call checkpoints - self.check_checkpoint_callback(not (should_check_val or is_overridden("validation_step", model))) - - # epoch end hook - self.run_on_epoch_end_hook(epoch_output) + self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model))) # increment the global step once # progress global step according to grads progress @@ -634,12 +612,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms grad_norm_dic = {} - # track all metrics for callbacks - batch_callback_metrics = [] - - # track metrics to log - batch_log_metrics = [] - # bookkeeping using_results_obj = False self.trainer.hiddens = None @@ -683,8 +655,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) batch_outputs = self._process_closure_result( - batch_callback_metrics=batch_callback_metrics, - batch_log_metrics=batch_log_metrics, batch_outputs=batch_outputs, opt_idx=opt_idx, ) @@ -711,15 +681,18 @@ def train_step_and_backward_closure(): self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) else: - self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens) + self._curr_step_result = self.training_step( + split_batch, + batch_idx, + opt_idx, + self.trainer.hiddens + ) if self._curr_step_result is None: # user decided to skip optimization continue batch_outputs = self._process_closure_result( - batch_callback_metrics=batch_callback_metrics, - batch_log_metrics=batch_log_metrics, batch_outputs=batch_outputs, opt_idx=opt_idx, ) @@ -737,19 +710,9 @@ def train_step_and_backward_closure(): # update running loss + reset accumulated loss self.update_running_loss() - # collapse all metrics into one dict - batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} - - # track all metrics for callbacks - self.trainer.logger_connector.callback_metrics.update(batch_log_metrics) - self.trainer.logger_connector.callback_metrics.update( - {k: v for d in batch_callback_metrics for k, v in d.items() if v is not None} - ) - result = AttributeDict( signal=0, grad_norm_dic=grad_norm_dic, - batch_log_metrics=batch_log_metrics, training_step_output_for_epoch_end=batch_outputs, ) return result @@ -762,14 +725,14 @@ def block_ddp_sync_behaviour(self): yield def _process_closure_result( - self, batch_callback_metrics: list, batch_log_metrics: list, batch_outputs: list, opt_idx: int + self, batch_outputs: list, opt_idx: int ) -> list: opt_closure_result = self._curr_step_result if opt_closure_result is not None: - # log metrics - self.log_training_step_metrics(opt_closure_result, batch_callback_metrics, batch_log_metrics) + # cache metrics + self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result) # track hiddens self.trainer.hiddens = self.process_hiddens(opt_closure_result) @@ -842,8 +805,10 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None): self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) def run_on_epoch_end_hook(self, epoch_output): - self.trainer.call_hook("on_epoch_end") - self.trainer.call_hook("on_train_epoch_end", epoch_output) + self.trainer.call_hook('on_epoch_end') + self.trainer.call_hook('on_train_epoch_end', epoch_output) + + self.trainer.logger_connector.on_train_epoch_end() def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self._accumulated_batches_reached() @@ -898,10 +863,8 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): def save_loggers_on_train_batch_end(self): # when loggers should save to disk - should_save_log = ( - self.trainer.global_step + 1 - ) % self.trainer.flush_logs_every_n_steps == 0 or self.trainer.should_stop - if should_save_log or self.trainer.fast_dev_run: + should_flush_logs = self.trainer.logger_connector.should_flush_logs + if should_flush_logs or self.trainer.fast_dev_run: if self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() @@ -955,7 +918,7 @@ def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): model.toggle_optimizer(optimizer, opt_idx) # use to track metrics internally - self.trainer.logger_connector.on_batch_start(split_idx, opt_idx, split_batch) + self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch) def update_running_loss(self): accumulated_loss = self.accumulated_loss.mean() diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index a29fd3e5a1059..98d322ce0a3a2 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -73,7 +73,7 @@ def find_free_network_port() -> int: return port -def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None): +def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None): """ Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes @@ -85,26 +85,41 @@ def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional Return: gathered_result: list with size equal to the process group where gathered_result[i] corresponds to result tensor from process i - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - if group is None: - group = torch.distributed.group.WORLD + if group is None: + group = torch.distributed.group.WORLD - world_size = torch.distributed.get_world_size(group) + world_size = torch.distributed.get_world_size(group) - gathered_result = [torch.zeros_like(result) for _ in range(world_size)] + gathered_result = [torch.zeros_like(result) for _ in range(world_size)] - # sync and broadcast all - torch.distributed.barrier(group=group) - torch.distributed.all_gather(gathered_result, result, group) + # sync and broadcast all + torch.distributed.barrier(group=group) + torch.distributed.all_gather(gathered_result, result, group) - result = gathered_result - return result + return gathered_result def sync_ddp_if_available( result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None +) -> torch.Tensor: + """ + Function to reduce a tensor across worker processes during distributed training + Args: + result: the value to sync and reduce (typically tensor or number) + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum. + Can also be a string of 'avg', 'mean' to calculate the mean during reduction. + Return: + reduced value + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return sync_ddp(result, group=group, reduce_op=reduce_op) + return result + + +def sync_ddp( + result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -118,24 +133,22 @@ def sync_ddp_if_available( Return: reduced value """ + divide_by_world_size = False - if torch.distributed.is_available() and torch.distributed.is_initialized(): - divide_by_world_size = False - - if group is None: - group = torch.distributed.group.WORLD + if group is None: + group = torch.distributed.group.WORLD - if reduce_op is None: - reduce_op = torch.distributed.ReduceOp.SUM - elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): - reduce_op = torch.distributed.ReduceOp.SUM - divide_by_world_size = True + if reduce_op is None: + reduce_op = torch.distributed.ReduceOp.SUM + elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): + reduce_op = torch.distributed.ReduceOp.SUM + divide_by_world_size = True - # sync all processes before reduction - torch.distributed.barrier(group=group) - torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False) + # sync all processes before reduction + torch.distributed.barrier(group=group) + torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False) - if divide_by_world_size: - result = result / torch.distributed.get_world_size(group) + if divide_by_world_size: + result = result / torch.distributed.get_world_size(group) return result diff --git a/requirements/examples.txt b/requirements/examples.txt index e930579b8b369..0afa62f9ffa95 100644 --- a/requirements/examples.txt +++ b/requirements/examples.txt @@ -1,2 +1,2 @@ torchvision>=0.4.1,<0.9.0 -gym>=0.17.0 \ No newline at end of file +gym>=0.17.0 diff --git a/requirements/extra.txt b/requirements/extra.txt index dbd5f7515109e..be21317a1d826 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -1,8 +1,7 @@ # extended list of package dependencies to reach full functionality matplotlib>=3.1.1 -# no need to install with [pytorch] as pytorch is already installed and torchvision is required only for Horovod examples -horovod>=0.20.1 # v0.20.0 has problem with building the wheel/installation +horovod>=0.20.2 # no need to install with [pytorch] as pytorch is already installed omegaconf>=2.0.0 # scipy>=0.13.3 scikit-learn>=0.22.2 diff --git a/tests/README.md b/tests/README.md index 7fd3c90c0241e..8ef006c4d879a 100644 --- a/tests/README.md +++ b/tests/README.md @@ -30,7 +30,7 @@ To test models that require GPU make sure to run the above command on a GPU mach The GPU machine must have: 1. At least 2 GPUs. 2. [NVIDIA-apex](https://github.com/NVIDIA/apex#linux) installed. -3. [Horovod with NCCL](https://horovod.readthedocs.io/en/stable/gpus_include.html) support: `HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_GPU_BROADCAST=NCCL pip install horovod` +3. [Horovod with NCCL](https://horovod.readthedocs.io/en/stable/gpus_include.html) support: `HOROVOD_GPU_OPERATIONS=NCCL pip install horovod` ## Running Coverage diff --git a/tests/base/develop_utils.py b/tests/base/develop_utils.py index ba0d20c2c8389..9c88ba1b7e4d3 100644 --- a/tests/base/develop_utils.py +++ b/tests/base/develop_utils.py @@ -32,7 +32,7 @@ def assert_speed_parity_relative(pl_times, pt_times, max_diff: float = 0.1): f"lightning {diffs} was slower than PT (threshold {max_diff})" -def assert_speed_parity_absolute(pl_times, pt_times, nb_epochs, max_diff: float = 0.6): +def assert_speed_parity_absolute(pl_times, pt_times, nb_epochs, max_diff: float = 0.55): # assert speeds diffs = np.asarray(pl_times) - np.asarray(pt_times) # norm by vanila time diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 886e0db4e7854..bccc5262a5bda 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -307,7 +307,7 @@ def on_test_model_train(self): trainer.fit(model) - assert model.called == [ + expected = [ 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', @@ -341,10 +341,12 @@ def on_test_model_train(self): 'on_fit_end', ] + assert model.called == expected + model2 = HookedModel() trainer.test(model2) - assert model2.called == [ + expected = [ 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', @@ -356,3 +358,5 @@ def on_test_model_train(self): 'on_test_model_train', 'on_fit_end', ] + + assert model2.called == expected diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index d09d9387ea485..d0ae17d8fee5d 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -17,19 +17,25 @@ import shlex import subprocess import sys -from unittest.mock import patch +import numpy as np import pytest import torch +from sklearn.metrics import accuracy_score + import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.horovod_accelerator import HorovodAccelerator +from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult +from pytorch_lightning.metrics.classification.accuracy import Accuracy from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE from tests.base import EvalModelTemplate from tests.base.models import BasicGAN try: + import horovod from horovod.common.util import nccl_built except ImportError: HOROVOD_AVAILABLE = False @@ -235,6 +241,111 @@ def get_optimizer_params(optimizer): assert get_model_params(model.generator) == get_optimizer_params(trainer.optimizers[0]) assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1]) + +@pytest.mark.skipif(not HOROVOD_AVAILABLE, reason="Horovod is unavailable") +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +def test_result_reduce_horovod(tmpdir): + """Make sure result logging works with Horovod. + + This test mirrors tests/core/test_results.py::_ddp_test_fn + """ + tutils.reset_seed() + tutils.set_random_master_port() + + def hvd_test_fn(): + path_here = os.path.abspath(os.path.dirname(__file__)) + path_root = os.path.abspath(os.path.join(path_here, '..', '..')) + sys.path.insert(0, os.path.abspath(path_root)) + + from tests.base.boring_model import BoringModel + + import horovod.torch as hvd + + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + self.training_step_called = True + + tensor = torch.tensor([1.0]) + self.log("test_tensor", tensor, sync_dist=True, sync_dist_op='sum', + on_step=True, on_epoch=True) + + res = self._results + + # Check that `tensor` is summed across all ranks automatically + assert res["test_tensor"].item() == hvd.size(), \ + "Result-Log does not work properly with Horovod and Tensors" + + def training_epoch_end(self, outputs) -> None: + assert len(outputs) == 0 + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + + trainer.fit(model) + + horovod.run(hvd_test_fn, np=2) + + +@pytest.mark.skipif(not HOROVOD_AVAILABLE, reason="Horovod is unavailable") +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +def test_accuracy_metric_horovod(): + num_batches = 10 + batch_size = 16 + threshold = 0.5 + + def sk_metric(preds, target): + sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8) + sk_target = target.view(-1).numpy() + return accuracy_score(y_true=sk_target, y_pred=sk_preds) + + preds = torch.rand(num_batches, batch_size) + target = torch.randint(high=2, size=(num_batches, batch_size)) + + def _compute_batch(): + import horovod.torch as hvd + + trainer = Trainer( + fast_dev_run=True, + distributed_backend='horovod', + ) + + accelerator_backend = trainer.accelerator_connector.select_accelerator() + assert isinstance(accelerator_backend, HorovodAccelerator) + + metric = Accuracy(compute_on_step=True, + dist_sync_on_step=True, + dist_sync_fn=accelerator_backend.gather_all_tensors, + threshold=threshold) + + for i in range(hvd.rank(), num_batches, hvd.size()): + batch_result = metric(preds[i], target[i]) + if hvd.rank() == 0: + dist_preds = torch.stack([preds[i + r] for r in range(hvd.size())]) + dist_target = torch.stack([target[i + r] for r in range(hvd.size())]) + sk_batch_result = sk_metric(dist_preds, dist_target) + assert np.allclose(batch_result.numpy(), sk_batch_result) + + # check on all batches on all ranks + result = metric.compute() + assert isinstance(result, torch.Tensor) + + total_preds = torch.stack([preds[i] for i in range(num_batches)]) + total_target = torch.stack([target[i] for i in range(num_batches)]) + sk_result = sk_metric(total_preds, total_target) + + assert np.allclose(result.numpy(), sk_result) + + horovod.run(_compute_batch, np=2) + # @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") # def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): # hparams = EvalModelTemplate.get_default_hparams() diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py index 31b60e1d0be8b..6329480e10a11 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py @@ -14,6 +14,7 @@ """ Tests to ensure that the training loop works with a dict """ +import os from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning import Trainer from tests.base.deterministic_model import DeterministicModel @@ -125,6 +126,9 @@ def test_validation_step_dict_return(tmpdir): Test that val step can return a dict with all the expected keys and they end up in the correct place """ + + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_dict_return model.validation_step = model.validation_step_dict_return @@ -166,6 +170,8 @@ def test_val_step_step_end_no_return(tmpdir): """ Test that val step + val step end work (with no return in val step end) """ + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_dict_return model.validation_step = model.validation_step_dict_return @@ -197,6 +203,9 @@ def test_val_step_step_end(tmpdir): """ Test that val step + val step end work """ + + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_dict_return model.validation_step = model.validation_step_dict_return @@ -241,6 +250,9 @@ def test_no_val_step_end(tmpdir): """ Test that val step + val epoch end """ + + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_dict_return model.validation_step = model.validation_step_dict_return @@ -284,6 +296,9 @@ def test_full_val_loop(tmpdir): """ Test that val step + val step end + val epoch end """ + + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_dict_return model.validation_step = model.validation_step_dict_return diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py index 7e8588ce9f6b2..8d1aaf1b3c548 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py @@ -44,9 +44,10 @@ def test_training_step_dict(tmpdir): break out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) + assert out.signal == 0 - assert out.batch_log_metrics['log_acc1'] == 12.0 - assert out.batch_log_metrics['log_acc2'] == 7.0 + assert trainer.logger_connector.logged_metrics['log_acc1'] == 12.0 + assert trainer.logger_connector.logged_metrics['log_acc2'] == 7.0 train_step_out = out.training_step_output_for_epoch_end assert len(train_step_out) == 1 @@ -92,8 +93,8 @@ def training_step_with_step_end(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert out.batch_log_metrics['log_acc1'] == 14.0 - assert out.batch_log_metrics['log_acc2'] == 9.0 + assert trainer.logger_connector.logged_metrics['log_acc1'] == 14.0 + assert trainer.logger_connector.logged_metrics['log_acc2'] == 9.0 train_step_end_out = out.training_step_output_for_epoch_end pbar_metrics = train_step_end_out['progress_bar'] @@ -133,8 +134,8 @@ def test_full_training_loop_dict(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert out.batch_log_metrics['log_acc1'] == 14.0 - assert out.batch_log_metrics['log_acc2'] == 9.0 + assert trainer.logger_connector.logged_metrics['log_acc1'] == 14.0 + assert trainer.logger_connector.logged_metrics['log_acc2'] == 9.0 # get the output of the first optimizer train_step_end_out = out.training_step_output_for_epoch_end @@ -220,8 +221,8 @@ def test_train_step_epoch_end(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert out.batch_log_metrics['log_acc1'] == 12.0 - assert out.batch_log_metrics['log_acc2'] == 7.0 + assert trainer.logger_connector.logged_metrics['log_acc1'] == 12.0 + assert trainer.logger_connector.logged_metrics['log_acc2'] == 7.0 # outputs are for 1 optimizer and no tbptt train_step_end_out = out.training_step_output_for_epoch_end diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py index b5eae913ca428..2a66f743a49ef 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py @@ -15,6 +15,7 @@ Tests to ensure that the training loop works with a scalar """ import torch +import os from pytorch_lightning import Trainer from tests.base.deterministic_model import DeterministicModel @@ -46,7 +47,6 @@ def test_training_step_scalar(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) train_step_out = out.training_step_output_for_epoch_end @@ -84,7 +84,6 @@ def training_step_scalar_with_step_end(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) train_step_out = out.training_step_output_for_epoch_end @@ -104,6 +103,8 @@ def test_full_training_loop_scalar(tmpdir): Checks train_step + training_step_end + training_epoch_end (all with scalar return from train_step) """ + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_scalar_return model.training_step_end = model.training_step_end_scalar @@ -132,7 +133,6 @@ def test_full_training_loop_scalar(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) train_step_out = out.training_step_output_for_epoch_end @@ -152,6 +152,8 @@ def test_train_step_epoch_end_scalar(tmpdir): Checks train_step + training_epoch_end (NO training_step_end) (with scalar return) """ + os.environ['PL_DEV_DEBUG'] = '0' + model = DeterministicModel() model.training_step = model.training_step_scalar_return model.training_step_end = None @@ -176,7 +178,6 @@ def test_train_step_epoch_end_scalar(tmpdir): out = trainer.train_loop.run_training_batch(batch, batch_idx, 0) assert out.signal == 0 - assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict) assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict) train_step_out = out.training_step_output_for_epoch_end diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index 0f27f2ca4fef4..08936f89eb9f8 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -17,15 +17,19 @@ import os import torch import pytest - +from copy import deepcopy from pytorch_lightning.trainer import Trainer from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector +from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore +from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel, RandomDataset class Helper: - def decorator_with_arguments(fx_name='', hook_fx_name=''): + def decorator_with_arguments(fx_name='', hook_fx_name=None): def decorator(func): def wrapper(self, *args, **kwargs): # Set information @@ -65,9 +69,9 @@ def training_step(self, batch, batch_idx): self.log("train_loss", loss, on_step=True, on_epoch=True) return {"loss": loss} - def val_dataloader(self): - return [torch.utils.data.DataLoader(RandomDataset(32, 64)), - torch.utils.data.DataLoader(RandomDataset(32, 64))] + def on_train_epoch_end(self, outputs): + # save objects as it will be reset at the end of epoch. + self.train_results = deepcopy(self.trainer.logger_connector.cached_results) model = TestModel() model.val_dataloader = None @@ -82,21 +86,31 @@ def val_dataloader(self): ) trainer.fit(model) - assert len(trainer.logger_connector.cached_results("train")['training_step']['0']['0']) == 2 - assert trainer.logger_connector.cached_results("train")['training_step']['0']['0']['0'][0]["train_loss"] == model.train_losses[0] - assert trainer.logger_connector.cached_results("train")['training_step']['0']['0']['1'][0]["train_loss"] == model.train_losses[1] + train_results = model.train_results - # assert reduction didn't happen yet - assert trainer.logger_connector.cached_results("train").has_reduced is False + assert len(train_results(fx_name="training_step", dl_idx="0", opt_idx="0")) == 2 + generated = train_results(fx_name="training_step", + dl_idx="0", + opt_idx="0", + batch_idx="0", + split_idx="0")["train_loss"] + assert generated == model.train_losses[0] + generated = train_results(fx_name="training_step", + dl_idx="0", + opt_idx="0", + batch_idx="1", + split_idx="0")["train_loss"] + assert generated == model.train_losses[1] - # Launch reduction - trainer.logger_connector.cached_results("train").has_batch_loop_finished = True + assert train_results.has_reduced is not True - # assert reduction did happen - assert trainer.logger_connector.cached_results("train").has_reduced is True + train_results.has_batch_loop_finished = True - assert trainer.logger_connector.cached_results("train")["training_step"]\ - ._internals_reduced["0"]["0"]['train_loss_epoch'].item() == torch.stack(model.train_losses).mean().item() + assert train_results.has_reduced is True + + generated = train_results(fx_name="training_step", dl_idx="0", opt_idx="0", reduced=True)['train_loss_epoch'].item() + excepted = torch.stack(model.train_losses).mean().item() + assert generated == excepted def test__logger_connector__epoch_result_store__train__ttbt(tmpdir): @@ -163,6 +177,10 @@ def train_dataloader(self): sampler=None, ) + def on_train_epoch_end(self, outputs): + # save objects as it will be reset at the end of epoch. + self.train_results = deepcopy(self.trainer.logger_connector.cached_results) + model = TestModel() model.training_epoch_end = None model.example_input_array = torch.randn(5, truncated_bptt_steps) @@ -178,19 +196,22 @@ def train_dataloader(self): ) trainer.fit(model) - assert len(trainer.logger_connector.cached_results("train")['training_step']['0']['0']['0']) == len(model.train_losses) + train_results = model.train_results + + generated = train_results(fx_name="training_step", dl_idx="0", opt_idx="0", batch_idx="0") + assert len(generated) == len(model.train_losses) # assert reduction didn't happen yet - assert trainer.logger_connector.cached_results("train").has_reduced is False + assert train_results.has_reduced is False # Launch reduction - trainer.logger_connector.cached_results("train").has_batch_loop_finished = True + train_results.has_batch_loop_finished = True # assert reduction did happen - assert trainer.logger_connector.cached_results("train").has_reduced is True + assert train_results.has_reduced is True - assert trainer.logger_connector.cached_results("train")['training_step']\ - ._internals_reduced['0']['0']["a_epoch"].item() == torch.stack(model.train_losses).mean().item() + generated = train_results(fx_name="training_step", dl_idx="0", opt_idx="0", reduced=True)['a_epoch'].item() + assert generated == torch.stack(model.train_losses).mean().item() @pytest.mark.parametrize('num_dataloaders', [1, 2]) @@ -206,11 +227,11 @@ class TestModel(BoringModel): test_losses = {} @Helper.decorator_with_arguments(fx_name="test_step") - def test_step(self, batch, batch_idx, dataloader_idx=0): + def test_step(self, batch, batch_idx, dl_idx=0): output = self.layer(batch) loss = self.loss(batch, output) - primary_key = str(dataloader_idx) + primary_key = str(dl_idx) if primary_key not in self.test_losses: self.test_losses[primary_key] = [] @@ -239,11 +260,126 @@ def test_dataloader(self): ) trainer.test(model) - assert len(trainer.logger_connector.cached_results("test")["test_step"]._internals) == num_dataloaders + test_results = trainer.logger_connector._cached_results["test"] + + generated = test_results(fx_name="test_step") + assert len(generated) == num_dataloaders + for dl_idx in range(num_dataloaders): - assert len(trainer.logger_connector.cached_results("test")["test_step"]._internals[str(dl_idx)]) == limit_test_batches - trainer.logger_connector.cached_results("test").has_batch_loop_finished = True + generated = len(test_results(fx_name="test_step", dl_idx=str(dl_idx))) + assert generated == limit_test_batches + + test_results.has_batch_loop_finished = True + for dl_idx in range(num_dataloaders): expected = torch.stack(model.test_losses[str(dl_idx)]).mean() - generated = trainer.logger_connector.cached_results("test")["test_step"]._internals_reduced[str(dl_idx)]["test_loss_epoch"] + generated = test_results(fx_name="test_step", dl_idx=str(dl_idx), reduced=True)["test_loss_epoch"] assert abs(expected.item() - generated.item()) < 1e-6 + + +def test_call_back_validator(tmpdir): + + funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')]) + + callbacks_func = [ + 'on_after_backward', + 'on_batch_end', + 'on_batch_start', + 'on_before_zero_grad', + 'on_epoch_end', + 'on_epoch_start', + 'on_fit_end', + 'on_fit_start', + 'on_init_end', 'on_init_start', + 'on_keyboard_interrupt', + 'on_load_checkpoint', + 'on_pretrain_routine_end', + 'on_pretrain_routine_start', + 'on_sanity_check_end', + 'on_sanity_check_start', + 'on_save_checkpoint', + 'on_test_batch_end', + 'on_test_batch_start', + 'on_test_end', + 'on_test_epoch_end', + 'on_test_epoch_start', + 'on_test_start', + 'on_train_batch_end', + 'on_train_batch_start', + 'on_train_end', + 'on_train_epoch_end', + 'on_train_epoch_start', + 'on_train_start', + 'on_validation_batch_end', + 'on_validation_batch_start', + 'on_validation_end', + 'on_validation_epoch_end', + 'on_validation_epoch_start', + 'on_validation_start', + 'setup', + 'teardown', + ] + + not_supported = [ + "on_fit_end", + "on_fit_start", + "on_init_end", + "on_init_start", + "on_keyboard_interrupt", + "on_load_checkpoint", + "on_pretrain_routine_end", + "on_pretrain_routine_start", + "on_sanity_check_end", + "on_sanity_check_start", + "on_save_checkpoint", + "on_test_end", + "on_train_end", + "on_validation_end", + "setup", + "teardown", + ] + + assert funcs_name == callbacks_func, """Detected new callback function. + Need to add its logging permission to CallbackHookNameValidator and update this test""" + + validator = CallbackHookNameValidator() + + for func_name in funcs_name: + # This summurize where and what is currently possible to log using `self.log` function. + is_stage = "train" in func_name or "test" in func_name or "validation" in func_name + is_start = "start" in func_name or "batch" in func_name + on_step = is_stage and is_start + on_epoch = True + # creating allowed condition + allowed = ( + is_stage + or "batch" in func_name + or "epoch" in func_name + or "grad" in func_name + or "backward" in func_name + ) + allowed = ( + allowed + and "pretrain" not in func_name + and func_name not in ["on_train_end", "on_test_end", "on_validation_end"] + ) + if allowed: + validator.check_logging_in_callbacks(current_hook_fx_name=func_name, + on_step=on_step, + on_epoch=on_epoch) + if not is_start and is_stage: + with pytest.raises(MisconfigurationException, match="function supports only"): + validator.check_logging_in_callbacks(current_hook_fx_name=func_name, + on_step=True, + on_epoch=on_epoch) + else: + assert func_name in not_supported + with pytest.raises(MisconfigurationException, match="function doesn't support"): + validator.check_logging_in_callbacks(current_hook_fx_name=func_name, + on_step=on_step, + on_epoch=on_epoch) + + result = validator.check_logging_in_callbacks(current_hook_fx_name=None, + on_step=None, + on_epoch=None) + assert result is None diff --git a/tests/trainer/logging_tests/test_distributed_logging.py b/tests/trainer/logging_tests/test_distributed_logging.py index 5fdd021dcc0ae..a600317a024c9 100644 --- a/tests/trainer/logging_tests/test_distributed_logging.py +++ b/tests/trainer/logging_tests/test_distributed_logging.py @@ -26,8 +26,9 @@ def on_pretrain_routine_end(self) -> None: with mock.patch('pytorch_lightning.loggers.base.LightningLoggerBase.agg_and_log_metrics') as m: self.trainer.logger_connector.log_metrics({'a': 2}, {}) logged_times = m.call_count - expected = 1 if self.global_rank == 0 else 0 - assert logged_times == expected, 'actual logger called from non-global zero' + expected = int(self.trainer.is_global_zero) + msg = f'actual logger called from non-global zero, logged_times: {logged_times}, expected: {expected}' + assert logged_times == expected, msg @pytest.mark.skipif(platform.system() == "Windows", diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 414264894e639..60ff33b402e4b 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -14,15 +14,22 @@ """ Tests to ensure that the training loop works with a dict (1.0) """ -from pytorch_lightning.core.lightning import LightningModule -from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset + import os -import torch +import collections import pytest +import itertools +import numpy as np +import torch +from torch.utils.data import Dataset + +import pytorch_lightning as pl +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning import Trainer, callbacks + +from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset from tests.base.deterministic_model import DeterministicModel -from torch.utils.data import Dataset def test__training_step__log(tmpdir): @@ -324,12 +331,12 @@ def training_step(self, batch, batch_idx, hiddens): assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) - loss_val = torch.nn.functional.mse_loss( + loss = torch.nn.functional.mse_loss( pred, y_tensor.view(batch_size, truncated_bptt_steps)) - self.log('a', loss_val, on_epoch=True) + self.log('a', loss, on_epoch=True) - return {'loss': loss_val, 'hiddens': self.test_hidden} + return {'loss': loss, 'hiddens': self.test_hidden} def on_train_epoch_start(self) -> None: self.test_hidden = None @@ -398,8 +405,10 @@ def val_dataloader(self): generated = set(trainer.logger_connector.logged_metrics) expected = { + 'a_step', 'a_epoch', - 'n_step/epoch_0', 'n_epoch', + 'n_step/epoch_0', + 'n_epoch', 'epoch' } @@ -489,3 +498,187 @@ def validation_step(self, batch, batch_idx): weights_summary=None, ) trainer.fit(model, train_data, val_data) + + +def test_log_works_in_train_callback(tmpdir): + """ + Tests that log can be called within callback + """ + + os.environ['PL_DEV_DEBUG'] = '1' + + class TestCallback(callbacks.Callback): + + # helpers + count = 1 + choices = [False, True] + # used to compute expected values + callback_funcs_called = collections.defaultdict(list) + funcs_called_count = collections.defaultdict(int) + funcs_attr = {} + + def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx, + on_steps=[], on_epochs=[], prob_bars=[]): + self.funcs_called_count[func_name] += 1 + for idx, (on_step, on_epoch, prog_bar) in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))): + # run logging + custom_func_name = f"{func_idx}_{idx}_{func_name}" + pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step, + on_epoch=on_epoch, prog_bar=prog_bar) + + # catch information for verification + + # on on_train_start is outside the main loop. Won't be called + if func_name == "on_train_start": + self.callback_funcs_called[func_name].append([self.count * func_idx]) + + # Saved only values from second epoch, so we can compute its mean or latest. + if pl_module.trainer.current_epoch == 1: + self.callback_funcs_called[func_name].append([self.count * func_idx]) + + forked = on_step and on_epoch + + self.funcs_attr[custom_func_name] = { + "on_step": on_step, + "on_epoch": on_epoch, + "prog_bar": prog_bar, + "forked": forked, + "func_name": func_name} + + if on_step and on_epoch: + self.funcs_attr[f"{custom_func_name}_step"] = { + "on_step": True, + "on_epoch": False, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name} + + self.funcs_attr[f"{custom_func_name}_epoch"] = { + "on_step": False, + "on_epoch": True, + "prog_bar": prog_bar, + "forked": False, + "func_name": func_name} + + def on_train_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_train_start', 1, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_start', 2, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_train_epoch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_train_epoch_start', 3, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_batch_start(self, trainer, pl_module): + self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_train_batch_start', 5, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_batch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_train_batch_end', 7, on_steps=self.choices, + on_epochs=self.choices, prob_bars=self.choices) + # used to make sure aggregation works fine. + # we should obtain func[value * c for c in range(1, max_epochs * limit_train_batches)]) + # with func = np.mean if on_epoch else func = np.max + self.count += 1 + + def on_epoch_end(self, trainer, pl_module): + self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + def on_train_epoch_end(self, trainer, pl_module, outputs): + self.make_logging(pl_module, 'on_train_epoch_end', 9, on_steps=[False], + on_epochs=self.choices, prob_bars=self.choices) + + class TestModel(BoringModel): + + manual_loss = [] + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.manual_loss.append(loss) + self.log('train_loss', loss) + return {"loss": loss} + + max_epochs = 2 + limit_train_batches = 2 + model = TestModel() + test_callback = TestCallback() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=0, + limit_test_batches=0, + val_check_interval=0., + num_sanity_val_steps=0, + max_epochs=max_epochs, + callbacks=[test_callback] + ) + trainer.fit(model) + + assert test_callback.funcs_called_count["on_train_start"] == 1 + assert test_callback.funcs_called_count["on_epoch_start"] == 2 + assert test_callback.funcs_called_count["on_train_epoch_start"] == 2 + assert test_callback.funcs_called_count["on_batch_start"] == 4 + assert test_callback.funcs_called_count["on_train_batch_start"] == 4 + assert test_callback.funcs_called_count["on_batch_end"] == 4 + assert test_callback.funcs_called_count["on_train_batch_end"] == 4 + assert test_callback.funcs_called_count["on_epoch_end"] == 2 + assert test_callback.funcs_called_count["on_train_epoch_end"] == 2 + + # Make sure the func_name exists within callback_metrics. If not, we missed some + callback_metrics_keys = [*trainer.callback_metrics.keys()] + for func_name in test_callback.callback_funcs_called.keys(): + is_in = False + for callback_metrics_key in callback_metrics_keys: + if func_name in callback_metrics_key: + is_in = True + assert is_in, (func_name, callback_metrics_keys) + + # function used to describe expected return logic + def get_expected_output(func_attr, original_values): + if func_attr["on_epoch"] and not func_attr["on_step"]: + # Apply mean on values + expected_output = np.mean(original_values) + else: + # Keep the latest value + expected_output = np.max(original_values) + return expected_output + + # Make sure the func_name output equals the average from all logged values when on_epoch true + # pop extra keys + trainer.callback_metrics.pop("debug_epoch") + assert trainer.logged_metrics["train_loss"] == model.manual_loss[-1] + assert trainer.callback_metrics["train_loss"] == model.manual_loss[-1] + trainer.callback_metrics.pop("train_loss") + + for func_name, output_value in trainer.callback_metrics.items(): + if torch.is_tensor(output_value): + output_value = output_value.item() + # get creation attr + func_attr = test_callback.funcs_attr[func_name] + + # retrived orginal logged values + original_values = test_callback.callback_funcs_called[func_attr["func_name"]] + + # compute expected output and compare to actual one + expected_output = get_expected_output(func_attr, original_values) + assert float(output_value) == float(expected_output) + + for func_name, func_attr in test_callback.funcs_attr.items(): + if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]) and not func_attr["forked"]: + assert func_name in trainer.logger_connector.progress_bar_metrics + else: + assert func_name not in trainer.logger_connector.progress_bar_metrics