-
Notifications
You must be signed in to change notification settings - Fork 3
/
dataset.py
36 lines (32 loc) · 1.06 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
from sklearn.cross_validation import train_test_split
from sklearn.datasets import fetch_mldata
def load_noisy_mnist(noise_ratio=0.):
"""
load MNIST with noisy labels
Parameters
----------
noise_ratio : float
ratio of noisy labels in training
Returns
-------
x_train : (60000, 784) np.ndarray
flattened training images
x_test : (10000, 784) np.ndarray
flattened test images
y_train : (60000,) np.ndarray
training labels
y_test : (10000,) np.ndarray
test labels
"""
mnist = fetch_mldata("MNIST original")
x = np.float32(mnist.data)
x /= np.max(x, axis=1, keepdims=True)
y = mnist.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=10000)
# indices = np.random.choice(60000, int(60000 * noise_ratio), False)
indices = np.arange(int(60000 * noise_ratio))
y_train[indices] = np.random.randint(0, 10, len(indices))
y_train = np.int32(y_train)
y_test = np.int32(y_test)
return x_train, x_test, y_train, y_test