-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
84 lines (68 loc) · 2.77 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from classifiers.mlp import MajorMlpClassifier
from embeddings.bert import BertSentenceEmbedder
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from classifiers.bert import BertClassifier
import pandas as pd
import numpy as np
from typing import Tuple
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import classification_report
from helper import load_data, get_recommendations, plot_confusion_matrix
import matplotlib.pyplot as plt
import os
device = "mps"
def evaluate(load_weights=False):
"""
Performs basic train/test split evaluation.
"""
os.makedirs("figures", exist_ok=True)
sentences, labels = load_data(num_majors=40)
embedder = BertSentenceEmbedder(device, padding_length=1000)
seed = 2
x_train, x_test, y_train,y_test = train_test_split(sentences, labels, random_state=seed, shuffle=True, train_size=0.8)
train_embeddings = embedder.transform(x_train)
test_embeddings = embedder.transform(x_test)
knn = KNeighborsClassifier()
mlp = MajorMlpClassifier(device)
bert_classifier = BertClassifier(
device=device,
epochs=25,
)
if load_weights:
mlp.load_weights("weights/major_classifier")
bert_classifier.load_weights("weights/bert_classifier_deployment_weights")
else:
bert_classifier.fit(x_train,y_train)
mlp.fit(train_embeddings,y_train)
knn.fit(train_embeddings, y_train)
class_labels = np.array(bert_classifier.labels)
def report(name, classifier, x,y, n=3):
probs = classifier.predict_proba(x)
ordered_choices = class_labels[(-probs).argsort(-1)[:,:n]]
preds = ordered_choices[:,0]
print(name)
print(f"Top {n} accuracy", np.mean([label in choices for label, choices in zip(y, ordered_choices)]))
print(classification_report(y, preds))
plot_confusion_matrix(y,preds, class_labels)
plt.savefig(f"figures/{name}_cm.png")
plt.clf()
report("bert_classifier",bert_classifier, x_test, y_test)
report("KNN",knn, test_embeddings, y_test)
report("major_mlp",mlp, test_embeddings, y_test)
def demo():
"""
Interact with a model on the command line.
"""
bert_classifier = BertClassifier(device="mps")
weights_path = os.path.join("weights", "bert_classifier_deployment_weights")
bert_classifier.load_weights(weights_path)
while True:
command = input("Describe your ideal major: ")
if command.lower() == "q" or command.lower() == "quit":
break
probs = bert_classifier.predict_proba(command)
labels = bert_classifier.labels
print(get_recommendations(probs, labels, n=3)[0])
if __name__ == "__main__":
evaluate()