From 7f254876fcdbd2ba5c46278d31eb851a50659e8f Mon Sep 17 00:00:00 2001 From: Federico Aguzzi <62149513+f-aguzzi@users.noreply.github.com> Date: Thu, 23 May 2024 13:20:28 +0200 Subject: [PATCH] fix: wrong training procedure in SVM (#25) * fix: wrong training procedure in SVM * ci(release): 1.1.0-beta.3 [skip ci] ## [1.1.0-beta.3](https://github.com/f-aguzzi/tesi/compare/v1.1.0-beta.2...v1.1.0-beta.3) (2024-05-23) ### Bug Fixes * wrong training procedure in SVM ([74d1741](https://github.com/f-aguzzi/tesi/commit/74d1741743f53eea6cc2d9002005c6426bf4f0d0)) * ci(release): 1.1.1-beta.1 [skip ci] ## [1.1.1-beta.1](https://github.com/f-aguzzi/tesi/compare/v1.1.0...v1.1.1-beta.1) (2024-05-23) ### Bug Fixes * wrong training procedure in SVM ([74d1741](https://github.com/f-aguzzi/tesi/commit/74d1741743f53eea6cc2d9002005c6426bf4f0d0)) ### CI * **release:** 1.1.0-beta.3 [skip ci] ([96988ac](https://github.com/f-aguzzi/tesi/commit/96988acbfd0f015c03385f74f14ea26e98a9b4b2)) --------- Co-authored-by: semantic-release-bot --- CHANGELOG.md | 18 +++++++++++++- chemfusekit/svm.py | 62 ++++++++++++---------------------------------- pyproject.toml | 2 +- 3 files changed, 34 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b9523ce..d183aef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,16 @@ -## [1.1.0](https://github.com/f-aguzzi/tesi/compare/v1.0.0...v1.1.0) (2024-05-20) +## [1.1.1-beta.1](https://github.com/f-aguzzi/tesi/compare/v1.1.0...v1.1.1-beta.1) (2024-05-23) + + +### Bug Fixes + +* wrong training procedure in SVM ([74d1741](https://github.com/f-aguzzi/tesi/commit/74d1741743f53eea6cc2d9002005c6426bf4f0d0)) + + +### CI + +* **release:** 1.1.0-beta.3 [skip ci] ([96988ac](https://github.com/f-aguzzi/tesi/commit/96988acbfd0f015c03385f74f14ea26e98a9b4b2)) + +## [1.1.0-beta.3](https://github.com/f-aguzzi/tesi/compare/v1.1.0-beta.2...v1.1.0-beta.3) (2024-05-23) ### Features @@ -8,6 +20,9 @@ ### Bug Fixes + +* wrong training procedure in SVM ([74d1741](https://github.com/f-aguzzi/tesi/commit/74d1741743f53eea6cc2d9002005c6426bf4f0d0)) + * add missing checks to LLDFSettings ([3d57752](https://github.com/f-aguzzi/tesi/commit/3d577527eefd0b183a66c378d53cb1f1ee506343)), closes [#20](https://github.com/f-aguzzi/tesi/issues/20) @@ -26,6 +41,7 @@ * **release:** 1.1.0-beta.1 [skip ci] ([2deffcc](https://github.com/f-aguzzi/tesi/commit/2deffcc4c8a29d09a4a644558e491d770f71f6dc)), closes [#7](https://github.com/f-aguzzi/tesi/issues/7) * **release:** 1.1.0-beta.2 [skip ci] ([25704dc](https://github.com/f-aguzzi/tesi/commit/25704dc4eb0cbc249bbf85611c5dfe257ebccc30)), closes [#20](https://github.com/f-aguzzi/tesi/issues/20) [#22](https://github.com/f-aguzzi/tesi/issues/22) [#23](https://github.com/f-aguzzi/tesi/issues/23) + ## [1.1.0-beta.2](https://github.com/f-aguzzi/tesi/compare/v1.1.0-beta.1...v1.1.0-beta.2) (2024-05-20) diff --git a/chemfusekit/svm.py b/chemfusekit/svm.py index 6df6ae8..7027177 100644 --- a/chemfusekit/svm.py +++ b/chemfusekit/svm.py @@ -3,14 +3,11 @@ import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns - -from sklearn.model_selection import train_test_split -from sklearn.metrics import confusion_matrix, classification_report from sklearn.svm import SVC from chemfusekit.lldf import LLDFModel +from chemfusekit.__utils import run_split_test, print_confusion_matrix + class SVMSettings: '''Holds the settings for the SVM object.''' @@ -25,6 +22,7 @@ def __init__(self, kernel: str = 'linear', output: bool = False, test_split: boo self.output = output self.test_split = test_split + class SVM: '''Class for Support Vector Machine analysis of the data''' def __init__(self, fused_data: LLDFModel, settings: SVMSettings): @@ -35,22 +33,9 @@ def __init__(self, fused_data: LLDFModel, settings: SVMSettings): def svm(self): '''Performs Support Vector Machine analysis''' - x_data = self.fused_data.x_data - x_train = self.fused_data.x_train - y = self.fused_data.y - - x_train, x_test, y_train, y_test = train_test_split( - x_data, - y, - train_size=0.7, - shuffle=True, - stratify=y - ) - # Linear kernel if self.settings.kernel == 'linear': svm_model = SVC(kernel='linear', probability=True) - # svmlinear.predict_proba(X_train) # Polynomial kernel elif self.settings.kernel == 'poly': svm_model = SVC(kernel='poly', degree=8) @@ -61,40 +46,25 @@ def svm(self): elif self.settings.kernel == 'sigmoid': svm_model = SVC(kernel='sigmoid') else: - raise ValueError(f"SVM: this type of kernel does not exist ({self.settings.type=})") + raise ValueError(f"SVM: this type of kernel does not exist ({self.settings.kernel=})") - svm_model.fit(x_train, y_train) + svm_model.fit(self.fused_data.x_data, self.fused_data.y) self.model = svm_model if self.settings.output: - y_pred = svm_model.predict(x_test) - - # Assuming 'y_true' and 'y_pred' are your true and predicted labels - cm = confusion_matrix(y_test, y_pred) - - # Get unique class labels from y_true - class_labels = sorted(set(y_test)) - - # Plot the confusion matrix using seaborn with custom colormap (Blues) - sns.heatmap( - cm, - annot=True, - fmt='d', - cmap='Blues', - xticklabels=class_labels, - yticklabels=class_labels, - cbar=False, - vmin=0, - vmax=cm.max() + predictions = svm_model.predict(self.fused_data.x_data) + print_confusion_matrix( + self.fused_data.y, + predictions, + "Confusion matrix based on the whole data set" ) - plt.xlabel('Predicted') - plt.ylabel('True') - plt.title('Confusion Matrix based on evaluation set') - plt.show() - - # Print the classification report - print(classification_report(y_test, y_pred, digits=2)) + if self.settings.output and self.settings.test_split: + run_split_test( + self.fused_data.x_data, + self.fused_data.y, + SVC(kernel=self.settings.kernel) + ) def predict(self, x_data: pd.DataFrame): '''Performs SVM prediction once the model is trained''' diff --git a/pyproject.toml b/pyproject.toml index 6fd9e60..c0ed7c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "chemfusekit" -version = "1.1.0" +version = "1.1.1b1" description = "A minimal Python / Jupyter Notebook / Colab library for data fusion and chemometrical analysis." authors = [ { name = "Federico Aguzzi", email = "62149513+f-aguzzi@users.noreply.github.com" }