Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dali MNIST example #3721

Merged
merged 36 commits into from
Nov 6, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
29bdc12
add MNIST DALI example, update README.md
irustandi Sep 29, 2020
dbef9de
Merge remote-tracking branch 'upstream/master' into dali_support
irustandi Sep 29, 2020
743afb7
Fix PEP8 warnings
irustandi Sep 29, 2020
cd9c892
reformatted using black
irustandi Sep 29, 2020
221fe9b
add mnist_dali to test_examples.py
irustandi Sep 29, 2020
4b4ebe9
Add documentation as docstrings
irustandi Sep 29, 2020
4cb797e
add nvidia-pyindex and nvidia-dali-cuda100
irustandi Sep 30, 2020
3b3a5dd
replace nvidia-pyindex with --extra-index-url
irustandi Sep 30, 2020
31fa2a9
mark mnist_dali test as Linux and GPU only
irustandi Sep 30, 2020
daa9a4b
adjust CUDA docker and examples.txt, fix import error in test_example…
irustandi Sep 30, 2020
46eb905
Merge remote-tracking branch 'upstream/master' into dali_support
irustandi Sep 30, 2020
780d518
adjust the GPU check
irustandi Sep 30, 2020
b0fce24
Merge remote-tracking branch 'upstream/master' into dali_support
irustandi Oct 20, 2020
0950111
Exit when DALI is not available
irustandi Oct 22, 2020
6206e99
Merge remote-tracking branch 'upstream/master' into dali_support
irustandi Oct 22, 2020
d5e5779
remove requirements-examples.txt and DALI pip install
irustandi Oct 22, 2020
9575a04
Refactored example, moved to new logging api, added runtime check for…
Nov 4, 2020
f6a7562
Merge branch 'master' into dali_support
Nov 4, 2020
8d91128
Patch to reflect the mnist example module
Nov 4, 2020
3c6998d
add req.
Borda Nov 4, 2020
d256e6d
Apply suggestions from code review
Borda Nov 4, 2020
832b5e0
Removed requirement as it breaks CPU install, added note in README to…
Nov 4, 2020
6effea0
Merge branch 'master' into dali_support
SeanNaren Nov 5, 2020
7751cbd
add DALI to Drone
Borda Nov 5, 2020
6472e7f
test examples
Borda Nov 5, 2020
abb1d6b
Apply suggestions from code review
Borda Nov 5, 2020
a6223aa
imports
Borda Nov 5, 2020
3d3e75f
ABC
Borda Nov 5, 2020
7b468e2
Merge branch 'master' into dali_support
SeanNaren Nov 5, 2020
61da5e1
cuda
Borda Nov 5, 2020
67b8e2f
Merge branch 'dali_support' of https://github.com/irustandi/pytorch-l…
Borda Nov 5, 2020
c5cb549
cuda
Borda Nov 5, 2020
de6f433
Merge branch 'master' into dali_support
SeanNaren Nov 5, 2020
8c09298
pip DALI
Borda Nov 6, 2020
f3408b3
Merge branch 'master' into dali_support
SeanNaren Nov 6, 2020
f7afb45
Move build into init function
Nov 6, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pl_examples/basic_examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@ python mnist.py
python mnist.py --gpus 2 --distributed_backend 'dp'
```

---
---
#### MNIST with DALI
The MNIST example above using [NVIDIA DALI](https://developer.nvidia.com/DALI).
```bash
python mnist_dali.py
```

---
#### Image classifier
Generic image classifier with an arbitrary backbone (ie: a simple system)
```bash
Expand Down
202 changes: 202 additions & 0 deletions pl_examples/basic_examples/mnist_dali.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser

import numpy as np
from random import shuffle

import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

try:
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
except Exception as e:
Borda marked this conversation as resolved.
Show resolved Hide resolved
from tests.base.datasets import MNIST

from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIClassificationIterator


class ExternalMNISTInputIterator(object):
def __init__(self, mnist_ds, batch_size):
self.batch_size = batch_size
self.mnist_ds = mnist_ds
self.indices = list(range(len(self.mnist_ds)))
shuffle(self.indices)

def __iter__(self):
self.i = 0
self.n = len(self.mnist_ds)
return self

def __next__(self):
batch = []
labels = []
for _ in range(self.batch_size):
index = self.indices[self.i]
img, label = self.mnist_ds[index]
batch.append(img.numpy())
labels.append(np.array([label], dtype=np.uint8))
self.i = (self.i + 1) % self.n
return (batch, labels)


class ExternalSourcePipeline(Pipeline):
def __init__(self, batch_size, eii, num_threads, device_id):
super(ExternalSourcePipeline, self).__init__(batch_size,
num_threads,
device_id,
seed=12)
self.source = ops.ExternalSource(source=eii, num_outputs=2)

def define_graph(self):
images, labels = self.source()
return images, labels


# we extend DALIClassificationIterator with the __len__() function so that we can call len() on it
class DALIClassificationLoader(DALIClassificationIterator):
def __init__(
self,
pipelines,
size=-1,
reader_name=None,
auto_reset=False,
fill_last_batch=True,
dynamic_shape=False,
last_batch_padded=False,
):
super().__init__(pipelines,
size,
reader_name,
auto_reset,
fill_last_batch,
dynamic_shape,
last_batch_padded)

def __len__(self):
batch_count = self._size // (self._num_gpus * self.batch_size)
last_batch = 1 if self._fill_last_batch else 0
return batch_count + last_batch


class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()

self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x

def split_batch(self, batch):
return batch[0]['data'], batch[0]['label'].squeeze().long()

def training_step(self, batch, batch_idx):
x, y = self.split_batch(batch)
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss

def validation_step(self, batch, batch_idx):
x, y = self.split_batch(batch)
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log('valid_loss', loss)
return result

def test_step(self, batch, batch_idx):
x, y = self.split_batch(batch)
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log('test_loss', loss)
return result

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser


def cli_main():
pl.seed_everything(1234)

# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
args = parser.parse_args()

# ------------
# data
# ------------
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

eii_train = ExternalMNISTInputIterator(mnist_train, args.batch_size)
eii_val = ExternalMNISTInputIterator(mnist_val, args.batch_size)
eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size)

pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0)
pipe_train.build()
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=False)

pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0)
pipe_val.build()
val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False)

pipe_test = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_test, num_threads=2, device_id=0)
pipe_test.build()
test_loader = DALIClassificationLoader(pipe_test, size=len(mnist_test), auto_reset=True, fill_last_batch=False)

# ------------
# model
# ------------
model = LitClassifier(args.hidden_dim, args.learning_rate)

# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)

# ------------
# testing
# ------------
trainer.test(test_dataloaders=test_loader)


if __name__ == '__main__':
cli_main()