This repository has been archived by the owner on May 27, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
py3testweights.py
86 lines (82 loc) · 3.07 KB
/
py3testweights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/python
import numpy as np
import parsematlab
import loaddata
import testweights
from iwafgui import Error, Info, SaveAs
def testWeights(name, values):
flistwidget, fnames = values['flist']
weightfile = values['weightfile'][1]
if not weightfile:
Error('You must first generate weights or select a file from which ' + \
'to load the weights.')
return
errors = []
label, value = values['test-args'][1]['matrixshape']
matrixshape = parsematlab.parse(value.lower().replace('x', ' '))
if isinstance(matrixshape, str):
errors.append(label + '\n ' + value.replace('\n', '\n '))
if np.isscalar(matrixshape):
matrixshape = [matrixshape]
label, value = values['test-args'][1]['repetitions']
repetitions = parsematlab.parse(value)
if isinstance(repetitions, str):
errors.append(label + '\n ' + value.replace('\n', '\n '))
if len(errors) > 0:
Error('\n\n'.join(errors))
return
classifier = loaddata.load_weights(weightfile)
if isinstance(classifier, str):
Error(classifier)
return
removeanomalies = values['generation-args'][1]['removeanomalies'][1]
data = []
type = []
samplingrate = None
try:
for fname in fnames:
result = loaddata.load_data(fname, [0, classifier.shape[0]],
None, True, removeanomalies = removeanomalies)
if isinstance(result, str):
Error(result)
return
if samplingrate == None:
samplingrate = result[2]
if samplingrate != result[2]:
Error('Not all data files have the same sampling rate.')
return
data.append(result[0])
type.append(result[1])
if len(data) == 0 or len(type) == 0:
Error('You must select some data upon which to test the weights.')
return
try:
data = np.concatenate(data)
except ValueError:
Error('Not all data files have the same number of channels.')
return
type = np.concatenate(type)
result = testweights.test_weights(data, type, classifier,
matrixshape, repetitions)
if isinstance(result, str):
Error(result)
return
score, correctness = result
message = '\n'.join(fnames)
message += '\n\n%s\n\nExpected accuracy for a %s matrix:\n\n' % \
(
weightfile,
'x'.join(str(i) for i in matrixshape)
)
for i in range(len(repetitions)):
if repetitions[i] != 1:
message += '%i repetitions: %0.1f%%\n' % \
(repetitions[i], correctness[i] * 100)
else:
message += '1 repetition: %0.1f%%\n' % (correctness[i] * 100)
message += '\nTarget STDEV: %f\nNontarget STDEV: %f\n' % score
Info(message)
except MemoryError:
Error('Could not fit all the selected data in memory.\n' + \
'Try loading fewer data files.')
return