-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathknn.py
59 lines (50 loc) · 1.78 KB
/
knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from pprint import pprint as pp
from pprint import pformat as pf
import pickle
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import os
from util.util import argmax_for_dict
from mnist import load_mnist
def experiment_knn_score(X, y): #TODO: add pickle
scores = {}
X_train, X_val, y_train, y_val = train_test_split(X, y, train_size=0.8, shuffle=False)
for neigh_num in range(1, 6):
knn = KNeighborsClassifier(n_neighbors=neigh_num)
knn.fit(X_train, y_train)
score = knn.score(X_val, y_val)
#print("score", score) # debug
scores[neigh_num] = score
return scores
def optimize(X, y):
fname = "pickles/scores.pickle"
if os.path.exists(fname):
print(f"{fname} exists") # debug
scores = pickle.load(open(fname, "rb"))
else:
print(f"{fname} doesn't exist") # debug
scores = experiment_knn_score(X, y)
opt_k, max_score = argmax_for_dict(scores)
print("opt_k", opt_k) # debug
return opt_k
def learn_knn(X, y, k, is_refresh=False):
fname = "pickles/knn.pickle"
if os.path.exists(fname) and not is_refresh:
print(f"{fname} exists") # debug
knn = pickle.load(open(fname, "rb"))
else:
print(f"{fname} doesn't exist") # debug
knn = KNeighborsClassifier(n_neighbors=k)
knn.fit(X, y)
pickle.dump(knn, open(fname, "wb"))
return knn
def knn_predict(learned_knn, X_test):
fname = "pickles/y_pred.pickle"
if os.path.exists(fname):
print(f"{fname} exists") # debug
y_pred = pickle.load(open(fname, "rb"))
else:
print(f"{fname} doesn't exist") # debug
y_pred = learned_knn.predict(X_test)
pickle.dump(y_pred, open(fname, "wb"))
return y_pred