-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_knn.py
39 lines (31 loc) · 1.37 KB
/
run_knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
from l2_distance import l2_distance
def run_knn(k, train_data, train_labels, valid_data):
"""Uses the supplied training inputs and labels to make
predictions for validation data using the K-nearest neighbours
algorithm.
Note: N_TRAIN is the number of training examples,
N_VALID is the number of validation examples,
and M is the number of features per example.
Inputs:
k: The number of neighbours to use for classification
of a validation example.
train_data: The N_TRAIN x M array of training
data.
train_labels: The N_TRAIN x 1 vector of training labels
corresponding to the examples in train_data
(must be binary).
valid_data: The N_VALID x M array of data to
predict classes for.
Outputs:
valid_labels: The N_VALID x 1 vector of predicted labels
for the validation data.
"""
dist = l2_distance(valid_data, train_data)
nearest = np.argsort(dist, axis=1)[:,:k]
train_labels = train_labels.reshape(-1)
valid_labels = train_labels[nearest]
# note this only works for binary labels
#valid_labels = (np.mean(valid_labels, axis=1) >= 0.5).astype(np.int)
valid_labels = valid_labels.reshape(-1,1)
return valid_labels