forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
32 lines (28 loc) · 982 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# pylint: skip-file
""" data iterator for mnist """
import sys
import os
# code to automatically download dataset
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
import get_data
import mxnet as mx
def mnist_iterator(batch_size, input_shape):
"""return train and val iterators for mnist"""
# download data
get_data.GetMNIST_ubyte()
flat = False if len(input_shape) == 3 else True
train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
input_shape=input_shape,
batch_size=batch_size,
shuffle=True,
flat=flat)
val_dataiter = mx.io.MNISTIter(
image="data/t10k-images-idx3-ubyte",
label="data/t10k-labels-idx1-ubyte",
input_shape=input_shape,
batch_size=batch_size,
flat=flat)
return (train_dataiter, val_dataiter)