-
Notifications
You must be signed in to change notification settings - Fork 1
/
readMNIST.m
90 lines (77 loc) · 2.45 KB
/
readMNIST.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
% readMNIST by Siddharth Hegde
%
% Description:
% Read digits and labels from raw MNIST data files
% File format as specified on http://yann.lecun.com/exdb/mnist/
% Note: The 4 pixel padding around the digits will be remove
% Pixel values will be normalised to the [0...1] range
%
% Usage:
% [imgs labels] = readMNIST(imgFile, labelFile, readDigits, offset)
%
% Parameters:
% imgFile = name of the image file
% labelFile = name of the label file
% readDigits = number of digits to be read
% offset = skips the first offset number of digits before reading starts
%
% Returns:
% imgs = 20 x 20 x readDigits sized matrix of digits
% labels = readDigits x 1 matrix containing labels for each digit
%
function [imgs labels] = readMNIST(imgFile, labelFile, readDigits, offset)
% Read digits
fid = fopen(imgFile, 'r', 'b');
header = fread(fid, 1, 'int32');
if header ~= 2051
error('Invalid image file header');
end
count = fread(fid, 1, 'int32');
if count < readDigits+offset
error('Trying to read too many digits');
end
h = fread(fid, 1, 'int32');
w = fread(fid, 1, 'int32');
if offset > 0
fseek(fid, w*h*offset, 'cof');
end
imgs = zeros([h w readDigits]);
for i=1:readDigits
for y=1:h
imgs(y,:,i) = fread(fid, w, 'uint8');
end
end
fclose(fid);
% Read digit labels
fid = fopen(labelFile, 'r', 'b');
header = fread(fid, 1, 'int32');
if header ~= 2049
error('Invalid label file header');
end
count = fread(fid, 1, 'int32');
if count < readDigits+offset
error('Trying to read too many digits');
end
if offset > 0
fseek(fid, offset, 'cof');
end
labels = fread(fid, readDigits, 'uint8');
fclose(fid);
% Calc avg digit and count
%imgs = trimDigits(imgs, 4);
imgs = normalizePixValue(imgs);
%[avg num stddev] = getDigitStats(imgs, labels);
end
function digits = trimDigits(digitsIn, border)
dSize = size(digitsIn);
digits = zeros([dSize(1)-(border*2) dSize(2)-(border*2) dSize(3)]);
for i=1:dSize(3)
digits(:,:,i) = digitsIn(border+1:dSize(1)-border, border+1:dSize(2)-border, i);
end
end
function digits = normalizePixValue(digits)
digits = double(digits);
for i=1:size(digits, 3)
digits(:,:,i) = digits(:,:,i)./255.0;
end
end