diff --git a/chemfusekit/svm.py b/chemfusekit/svm.py index 6df6ae8..7027177 100644 --- a/chemfusekit/svm.py +++ b/chemfusekit/svm.py @@ -3,14 +3,11 @@ import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns - -from sklearn.model_selection import train_test_split -from sklearn.metrics import confusion_matrix, classification_report from sklearn.svm import SVC from chemfusekit.lldf import LLDFModel +from chemfusekit.__utils import run_split_test, print_confusion_matrix + class SVMSettings: '''Holds the settings for the SVM object.''' @@ -25,6 +22,7 @@ def __init__(self, kernel: str = 'linear', output: bool = False, test_split: boo self.output = output self.test_split = test_split + class SVM: '''Class for Support Vector Machine analysis of the data''' def __init__(self, fused_data: LLDFModel, settings: SVMSettings): @@ -35,22 +33,9 @@ def __init__(self, fused_data: LLDFModel, settings: SVMSettings): def svm(self): '''Performs Support Vector Machine analysis''' - x_data = self.fused_data.x_data - x_train = self.fused_data.x_train - y = self.fused_data.y - - x_train, x_test, y_train, y_test = train_test_split( - x_data, - y, - train_size=0.7, - shuffle=True, - stratify=y - ) - # Linear kernel if self.settings.kernel == 'linear': svm_model = SVC(kernel='linear', probability=True) - # svmlinear.predict_proba(X_train) # Polynomial kernel elif self.settings.kernel == 'poly': svm_model = SVC(kernel='poly', degree=8) @@ -61,40 +46,25 @@ def svm(self): elif self.settings.kernel == 'sigmoid': svm_model = SVC(kernel='sigmoid') else: - raise ValueError(f"SVM: this type of kernel does not exist ({self.settings.type=})") + raise ValueError(f"SVM: this type of kernel does not exist ({self.settings.kernel=})") - svm_model.fit(x_train, y_train) + svm_model.fit(self.fused_data.x_data, self.fused_data.y) self.model = svm_model if self.settings.output: - y_pred = svm_model.predict(x_test) - - # Assuming 'y_true' and 'y_pred' are your true and predicted labels - cm = confusion_matrix(y_test, y_pred) - - # Get unique class labels from y_true - class_labels = sorted(set(y_test)) - - # Plot the confusion matrix using seaborn with custom colormap (Blues) - sns.heatmap( - cm, - annot=True, - fmt='d', - cmap='Blues', - xticklabels=class_labels, - yticklabels=class_labels, - cbar=False, - vmin=0, - vmax=cm.max() + predictions = svm_model.predict(self.fused_data.x_data) + print_confusion_matrix( + self.fused_data.y, + predictions, + "Confusion matrix based on the whole data set" ) - plt.xlabel('Predicted') - plt.ylabel('True') - plt.title('Confusion Matrix based on evaluation set') - plt.show() - - # Print the classification report - print(classification_report(y_test, y_pred, digits=2)) + if self.settings.output and self.settings.test_split: + run_split_test( + self.fused_data.x_data, + self.fused_data.y, + SVC(kernel=self.settings.kernel) + ) def predict(self, x_data: pd.DataFrame): '''Performs SVM prediction once the model is trained'''