-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathdatasource.cpp
119 lines (100 loc) · 2.56 KB
/
datasource.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include "datasource.h"
#include <QDebug>
DataSource::DataSource(int outputDimension, int classes, QObject *parent) : QObject(parent)
{
dataFile = new QFile();
labelFile = new QFile();
vectorDimension = outputDimension;
trainingExamplesAmount = 2000;
classesAmount = classes;
label_arr = new int[classesAmount];
for (int i = 0; i < classesAmount; ++i) {
label_arr[i] = 0;
}
data = new double[vectorDimension];
for (int i = 0; i < vectorDimension; ++i) {
data[i] = 0;
}
}
DataSource::~DataSource()
{
delete dataFile;
delete labelFile;
delete[] data;
delete[] label_arr;
}
bool DataSource::setDataFilename(QString filename)
{
dataFile->setFileName(filename);
if(!dataFile->exists()) {
qDebug() << "datasource.cpp:setDataFilename(): Data file does not exists";
return false;
}
return true;
}
bool DataSource::setLabelFilename(QString filename)
{
labelFile->setFileName(filename);
if(!labelFile->exists()) {
qDebug() << "datasource.cpp:setLabelFilename(): Label file does not exists";
return false;
}
return true;
}
double* DataSource::getData(int param) const
{
if(param > trainingExamplesAmount) return NULL;
if(!dataFile->open(QIODevice::ReadOnly)) {
qDebug() << "Can't open file";
return NULL;
}
char bytes[vectorDimension];
// 4 * 4 - 32bit integer
dataFile->read(16 + param * sizeof(bytes));
dataFile->read(bytes, sizeof(bytes));
dataFile->close();
for (int i = 0; i < vectorDimension; ++i) {
data[i] = (unsigned char)bytes[i];
}
return data;
}
double DataSource::getLabel(int param)
{
if(param > trainingExamplesAmount)
qFatal("datasource.cpp:getLabel(): param is out of trainingExamplesAmount");
if(!labelFile->open(QIODevice::ReadOnly)) {
qDebug() << "Can't open file";
return 0;
}
//4 * 2 - 32bit integer
labelFile->read(8 + param * 1);
char a = 255;
labelFile->read(&a, 1);
label = (double)a;
labelFile->close();
return label;
}
int* DataSource::getLabelAsArray(int y)
{
double val = getLabel(y);
for (int i = 0; i < classesAmount; ++i) {
if(i != val) {
label_arr[i] = 0;
} else {
label_arr[i] = 1;
}
}
return label_arr;
}
int DataSource::getTrainingExampleAmount() const
{
return trainingExamplesAmount;
}
int DataSource::getVectorDimension() const
{
return vectorDimension;
}
int DataSource::getClassesAmount() const
{
return classesAmount;
}