Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a notebook example to reach a quick baseline of ~94% accuracy on CIFAR #4818

Merged
merged 10 commits into from
Dec 10, 2020
394 changes: 394 additions & 0 deletions notebooks/06-cifar10-baseline.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,394 @@
{
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "06_cifar10_baseline.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "qMDj0BYNECU8"
},
"source": [
"<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-cifar10-pytorch-lightning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ECu0zDh8UXU8"
},
"source": [
"# PyTorch Lightning CIFAR10 ~94% Baseline Tutorial ⚡\n",
"\n",
"Train a Resnet to 94% accuracy on Cifar10!\n",
"\n",
"Main takeaways:\n",
"1. Experiment with different Learning Rate schedules and frequencies in the configure_optimizers method in pl.LightningModule\n",
"2. Use an existing Resnet architecture with modifications directly with Lightning\n",
"\n",
"---\n",
"\n",
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
" - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HYpMlx7apuHq"
},
"source": [
"### Setup\n",
"Lightning is easy to install. Simply `pip install pytorch-lightning`.\n",
"Also check out [bolts](https://github.com/PyTorchLightning/pytorch-lightning-bolts/) for pre-existing data modules and models."
]
},
{
"cell_type": "code",
"metadata": {
"id": "ziAQCrE-TYWG"
},
"source": [
"! pip install pytorch-lightning pytorch-lightning-bolts -qU"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "L-W_Gq2FORoU"
},
"source": [
"# Run this if you intend to use TPUs\n",
"# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n",
"# !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "wjov-2N_TgeS"
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.optim.lr_scheduler import OneCycleLR\n",
"from torch.optim.swa_utils import AveragedModel, update_bn\n",
"import torchvision\n",
"\n",
"import pytorch_lightning as pl\n",
"from pytorch_lightning.callbacks import LearningRateMonitor\n",
"from pytorch_lightning.metrics.functional import accuracy\n",
"from pl_bolts.datamodules import CIFAR10DataModule\n",
"from pl_bolts.transforms.dataset_normalizations import cifar10_normalization"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "54JMU1N-0y0g"
},
"source": [
"pl.seed_everything(7);"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "FA90qwFcqIXR"
},
"source": [
"### CIFAR10 Data Module\n",
"\n",
"Import the existing data module from `bolts` and modify the train and test transforms."
]
},
{
"cell_type": "code",
"metadata": {
"id": "S9e-W8CSa8nH"
},
"source": [
"batch_size = 32\n",
"\n",
"train_transforms = torchvision.transforms.Compose([\n",
" torchvision.transforms.RandomCrop(32, padding=4),\n",
" torchvision.transforms.RandomHorizontalFlip(),\n",
" torchvision.transforms.ToTensor(),\n",
" cifar10_normalization(),\n",
"])\n",
"\n",
"test_transforms = torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(),\n",
" cifar10_normalization(),\n",
"])\n",
"\n",
"cifar10_dm = CIFAR10DataModule(\n",
" batch_size=batch_size,\n",
" train_transforms=train_transforms,\n",
" test_transforms=test_transforms,\n",
" val_transforms=test_transforms,\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "SfCsutp3qUMc"
},
"source": [
"### Resnet\n",
"Modify the pre-existing Resnet architecture from TorchVision. The pre-existing architecture is based on ImageNet images (224x224) as input. So we need to modify it for CIFAR10 images (32x32)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "GNSeJgwvhHp-"
},
"source": [
"def create_model():\n",
" model = torchvision.models.resnet18(pretrained=False, num_classes=10)\n",
" model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" model.maxpool = nn.Identity()\n",
" return model"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "HUCj5TKsqty1"
},
"source": [
"### Lightning Module\n",
"Check out the [`configure_optimizers`](https://pytorch-lightning.readthedocs.io/en/stable/lightning_module.html#configure-optimizers) method to use custom Learning Rate schedulers. The OneCycleLR with SGD will get you to around 92-93% accuracy in 20-30 epochs and 93-94% accuracy in 40-50 epochs. Feel free to experiment with different LR schedules from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate"
]
},
{
"cell_type": "code",
"metadata": {
"id": "03OMrBa5iGtT"
},
"source": [
"class LitResnet(pl.LightningModule):\n",
" def __init__(self, lr=0.05):\n",
" super().__init__()\n",
"\n",
" self.save_hyperparameters()\n",
" self.model = create_model()\n",
"\n",
" def forward(self, x):\n",
" out = self.model(x)\n",
" return F.log_softmax(out, dim=1)\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" logits = F.log_softmax(self.model(x), dim=1)\n",
" loss = F.nll_loss(logits, y)\n",
" self.log('train_loss', loss)\n",
" return loss\n",
"\n",
" def evaluate(self, batch, stage=None):\n",
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" preds = torch.argmax(logits, dim=1)\n",
" acc = accuracy(preds, y)\n",
"\n",
" if stage:\n",
" self.log(f'{stage}_loss', loss, prog_bar=True)\n",
" self.log(f'{stage}_acc', acc, prog_bar=True)\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" self.evaluate(batch, 'val')\n",
"\n",
" def test_step(self, batch, batch_idx):\n",
" self.evaluate(batch, 'test')\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)\n",
" steps_per_epoch = 45000 // batch_size\n",
" scheduler_dict = {\n",
" 'scheduler': OneCycleLR(optimizer, 0.1, epochs=self.trainer.max_epochs, steps_per_epoch=steps_per_epoch),\n",
" 'interval': 'step',\n",
" }\n",
" return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3FFPgpAFi9KU"
},
"source": [
"model = LitResnet(lr=0.05)\n",
"model.datamodule = cifar10_dm\n",
"\n",
"trainer = pl.Trainer(\n",
" progress_bar_refresh_rate=20,\n",
" max_epochs=40,\n",
" gpus=1,\n",
" logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='resnet'),\n",
" callbacks=[LearningRateMonitor(logging_interval='step')],\n",
")\n",
"\n",
"trainer.fit(model, cifar10_dm)\n",
"trainer.test(model, datamodule=cifar10_dm);"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "lWL_WpeVIXWQ"
},
"source": [
"### Bonus: Use [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407) to get a boost on performance\n",
"\n",
"Use SWA from torch.optim to get a quick performance boost. Also shows a couple of cool features from Lightning:\n",
"- Use `training_epoch_end` to run code after the end of every epoch\n",
"- Use a pretrained model directly with this wrapper for SWA"
]
},
{
"cell_type": "code",
"metadata": {
"id": "bsSwqKv0t9uY"
},
"source": [
"class SWAResnet(LitResnet):\n",
" def __init__(self, trained_model, lr=0.01):\n",
" super().__init__()\n",
"\n",
" self.save_hyperparameters('lr')\n",
" self.model = trained_model\n",
" self.swa_model = AveragedModel(self.model)\n",
"\n",
" def forward(self, x):\n",
" out = self.swa_model(x)\n",
" return F.log_softmax(out, dim=1)\n",
"\n",
" def training_epoch_end(self, training_step_outputs):\n",
" self.swa_model.update_parameters(self.model)\n",
"\n",
" def validation_step(self, batch, batch_idx, stage=None):\n",
" x, y = batch\n",
" logits = F.log_softmax(self.model(x), dim=1)\n",
" loss = F.nll_loss(logits, y)\n",
" preds = torch.argmax(logits, dim=1)\n",
" acc = accuracy(preds, y)\n",
"\n",
" self.log(f'val_loss', loss, prog_bar=True)\n",
" self.log(f'val_acc', acc, prog_bar=True)\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)\n",
" return optimizer\n",
"\n",
" def on_train_end(self):\n",
" update_bn(self.datamodule.train_dataloader(), self.swa_model, device=self.device)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "cA6ZG7C74rjL"
},
"source": [
"swa_model = SWAResnet(model.model, lr=0.01)\n",
"swa_model.datamodule = cifar10_dm\n",
"\n",
"swa_trainer = pl.Trainer(\n",
" progress_bar_refresh_rate=20,\n",
" max_epochs=20,\n",
" gpus=1,\n",
" logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='swa_resnet'),\n",
")\n",
"\n",
"swa_trainer.fit(swa_model, cifar10_dm)\n",
"swa_trainer.test(swa_model, datamodule=cifar10_dm);"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "RRHMfGiDpZ2M"
},
"source": [
"# Start tensorboard.\n",
"%reload_ext tensorboard\n",
"%tensorboard --logdir lightning_logs/"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "RltpFGS-s0M1"
},
"source": [
"<code style=\"color:#792ee5;\">\n",
" <h1> <strong> Congratulations - Time to Join the Community! </strong> </h1>\n",
"</code>\n",
"\n",
"Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n",
"\n",
"### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
"The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
"\n",
"* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n",
"\n",
"### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n",
"The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n",
"\n",
"### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
"Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n",
"\n",
"* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
"\n",
"### Contributions !\n",
"The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n",
"\n",
"* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
"* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
"* You can also contribute your own notebooks with useful examples !\n",
"\n",
"### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
"\n",
"<img src=\"https://github.com/PyTorchLightning/pytorch-lightning/blob/master/docs/source/_images/logos/lightning_logo-name.png?raw=true\" width=\"800\" height=\"200\" />"
]
}
]
}
Loading