Skip to content

Commit

Permalink
Merge branch 'master' into bug/913-randint-size
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudiaComito committed Feb 14, 2022
2 parents e9af142 + 66da5c2 commit 2b617d4
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions examples/classification/demo_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def create_fold(dataset_x, dataset_y, size, seed=None):
data_indices = ht.array(indices[0:size], split=0)
verification_indices = ht.array(indices[size:], split=0)

fold_x = ht.array(dataset_x[data_indices], is_split=0)
fold_y = ht.array(dataset_y[data_indices], is_split=0)
verification_y = ht.array(dataset_y[verification_indices], is_split=0)
verification_x = ht.array(dataset_x[verification_indices], is_split=0)
fold_x = dataset_x[data_indices]
fold_y = dataset_y[data_indices]
verification_y = dataset_y[verification_indices]
verification_x = dataset_x[verification_indices]

# Balance arrays
fold_x.balance_()
Expand Down Expand Up @@ -138,10 +138,11 @@ def verify_algorithm(x, y, split_number, split_size, k, seed=None):

for split_index in range(split_number):
fold_x, fold_y, verification_x, verification_y = create_fold(x, y, split_size, seed)
classifier = KNeighborsClassifier(fold_x, fold_y, k)
classifier = KNeighborsClassifier(k)
classifier.fit(fold_x, fold_y)
result_y = classifier.predict(verification_x)
accuracies.append(calculate_accuracy(result_y, verification_y).item())
return accuracies


print(verify_algorithm(X, Y, 1, 30, 5, 1))
print("Accuracy: {}".format(verify_algorithm(X, Y, 1, 30, 5, 1)))

0 comments on commit 2b617d4

Please sign in to comment.