Skip to content

Commit

Permalink
feat(LDA): autodetect components from PCADataModel
Browse files Browse the repository at this point in the history
  • Loading branch information
f-aguzzi committed Jun 4, 2024
1 parent c919596 commit a59cd54
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 23 deletions.
26 changes: 14 additions & 12 deletions chemfusekit/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from chemfusekit.__utils import graph_output, run_split_test
from chemfusekit.__utils import print_confusion_matrix, print_table, GraphMode
from .__base import BaseDataModel, BaseClassifier, BaseSettings
from .pca import PCADataModel


class LDASettings(BaseSettings):
Expand All @@ -23,19 +24,20 @@ def __init__(self, components: int = 3, output: GraphMode = GraphMode.NONE, test

class LDA(BaseClassifier):
'''Class to store the data, methods and artifacts for Linear Discriminant Analysis'''
def __init__(self, settings: LDASettings, data_model: BaseDataModel):
super().__init__(settings, data_model)
def __init__(self, settings: LDASettings, data: BaseDataModel):
super().__init__(settings, data)
self.settings = settings
self.x_data = data_model.x_data
self.x_train = data_model.x_train
self.y = data_model.y
self.data = data
# Self-detect components if the data is from PCA
if isinstance(data, PCADataModel):
self.settings.components = data.components - 1

def lda(self):
'''Performs Linear Discriminant Analysis'''

lda = LD(n_components=self.settings.components) # N-1 where N are the classes
scores_lda = lda.fit(self.x_data, self.y).transform(self.x_data)
pred = lda.predict(self.x_data)
scores_lda = lda.fit(self.data.x_data, self.y).transform(self.data.x_data)
pred = lda.predict(self.data.x_data)

print_table(
[f"LV{i+1}" for i in range(scores_lda.shape[1])],
Expand Down Expand Up @@ -72,7 +74,7 @@ def lda(self):
self.settings.output
)

pred = lda.predict(self.x_data)
pred = lda.predict(self.data.x_data)
print_confusion_matrix(
y1=self.y,
y2=pred,
Expand All @@ -81,13 +83,13 @@ def lda(self):
)

lv_cols = [f'LV{i+1}' for i in range(self.settings.components)]
scores = pd.DataFrame(data = scores_lda, columns = lv_cols) # latent variables
scores.index = self.x_data.index
scores = pd.DataFrame(data=scores_lda, columns=lv_cols) # latent variables
scores.index = self.data.x_data.index
y_dataframe = pd.DataFrame(self.y, columns=['Substance'])

scores = pd.concat([scores, y_dataframe], axis = 1)

# Store the traiend model
# Store the trained model
self.model = lda

# Show graphs if required by the user
Expand All @@ -102,7 +104,7 @@ def lda(self):
if self.settings.test_split:
run_split_test(
scores.drop('Substance', axis=1).values,
self.y,
self.data.y,
LD(n_components=self.settings.components),
mode=self.settings.output
)
7 changes: 5 additions & 2 deletions chemfusekit/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

class PCADataModel(BaseDataModel):
'''Data model for the PCA outputs.'''
def __init__(self, x_data: pd.DataFrame, x_train: pd.DataFrame, y: np.ndarray, array_scores: np.ndarray):
def __init__(self, x_data: pd.DataFrame, x_train: pd.DataFrame, y: np.ndarray, array_scores: np.ndarray,
components: int):
super().__init__(x_data, x_train, y)
self.array_scores = array_scores
self.components = components


class PCASettings:
Expand Down Expand Up @@ -276,5 +278,6 @@ def export_data(self) -> PCADataModel:
self.data.x_data,
self.data.x_train,
self.data.y,
self.array_scores
self.array_scores,
self.components
)
7 changes: 5 additions & 2 deletions docs/docs/pca/pcadatamodel.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ It inherits from the [`BaseDataModel`](../base/basedatamodel.md).
## Syntax

```python
PCAModel(x_data: pd.DataFrame, x_train: pd.DataFrame, y: np.ndarray, array_scores: np.ndarray)
PCAModel(x_data: pd.DataFrame, x_train: pd.DataFrame, y: np.ndarray, array_scores: np.ndarray, components: int)
```

## Fields and constructor parameters
Expand All @@ -20,6 +20,9 @@ The first two are `Pandas` `DataFrame` objects:
- `x_data`
- `x_train`

The last two are `NumPy` `ndarray`s:
The second two are `NumPy` `ndarray`s:
- `y`
- `array_scores`

The last is an integer:
- `components`
7 changes: 5 additions & 2 deletions docs/versioned_docs/version-2.0.0/pca/pcadatamodel.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ It inherits from the [`BaseDataModel`](../base/basedatamodel.md).
## Syntax

```python
PCAModel(x_data: pd.DataFrame, x_train: pd.DataFrame, y: np.ndarray, array_scores: np.ndarray)
PCAModel(x_data: pd.DataFrame, x_train: pd.DataFrame, y: np.ndarray, array_scores: np.ndarray, components: int)
```

## Fields and constructor parameters
Expand All @@ -20,6 +20,9 @@ The first two are `Pandas` `DataFrame` objects:
- `x_data`
- `x_train`

The last two are `NumPy` `ndarray`s:
The second two are `NumPy` `ndarray`s:
- `y`
- `array_scores`

The last is an integer:
- `components`
12 changes: 7 additions & 5 deletions examples/pca_lda_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@
"\n",
"# Print the number of components and the statistics\n",
"print(f\"\\nNumber of components: {pca.components}\\n\")\n",
"pca.pca_stats()"
"pca.pca_stats()\n",
"\n",
"# Export data from PCA\n",
"pca_data = pca.export_data()"
]
},
{
Expand All @@ -134,13 +137,12 @@
"from chemfusekit.lda import LDASettings, LDA, GraphMode\n",
"\n",
"settings = LDASettings(\n",
" components=(pca.components - 1), # one less component than the number determined by PCA\n",
" output=GraphMode.GRAPHIC, # graphs will be printed as pictures\n",
" test_split=True # Run split test\n",
" output=GraphMode.GRAPHIC # Graphs will be printed\n",
" test_split=True # Run split test\n",
")\n",
"\n",
"# Initialize and run the LDA class\n",
"lda = LDA(settings, lldf.fused_data)\n",
"lda = LDA(settings, pca_data) # components will be determined automatically from the PCA data\n",
"lda.lda()"
]
},
Expand Down

0 comments on commit a59cd54

Please sign in to comment.