Skip to content

Commit

Permalink
Compute labels for fit (#5)
Browse files Browse the repository at this point in the history
* Remove useless imports

* Remove the randomness class

* Forgot to use seed

* Fix it such that the return labels actually refer to the original data and not to the coreset

* Fix mypy errors
  • Loading branch information
giuliabaldini authored Jul 4, 2024
1 parent 6f225dc commit 208f7d8
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions bico/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,14 @@ def __init__(

@property
def labels_(self) -> np.ndarray:
if not hasattr(self, "_labels"):
if not hasattr(self, "_cluster_centers"):
raise NotFittedError(self._CORESET_ESTIMATOR_ERROR)
elif not hasattr(self, "_labels"):
raise ValueError(
"The labels have not been computed because the coreset "
"was fit using partial_fit. "
"Please call predict on your data to obtain the labels."
)
return self._labels

@property
Expand Down Expand Up @@ -114,6 +120,7 @@ def partial_fit(

def _fit_coreset(
self,
X: Optional[np.ndarray] = None,
) -> None:
if self.coreset_estimator is None:
from sklearn.cluster import KMeans
Expand All @@ -127,10 +134,13 @@ def _fit_coreset(
self._coreset_points, sample_weight=self._coreset_weights
)
self._cluster_centers: np.ndarray = self.coreset_estimator.cluster_centers_
self._labels: np.ndarray = self.coreset_estimator.labels_
if X is not None:
self._labels: np.ndarray = self.coreset_estimator.predict(X)
self._inertia: float = self.coreset_estimator.inertia_

def _compute_coreset(self, fit_coreset: bool = False) -> "BICO":
def _compute_coreset(
self, X: Optional[np.ndarray] = None, fit_coreset: bool = False
) -> "BICO":
if not hasattr(self, "bico_obj_"):
raise NotFittedError(
"This BICO instance is not fitted yet. " "Call `fit` or `partial_fit`."
Expand All @@ -152,7 +162,7 @@ def _compute_coreset(self, fit_coreset: bool = False) -> "BICO":
self._n_features_out = n_found_points

if self.fit_coreset or fit_coreset:
self._fit_coreset()
self._fit_coreset(X)

return self

Expand Down Expand Up @@ -188,7 +198,9 @@ def _fit(
_DLL.addData(self.bico_obj_, c_array, c_n)

if not partial or fit_coreset:
self._compute_coreset(fit_coreset)
self._compute_coreset(
X=_X if not partial else None, fit_coreset=fit_coreset
)

return self

Expand All @@ -204,9 +216,9 @@ def fit_predict(
return self.labels_

def predict(self, X: Sequence[Sequence[float]]) -> Any:
self._fit_coreset()

if self.coreset_estimator is None:
raise NotFittedError(self._CORESET_ESTIMATOR_ERROR)

self._fit_coreset()

return self.coreset_estimator.predict(X)

0 comments on commit 208f7d8

Please sign in to comment.