-
Notifications
You must be signed in to change notification settings - Fork 1
/
SVM.py
93 lines (73 loc) · 2.88 KB
/
SVM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# SVM.py
# Implements a support vector machine to determine an artist based on an artwork.
import numpy as np
import os
import matplotlib.pyplot as plt
# sci-kit imports
from sklearn import svm
from skimage.feature import hog
from skimage.util import montage
from sklearn.metrics import plot_confusion_matrix
# custom library
from load_data import load_data
def main():
'''
VARIABLE KEY:
Training Data: xtrain
Training Labels: ttrain
Testing Data: xtest
Testing Labels: ttest
HOG train: x
HOG test: xt
Misclassified LabeLs: mislabeled_(train/test)
'''
# set seed for repeatability
np.random.seed(0)
print("Loading training and testing data...")
(xtrain, ttrain), (xtest, ttest) = load_data()
# reshape training and testing data for HOG
xtrain = xtrain.reshape(xtrain.shape[0], 224, 224)
xtest = xtest.reshape(xtest.shape[0], 224, 224)
# get HOG features of training and testing data
print("Extracting HOG features...")
x = getHOG(xtrain)
xt = getHOG(xtest)
print("Setting up labels...")
# converts categorical class labels to unique values
_, t = np.unique(ttrain, return_inverse=True)
labels, tt = np.unique(ttest, return_inverse=True)
print("Training SVM...")
clf = svm.SVC(C=1.0)
clf.fit(x, t)
# performance metrics
print("\nComputing performance metrics...")
print("Training Accuracy: ", clf.score(x, t))
print("Testing Accuracy:", clf.score(xt, tt))
mislabeled_train = np.where(clf.predict(x) != t)
mislabeled_test = np.where(clf.predict(xt) != tt)
print("Number of Incorrectly Predicted:", (xtrain[mislabeled_train].shape[0] + xtest[mislabeled_test].shape[0]), "/" ,
(xtrain.shape[0] + xtest.shape[0]), "images")
print("Number of Correctly Predicted:", ((xtrain.shape[0] - xtrain[mislabeled_train].shape[0]) + (xtest.shape[0] - xtest[mislabeled_test].shape[0])), " / " ,
(xtrain.shape[0] + xtest.shape[0]), "images")
print("\nCreating confusion matrix... ")
plt.rc('font', size=6)
plt.rc('figure', titlesize=10)
fig, ax = plt.subplots(figsize=(8, 6))
plt.subplots_adjust(bottom=0.2, top=0.9, right=0.9, left=0.1)
ax.set_title("SVM Confusion Matrix")
cm = plot_confusion_matrix(clf, xt, tt,
normalize='all',
display_labels=labels,
xticks_rotation='vertical',
cmap=plt.cm.Blues,
ax=ax)
plt.show()
def getHOG(data):
# Computes the HOG features of a dataset.
x = []
for d in data:
x.append(hog(d, orientations = 45, pixels_per_cell = (16, 16),
cells_per_block = (14, 14), transform_sqrt = True))
return np.asarray(x)
if __name__ == '__main__':
main()