diff --git a/bico/core.py b/bico/core.py index 5184544..66ec518 100644 --- a/bico/core.py +++ b/bico/core.py @@ -64,8 +64,14 @@ def __init__( @property def labels_(self) -> np.ndarray: - if not hasattr(self, "_labels"): + if not hasattr(self, "_cluster_centers"): raise NotFittedError(self._CORESET_ESTIMATOR_ERROR) + elif not hasattr(self, "_labels"): + raise ValueError( + "The labels have not been computed because the coreset " + "was fit using partial_fit. " + "Please call predict on your data to obtain the labels." + ) return self._labels @property @@ -114,6 +120,7 @@ def partial_fit( def _fit_coreset( self, + X: Optional[np.ndarray] = None, ) -> None: if self.coreset_estimator is None: from sklearn.cluster import KMeans @@ -127,10 +134,13 @@ def _fit_coreset( self._coreset_points, sample_weight=self._coreset_weights ) self._cluster_centers: np.ndarray = self.coreset_estimator.cluster_centers_ - self._labels: np.ndarray = self.coreset_estimator.labels_ + if X is not None: + self._labels: np.ndarray = self.coreset_estimator.predict(X) self._inertia: float = self.coreset_estimator.inertia_ - def _compute_coreset(self, fit_coreset: bool = False) -> "BICO": + def _compute_coreset( + self, X: Optional[np.ndarray] = None, fit_coreset: bool = False + ) -> "BICO": if not hasattr(self, "bico_obj_"): raise NotFittedError( "This BICO instance is not fitted yet. " "Call `fit` or `partial_fit`." @@ -152,7 +162,7 @@ def _compute_coreset(self, fit_coreset: bool = False) -> "BICO": self._n_features_out = n_found_points if self.fit_coreset or fit_coreset: - self._fit_coreset() + self._fit_coreset(X) return self @@ -188,7 +198,9 @@ def _fit( _DLL.addData(self.bico_obj_, c_array, c_n) if not partial or fit_coreset: - self._compute_coreset(fit_coreset) + self._compute_coreset( + X=_X if not partial else None, fit_coreset=fit_coreset + ) return self @@ -204,9 +216,9 @@ def fit_predict( return self.labels_ def predict(self, X: Sequence[Sequence[float]]) -> Any: - self._fit_coreset() - if self.coreset_estimator is None: raise NotFittedError(self._CORESET_ESTIMATOR_ERROR) + self._fit_coreset() + return self.coreset_estimator.predict(X)