Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

update example/mnist, update colormap in visualization.py #134

Merged
merged 8 commits into from
Sep 23, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 24 additions & 46 deletions example/mnist/README.md
Original file line number Diff line number Diff line change
@@ -1,47 +1,25 @@
# Training Neural Networks on MNIST

Machine: Dual Xeon E5-2680 2.8GHz, Dual GTX 980, CUDA 7.0

| | 2 x E5-2680 | 1 x GTX 980 | 2 x GTX 980 |
| --- | --- | --- | --- |
| `mlp.py` | 40K img/sec | 103K img/sec | 60K img/sec |

Dual GPUs slow down the performance due to the tiny size of workload.

sample output using single GTX 980

```bash
~/mxnet/example/mnist $ python mlp.py
[20:52:47] src/io/iter_mnist.cc:84: MNISTIter: load 60000 images, shuffle=1, shape=(100,784)
[20:52:47] src/io/iter_mnist.cc:84: MNISTIter: load 10000 images, shuffle=1, shape=(100,784)
INFO:root:Start training with 1 devices
INFO:root:Iteration[0] Train-accuracy=0.920833
INFO:root:Iteration[0] Time cost=0.656
INFO:root:Iteration[0] Validation-accuracy=0.961100
INFO:root:Iteration[1] Train-accuracy=0.965317
INFO:root:Iteration[1] Time cost=0.576
INFO:root:Iteration[1] Validation-accuracy=0.963000
INFO:root:Iteration[2] Train-accuracy=0.974817
INFO:root:Iteration[2] Time cost=0.567
INFO:root:Iteration[2] Validation-accuracy=0.965800
INFO:root:Iteration[3] Train-accuracy=0.978433
INFO:root:Iteration[3] Time cost=0.590
INFO:root:Iteration[3] Validation-accuracy=0.970900
INFO:root:Iteration[4] Train-accuracy=0.982583
INFO:root:Iteration[4] Time cost=0.593
INFO:root:Iteration[4] Validation-accuracy=0.973100
INFO:root:Iteration[5] Train-accuracy=0.982217
INFO:root:Iteration[5] Time cost=0.592
INFO:root:Iteration[5] Validation-accuracy=0.971300
INFO:root:Iteration[6] Train-accuracy=0.985817
INFO:root:Iteration[6] Time cost=0.555
INFO:root:Iteration[6] Validation-accuracy=0.969400
INFO:root:Iteration[7] Train-accuracy=0.987033
INFO:root:Iteration[7] Time cost=0.546
INFO:root:Iteration[7] Validation-accuracy=0.974800
INFO:root:Iteration[8] Train-accuracy=0.988333
INFO:root:Iteration[8] Time cost=0.535
INFO:root:Iteration[8] Validation-accuracy=0.975900
INFO:root:Iteration[9] Train-accuracy=0.987983
INFO:root:Iteration[9] Time cost=0.531
INFO:root:Iteration[9] Validation-accuracy=0.968900
```
The [MNIST](http://yann.lecun.com/exdb/mnist/) database of handwritten digits
has a training set of 60,000 examples, and a test set of 10,000 examples. Each
example is a 28 × 28 gray image. They are provided by Yann LeCun, Corinna
Cortes, and Christopher J.C. Burges.


## Neural Networks

- [mlp.py](mlp.py) : multilayer perceptron with 3 fully connected layers
- [lenet.py](lenet.py) : LeNet with 2 convolution layers followed by 2 fully
connected layers

## Results


Using 100 minibatch size and 20 data passes (not fine tuned.)

Machine: Dual Xeon E5-2680 2.8GHz, Dual GTX 980, Intel MKL, and CUDA 7.0

| | val accuracy | 2 x E5-2680 | 1 x GTX 980 | 2 x GTX 980 |
| --- | ---: | ---: | ---: | ---: |
| `mlp.py` | 97.8% | 40K img/sec | 103K img/sec | 60K img/sec |
| `lenet.py` | 99% | 368 img/sec | 22.5K img/sec | 33K img/sec |
30 changes: 30 additions & 0 deletions example/mnist/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# pylint: skip-file
""" data iterator for mnist """
import sys
sys.path.insert(0, "../../python/")
sys.path.append("../../tests/python/common")
import get_data
import mxnet as mx

def mnist_iterator(batch_size, input_shape):
"""return train and val iterators for mnist"""
# download data
get_data.GetMNIST_ubyte()
flat = False if len(input_shape) == 3 else True

train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
input_shape=input_shape,
batch_size=batch_size,
shuffle=True,
flat=flat)

val_dataiter = mx.io.MNISTIter(
image="data/t10k-images-idx3-ubyte",
label="data/t10k-labels-idx1-ubyte",
input_shape=input_shape,
batch_size=batch_size,
flat=flat)

return (train_dataiter, val_dataiter)
44 changes: 44 additions & 0 deletions example/mnist/lenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# pylint: skip-file
from data import mnist_iterator
import mxnet as mx
import logging

## define lenet

# input
data = mx.symbol.Variable('data')
# first conv
conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)
relu1 = mx.symbol.Activation(data=conv1, act_type="relu")
pool1 = mx.symbol.Pooling(data=relu1, pool_type="max",
kernel=(2,2), stride=(2,2))
# second conv
conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)
relu2 = mx.symbol.Activation(data=conv2, act_type="relu")
pool2 = mx.symbol.Pooling(data=relu2, pool_type="max",
kernel=(2,2), stride=(2,2))
# first fullc
flatten = mx.symbol.Flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
relu3 = mx.symbol.Activation(data=fc1, act_type="relu")
# second fullc
fc2 = mx.symbol.FullyConnected(data=relu3, num_hidden=10)
# loss
lenet = mx.symbol.Softmax(data=fc2)

## data

train, val = mnist_iterator(batch_size=100, input_shape=(1,28,28))

## train

logging.basicConfig(level=logging.DEBUG)

# dev = [mx.gpu(i) for i in range(2)]
dev = mx.gpu()

model = mx.model.FeedForward(
ctx = dev, symbol = lenet, num_round = 20,
learning_rate = 0.01, momentum = 0.9, wd = 0.00001)

model.fit(X=train, eval_data=val)
32 changes: 6 additions & 26 deletions example/mnist/mlp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
# pylint: skip-file
import sys
sys.path.insert(0, "../../python/")
sys.path.append("../../tests/python/common")
from data import mnist_iterator
import mxnet as mx
import logging
import numpy as np
import get_data

# define mlp

Expand All @@ -19,30 +15,14 @@

# data

batch_size = 100

get_data.GetMNIST_ubyte()
train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
input_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10)
val_dataiter = mx.io.MNISTIter(
image="data/t10k-images-idx3-ubyte",
label="data/t10k-labels-idx1-ubyte",
input_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False)

train, val = mnist_iterator(batch_size=100, input_shape = (784,))

# train

logging.basicConfig(level=logging.DEBUG)

model = mx.model.FeedForward(ctx = mx.cpu(),
symbol = mlp,
num_round = 10,
learning_rate = 0.1,
momentum = 0.9,
wd = 0.00001)
model = mx.model.FeedForward(
ctx = mx.cpu(), symbol = mlp, num_round = 20,
learning_rate = 0.1, momentum = 0.9, wd = 0.00001)

model.fit(X=train_dataiter, eval_data=val_dataiter)
model.fit(X=train, eval_data=val)
70 changes: 27 additions & 43 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,69 +55,53 @@ def plot_network(title, symbol, shape=None):
node_attr = {"shape": "box", "fixedsize": "true",
"width": "1.3", "height": "0.8034", "style": "filled"}
dot = Digraph(name=title)
# color map
cm = ("#8dd3c7", "#fb8072", "#ffffb3", "#bebada", "#80b1d3",
"#fdb462", "#b3de69", "#fccde5")

# make nodes
for i in range(len(nodes)):
node = nodes[i]
op = node["op"]
name = "%s_%d" % (op, i)
# input data
if i in heads and op == "null":
label = node["name"]
attr = copy.deepcopy(node_attr)
dot.node(name=name, label=label, **attr)
attr = copy.deepcopy(node_attr)
label = op

if op == "null":
continue
if i in heads:
label = node["name"]
attr["fillcolor"] = cm[0]
else:
continue
elif op == "Convolution":
label = "Convolution\n%sx%s/%s, %s" % (_str2tuple(node["param"]["kernel"])[0],
_str2tuple(node["param"]["kernel"])[1],
_str2tuple(node["param"]["stride"])[0],
node["param"]["num_filter"])
attr = copy.deepcopy(node_attr)
attr["color"] = "royalblue1"
dot.node(name=name, label=label, **attr)
attr["fillcolor"] = cm[1]
elif op == "FullyConnected":
label = "FullyConnected\n%s" % node["param"]["num_hidden"]
attr = copy.deepcopy(node_attr)
attr["color"] = "royalblue1"
dot.node(name=name, label=label, **attr)
attr["fillcolor"] = cm[1]
elif op == "BatchNorm":
label = "BatchNorm"
attr = copy.deepcopy(node_attr)
attr["color"] = "orchid1"
dot.node(name=name, label=label, **attr)
elif op == "Concat":
label = "Concat"
attr = copy.deepcopy(node_attr)
attr["color"] = "seagreen1"
dot.node(name=name, label=label, **attr)
elif op == "Flatten":
label = "Flatten"
attr = copy.deepcopy(node_attr)
attr["color"] = "seagreen1"
dot.node(name=name, label=label, **attr)
elif op == "Reshape":
label = "Reshape"
attr = copy.deepcopy(node_attr)
attr["color"] = "seagreen1"
dot.node(name=name, label=label, **attr)
attr["fillcolor"] = cm[3]
elif op == "Activation" or op == "LeakyReLU":
label = "%s\n%s" % (op, node["param"]["act_type"])
attr["fillcolor"] = cm[2]
elif op == "Pooling":
label = "Pooling\n%s, %sx%s/%s" % (node["param"]["pool_type"],
_str2tuple(node["param"]["kernel"])[0],
_str2tuple(node["param"]["kernel"])[1],
_str2tuple(node["param"]["stride"])[0])
attr = copy.deepcopy(node_attr)
attr["color"] = "firebrick2"
dot.node(name=name, label=label, **attr)
elif op == "Activation" or op == "LeakyReLU":
label = "%s\n%s" % (op, node["param"]["act_type"])
attr = copy.deepcopy(node_attr)
attr["color"] = "salmon"
dot.node(name=name, label=label, **attr)
attr["fillcolor"] = cm[4]
elif op == "Concat" or op == "Flatten" or op == "Reshape":
attr["fillcolor"] = cm[5]
elif op == "Softmax":
attr["fillcolor"] = cm[6]
else:
label = op
attr = copy.deepcopy(node_attr)
attr["color"] = "olivedrab1"
dot.node(name=name, label=label, **attr)
attr["fillcolor"] = cm[7]

dot.node(name=name, label=label, **attr)

# add edges
for i in range(len(nodes)):
Expand All @@ -133,7 +117,7 @@ def plot_network(title, symbol, shape=None):
input_name = "%s_%d" % (input_node["op"], item[0])
if input_node["op"] != "null" or item[0] in heads:
# add shape into label
attr = {"dir": "back"}
attr = {"dir": "back", 'arrowtail':'open'}
dot.edge(tail_name=name, head_name=input_name, **attr)

return dot