Skip to content

Commit

Permalink
refactor: moved prediction into base class
Browse files Browse the repository at this point in the history
  • Loading branch information
f-aguzzi committed Jun 4, 2024
1 parent a0eb4e6 commit 57a3497
Show file tree
Hide file tree
Showing 10 changed files with 21 additions and 44 deletions.
9 changes: 9 additions & 0 deletions chemfusekit/__base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,12 @@ def export_model(self, export_path: str):
else:
raise RuntimeError("You haven't trained the model yet! You cannot export it now.")

def predict(self, x_data: pd.DataFrame):
'''Performs prediction once the model is trained.'''
if x_data is None:
raise TypeError(f"X data for {self.__class__.__name__} prediction must be non-empty.")
if self.model is None:
raise RuntimeError(f"The {self.__class__.__name__} model is not trained yet!")

y_pred = self.model.predict(x_data)
return y_pred
10 changes: 0 additions & 10 deletions chemfusekit/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,3 @@ def knn(self):
algorithm=self.settings.algorithm
)
run_split_test(self.data.x_data, self.data.y, knn_split)

def predict(self, x_data: pd.DataFrame):
'''Performs kNN prediction once the model is trained.'''
if x_data is None:
raise TypeError("X data for kNN prediction must be non-empty.")
if self.model is None:
raise RuntimeError("The kNN model is not trained yet!")

y_pred = self.model.predict(x_data)
return y_pred
12 changes: 1 addition & 11 deletions chemfusekit/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,8 @@ def lda(self):
# Run split tests if required by the user
if self.settings.test_split:
run_split_test(
(scores.drop('Substance', axis=1).values),
scores.drop('Substance', axis=1).values,
self.y,
LD(n_components=self.settings.components),
mode=self.settings.output
)

def predict(self, x_data: pd.DataFrame):
'''Performs LDA prediction once the model is trained.'''
if x_data is None:
raise TypeError("X data for LDA prediction must be non-empty.")
if self.model is None:
raise RuntimeError("The LDA model is not trained yet!")

y_pred = self.model.predict(x_data)
return y_pred
5 changes: 1 addition & 4 deletions chemfusekit/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,7 @@ def lr(self):

def predict(self, x_sample: pd.DataFrame):
'''Performs LR prediction once the model is trained.'''
if self.model is None:
raise RuntimeError("The LR model is not trained yet!")

prediction = self.model.predict(x_sample)
prediction = super().predict(x_sample)
probabilities = self.model.predict_proba(x_sample)

classes = self.model.classes_.reshape((self.model.classes_.shape[0], ))
Expand Down
10 changes: 0 additions & 10 deletions chemfusekit/plsda.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,3 @@ def plsda(self):
x = self.data.x_data
y = self.data.x_train.Substance.astype('category').cat.codes
run_split_test(x, y, PLSR(self.settings.n_components), mode=self.settings.output)

def predict(self, x_data: pd.DataFrame):
'''Performs PLSDA prediction once the model is trained.'''
if x_data is None:
raise TypeError("X data for PLSDA prediction must be non-empty.")
if self.model is None:
raise RuntimeError("The PLSDA model is not trained yet!")

y_pred = self.model.predict(x_data)
return y_pred
9 changes: 0 additions & 9 deletions chemfusekit/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,3 @@ def svm(self):
model=SVC(kernel=self.settings.kernel),
mode=self.settings.output
)

def predict(self, x_data: pd.DataFrame):
'''Performs SVM prediction once the model is trained'''
if self.model is None:
raise RuntimeError("The model hasn't been trained yet!")
if x_data is None:
raise TypeError("X data for prediction cannot be empty.")

return self.model.predict(x_data)
1 change: 1 addition & 0 deletions docs/cookbook/structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ classDiagram
__init__(settings, data)
import_model(import_path: str)
export_model(export_path: str)
predict(x_data: pd.DataFrame)
}
class KNN {
Expand Down
1 change: 1 addition & 0 deletions docs/cookbook_versioned_docs/version-2.0.0/structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ classDiagram
__init__(settings, data)
import_model(import_path: str)
export_model(export_path: str)
predict(x_data: pd.DataFrame)
}
class KNN {
Expand Down
4 changes: 4 additions & 0 deletions docs/docs/base/baseclassifier.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ BaseClassifier(settings: BaseSettings, data: BaseDataModel)
- `export_model(export_path: str)`: exports a model to file
- *raises*:
- `RuntimeError("You haven't trained the model yet! You cannot export it now.")` when trying to export an untrained model
- `predict(x_data: pd.DataFrame)`: performs prediction through the `model`
- *raises*:
- `TypeError("X data for prediction must be non-empty.")` on empty `x_data`
- `RuntimeError("The model is not trained yet!")` when run with an untrained `model`
4 changes: 4 additions & 0 deletions docs/versioned_docs/version-2.0.0/base/baseclassifier.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ BaseClassifier(settings: BaseSettings, data: BaseDataModel)
- `export_model(export_path: str)`: exports a model to file
- *raises*:
- `RuntimeError("You haven't trained the model yet! You cannot export it now.")` when trying to export an untrained model
- `predict(x_data: pd.DataFrame)`: performs prediction through the `model`
- *raises*:
- `TypeError("X data for prediction must be non-empty.")` on empty `x_data`
- `RuntimeError("The model is not trained yet!")` when run with an untrained `model`

0 comments on commit 57a3497

Please sign in to comment.