-
Notifications
You must be signed in to change notification settings - Fork 2
/
mnist.py
109 lines (83 loc) · 3.72 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Stolen shamelessly and only slightly modified, from this link:
# https://github.com/google/jax/blob/main/examples/datasets.py
"""Datasets used in examples."""
import array
import gzip
import os
from os import path
import struct
import urllib.request
from jax import jit, random
import jax.numpy as np
from jax.random import permutation
_DATA = "./data/"
def _download(url, filename):
"""Download a url to a file in the JAX data temp directory."""
if not path.exists(_DATA):
os.makedirs(_DATA)
out_file = path.join(_DATA, filename)
if not path.isfile(out_file):
urllib.request.urlretrieve(url, out_file)
print(f"downloaded {url} to {_DATA}")
def _partial_flatten(x):
"""Flatten all but the first dimension of an ndarray."""
return np.reshape(x, (x.shape[0], -1))
def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
def mnist_raw():
"""Download and parse the raw MNIST dataset."""
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
def parse_labels(fname):
with gzip.open(fname, "rb") as fh:
_ = struct.unpack(">II", fh.read(8))
return np.array(array.array("B", fh.read()), dtype=np.uint8)
def parse_images(fname):
with gzip.open(fname, "rb") as fh:
_, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
return np.array(array.array("B", fh.read()),
dtype=np.uint8).reshape(num_data, rows, cols)
for filename in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]:
_download(base_url + filename, filename)
train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz"))
train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz"))
test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))
return train_images, train_labels, test_images, test_labels
def mnist(permute_key=None):
"""Download, parse and process MNIST data to unit scale and one-hot labels."""
""""""
train_images, train_labels, test_images, test_labels = mnist_raw()
train_images = _partial_flatten(train_images) / np.float32(255.)
test_images = _partial_flatten(test_images) / np.float32(255.)
train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)
if permute_key is not None:
perm = permutation(permute_key, train_images.shape[0])
train_images = train_images[perm]
train_labels = train_labels[perm]
return train_images, train_labels, test_images, test_labels
def dataloader(key, batch_size, data, labels):
data_len = data.shape[0]
while True:
i = 0
key, subkey = random.split(key)
order = permutation(subkey, data_len)
while i < data_len:
yield data[order[i:i+batch_size]], labels[order[i:i+batch_size]]
i += batch_size