diff --git a/sktree/tree/_classes.py b/sktree/tree/_classes.py index 0849c28e3..18699a60c 100644 --- a/sktree/tree/_classes.py +++ b/sktree/tree/_classes.py @@ -3,7 +3,7 @@ import numpy as np from scipy.sparse import issparse -from sklearn.base import ClusterMixin, TransformerMixin +from sklearn.base import ClusterMixin, TransformerMixin, is_classifier from sklearn.cluster import AgglomerativeClustering from sklearn.utils._param_validation import Interval from sklearn.utils.validation import check_is_fitted @@ -234,6 +234,43 @@ def _build_tree( max_depth, random_state, ): + if self.monotonic_cst is None: + monotonic_cst = None + else: + if self.n_outputs_ > 1: + raise ValueError( + "Monotonicity constraints are not supported with multiple outputs." + ) + # Check to correct monotonicity constraint' specification, + # by applying element-wise logical conjunction + # Note: we do not cast `np.asarray(self.monotonic_cst, dtype=np.int8)` + # straight away here so as to generate error messages for invalid + # values using the original values prior to any dtype related conversion. + monotonic_cst = np.asarray(self.monotonic_cst) + if monotonic_cst.shape[0] != X.shape[1]: + raise ValueError( + "monotonic_cst has shape {} but the input data " + "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) + ) + valid_constraints = np.isin(monotonic_cst, (-1, 0, 1)) + if not np.all(valid_constraints): + unique_constaints_value = np.unique(monotonic_cst) + raise ValueError( + "monotonic_cst must be None or an array-like of -1, 0 or 1, but" + f" got {unique_constaints_value}" + ) + monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) + if self.n_classes_[0] > 2: + raise ValueError( + "Monotonicity constraints are not supported for unsupervised yet." + ) + # Binary classification trees are built by constraining probabilities + # of the *negative class* in order to make the implementation similar + # to regression trees. + # Since self.monotonic_cst encodes constraints on probabilities of the + # *positive class*, all signs must be flipped. + monotonic_cst *= -1 + criterion = self.criterion if not isinstance(criterion, UnsupervisedCriterion): criterion = UNSUPERVISED_CRITERIA[self.criterion]() @@ -250,6 +287,7 @@ def _build_tree( min_samples_leaf, min_weight_leaf, random_state, + monotonic_cst, ) self.tree_ = UnsupervisedTree(self.n_features_in_) @@ -503,8 +541,44 @@ def _build_tree( max_depth, random_state, ): - # TODO: add feature_combinations fix that was used in obliquedecisiontreeclassifier + if self.monotonic_cst is None: + monotonic_cst = None + else: + if self.n_outputs_ > 1: + raise ValueError( + "Monotonicity constraints are not supported with multiple outputs." + ) + # Check to correct monotonicity constraint' specification, + # by applying element-wise logical conjunction + # Note: we do not cast `np.asarray(self.monotonic_cst, dtype=np.int8)` + # straight away here so as to generate error messages for invalid + # values using the original values prior to any dtype related conversion. + monotonic_cst = np.asarray(self.monotonic_cst) + if monotonic_cst.shape[0] != X.shape[1]: + raise ValueError( + "monotonic_cst has shape {} but the input data " + "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) + ) + valid_constraints = np.isin(monotonic_cst, (-1, 0, 1)) + if not np.all(valid_constraints): + unique_constaints_value = np.unique(monotonic_cst) + raise ValueError( + "monotonic_cst must be None or an array-like of -1, 0 or 1, but" + f" got {unique_constaints_value}" + ) + monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) + if self.n_classes_[0] > 2: + raise ValueError( + "Monotonicity constraints are not supported for unsupervised yet." + ) + # Binary classification trees are built by constraining probabilities + # of the *negative class* in order to make the implementation similar + # to regression trees. + # Since self.monotonic_cst encodes constraints on probabilities of the + # *positive class*, all signs must be flipped. + monotonic_cst *= -1 + # TODO: add feature_combinations fix that was used in obliquedecisiontreeclassifier criterion = self.criterion if not isinstance(criterion, UnsupervisedCriterion): criterion = UNSUPERVISED_CRITERIA[self.criterion]() @@ -521,6 +595,7 @@ def _build_tree( min_samples_leaf, min_weight_leaf, random_state, + monotonic_cst, self.feature_combinations, ) @@ -882,6 +957,45 @@ def _build_tree( else: self.feature_combinations_ = self.feature_combinations + if self.monotonic_cst is None: + monotonic_cst = None + else: + if self.n_outputs_ > 1: + raise ValueError( + "Monotonicity constraints are not supported with multiple outputs." + ) + # Check to correct monotonicity constraint' specification, + # by applying element-wise logical conjunction + # Note: we do not cast `np.asarray(self.monotonic_cst, dtype=np.int8)` + # straight away here so as to generate error messages for invalid + # values using the original values prior to any dtype related conversion. + monotonic_cst = np.asarray(self.monotonic_cst) + if monotonic_cst.shape[0] != X.shape[1]: + raise ValueError( + "monotonic_cst has shape {} but the input data " + "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) + ) + valid_constraints = np.isin(monotonic_cst, (-1, 0, 1)) + if not np.all(valid_constraints): + unique_constaints_value = np.unique(monotonic_cst) + raise ValueError( + "monotonic_cst must be None or an array-like of -1, 0 or 1, but" + f" got {unique_constaints_value}" + ) + monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) + if is_classifier(self): + if self.n_classes_[0] > 2: + raise ValueError( + "Monotonicity constraints are not supported with multiclass " + "classification" + ) + # Binary classification trees are built by constraining probabilities + # of the *negative class* in order to make the implementation similar + # to regression trees. + # Since self.monotonic_cst encodes constraints on probabilities of the + # *positive class*, all signs must be flipped. + monotonic_cst *= -1 + # Build tree criterion = self.criterion if not isinstance(criterion, BaseCriterion): @@ -907,6 +1021,7 @@ def _build_tree( min_samples_leaf, min_weight_leaf, random_state, + monotonic_cst, self.feature_combinations_, ) @@ -1239,6 +1354,45 @@ def _build_tree( else: self.feature_combinations_ = self.feature_combinations + if self.monotonic_cst is None: + monotonic_cst = None + else: + if self.n_outputs_ > 1: + raise ValueError( + "Monotonicity constraints are not supported with multiple outputs." + ) + # Check to correct monotonicity constraint' specification, + # by applying element-wise logical conjunction + # Note: we do not cast `np.asarray(self.monotonic_cst, dtype=np.int8)` + # straight away here so as to generate error messages for invalid + # values using the original values prior to any dtype related conversion. + monotonic_cst = np.asarray(self.monotonic_cst) + if monotonic_cst.shape[0] != X.shape[1]: + raise ValueError( + "monotonic_cst has shape {} but the input data " + "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) + ) + valid_constraints = np.isin(monotonic_cst, (-1, 0, 1)) + if not np.all(valid_constraints): + unique_constaints_value = np.unique(monotonic_cst) + raise ValueError( + "monotonic_cst must be None or an array-like of -1, 0 or 1, but" + f" got {unique_constaints_value}" + ) + monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) + if is_classifier(self): + if self.n_classes_[0] > 2: + raise ValueError( + "Monotonicity constraints are not supported with multiclass " + "classification" + ) + # Binary classification trees are built by constraining probabilities + # of the *negative class* in order to make the implementation similar + # to regression trees. + # Since self.monotonic_cst encodes constraints on probabilities of the + # *positive class*, all signs must be flipped. + monotonic_cst *= -1 + # Build tree criterion = self.criterion if not isinstance(criterion, BaseCriterion): @@ -1264,6 +1418,7 @@ def _build_tree( min_samples_leaf, min_weight_leaf, random_state, + monotonic_cst, self.feature_combinations_, ) @@ -1718,6 +1873,45 @@ def _build_tree( random_state : int, RandomState instance or None, default=None Controls the randomness of the estimator. """ + if self.monotonic_cst is None: + monotonic_cst = None + else: + if self.n_outputs_ > 1: + raise ValueError( + "Monotonicity constraints are not supported with multiple outputs." + ) + # Check to correct monotonicity constraint' specification, + # by applying element-wise logical conjunction + # Note: we do not cast `np.asarray(self.monotonic_cst, dtype=np.int8)` + # straight away here so as to generate error messages for invalid + # values using the original values prior to any dtype related conversion. + monotonic_cst = np.asarray(self.monotonic_cst) + if monotonic_cst.shape[0] != X.shape[1]: + raise ValueError( + "monotonic_cst has shape {} but the input data " + "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) + ) + valid_constraints = np.isin(monotonic_cst, (-1, 0, 1)) + if not np.all(valid_constraints): + unique_constaints_value = np.unique(monotonic_cst) + raise ValueError( + "monotonic_cst must be None or an array-like of -1, 0 or 1, but" + f" got {unique_constaints_value}" + ) + monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) + if is_classifier(self): + if self.n_classes_[0] > 2: + raise ValueError( + "Monotonicity constraints are not supported with multiclass " + "classification" + ) + # Binary classification trees are built by constraining probabilities + # of the *negative class* in order to make the implementation similar + # to regression trees. + # Since self.monotonic_cst encodes constraints on probabilities of the + # *positive class*, all signs must be flipped. + monotonic_cst *= -1 + # Build tree criterion = self.criterion if not isinstance(criterion, BaseCriterion): @@ -1743,6 +1937,8 @@ def _build_tree( min_samples_leaf, min_weight_leaf, random_state, + monotonic_cst, + None, self.min_patch_dims_, self.max_patch_dims_, self.dim_contiguous_, @@ -2193,9 +2389,47 @@ def _build_tree( random_state : int, RandomState instance or None, default=None Controls the randomness of the estimator. """ - n_samples = X.shape[0] + if self.monotonic_cst is None: + monotonic_cst = None + else: + if self.n_outputs_ > 1: + raise ValueError( + "Monotonicity constraints are not supported with multiple outputs." + ) + # Check to correct monotonicity constraint' specification, + # by applying element-wise logical conjunction + # Note: we do not cast `np.asarray(self.monotonic_cst, dtype=np.int8)` + # straight away here so as to generate error messages for invalid + # values using the original values prior to any dtype related conversion. + monotonic_cst = np.asarray(self.monotonic_cst) + if monotonic_cst.shape[0] != X.shape[1]: + raise ValueError( + "monotonic_cst has shape {} but the input data " + "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) + ) + valid_constraints = np.isin(monotonic_cst, (-1, 0, 1)) + if not np.all(valid_constraints): + unique_constaints_value = np.unique(monotonic_cst) + raise ValueError( + "monotonic_cst must be None or an array-like of -1, 0 or 1, but" + f" got {unique_constaints_value}" + ) + monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) + if is_classifier(self): + if self.n_classes_[0] > 2: + raise ValueError( + "Monotonicity constraints are not supported with multiclass " + "classification" + ) + # Binary classification trees are built by constraining probabilities + # of the *negative class* in order to make the implementation similar + # to regression trees. + # Since self.monotonic_cst encodes constraints on probabilities of the + # *positive class*, all signs must be flipped. + monotonic_cst *= -1 + # Build tree criterion = self.criterion if not isinstance(criterion, BaseCriterion): @@ -2221,6 +2455,8 @@ def _build_tree( min_samples_leaf, min_weight_leaf, random_state, + monotonic_cst, + None, self.min_patch_dims_, self.max_patch_dims_, self.dim_contiguous_, diff --git a/sktree/tree/_oblique_splitter.pxd b/sktree/tree/_oblique_splitter.pxd index 34b1f26f9..85eb10940 100644 --- a/sktree/tree/_oblique_splitter.pxd +++ b/sktree/tree/_oblique_splitter.pxd @@ -83,7 +83,9 @@ cdef class BaseObliqueSplitter(Splitter): self, double impurity, # Impurity of the node SplitRecord* split, - SIZE_t* n_constant_features + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound ) except -1 nogil diff --git a/sktree/tree/_oblique_splitter.pyx b/sktree/tree/_oblique_splitter.pyx index 1a2c43be8..330d97bc7 100644 --- a/sktree/tree/_oblique_splitter.pyx +++ b/sktree/tree/_oblique_splitter.pyx @@ -169,7 +169,9 @@ cdef class BaseObliqueSplitter(Splitter): self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound ) except -1 nogil: """Find the best_split split on node samples[start:end] @@ -456,7 +458,9 @@ cdef class BestObliqueSplitter(ObliqueSplitter): self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound ) except -1 nogil: """Find the best_split split on node samples[start:end] diff --git a/sktree/tree/unsupervised/_unsup_oblique_splitter.pxd b/sktree/tree/unsupervised/_unsup_oblique_splitter.pxd index 88f765462..141ec79fd 100644 --- a/sktree/tree/unsupervised/_unsup_oblique_splitter.pxd +++ b/sktree/tree/unsupervised/_unsup_oblique_splitter.pxd @@ -55,10 +55,14 @@ cdef class UnsupervisedObliqueSplitter(UnsupervisedSplitter): cdef int node_reset(self, SIZE_t start, SIZE_t end, double* weighted_n_node_samples) except -1 nogil - cdef int node_split(self, - double impurity, # Impurity of the node - SplitRecord* split, - SIZE_t* n_constant_features) except -1 nogil + cdef int node_split( + self, + double impurity, # Impurity of the node + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) except -1 nogil cdef int init( self, const DTYPE_t[:, :] X, diff --git a/sktree/tree/unsupervised/_unsup_oblique_splitter.pyx b/sktree/tree/unsupervised/_unsup_oblique_splitter.pyx index 351e4d080..8b157280b 100644 --- a/sktree/tree/unsupervised/_unsup_oblique_splitter.pyx +++ b/sktree/tree/unsupervised/_unsup_oblique_splitter.pyx @@ -201,8 +201,14 @@ cdef class BestObliqueUnsupervisedSplitter(UnsupervisedObliqueSplitter): proj_mat_indices[proj_i].push_back(feat_i) # Store index of nonzero proj_mat_weights[proj_i].push_back(weight) # Store weight of nonzero - cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) except -1 nogil: + cdef int node_split( + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) except -1 nogil: """Find the best_split split on node samples[start:end] Returns -1 in case of failure to allocate memory (and raise MemoryError) diff --git a/sktree/tree/unsupervised/_unsup_splitter.pxd b/sktree/tree/unsupervised/_unsup_splitter.pxd index da9cf080b..848848be4 100644 --- a/sktree/tree/unsupervised/_unsup_splitter.pxd +++ b/sktree/tree/unsupervised/_unsup_splitter.pxd @@ -40,7 +40,9 @@ cdef class UnsupervisedSplitter(BaseSplitter): self, double impurity, # Impurity of the node SplitRecord* split, - SIZE_t* n_constant_features + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound ) except -1 nogil cdef void node_value(self, double* dest) noexcept nogil cdef double node_impurity(self) noexcept nogil diff --git a/sktree/tree/unsupervised/_unsup_splitter.pyx b/sktree/tree/unsupervised/_unsup_splitter.pyx index 854d35143..fac866d98 100644 --- a/sktree/tree/unsupervised/_unsup_splitter.pyx +++ b/sktree/tree/unsupervised/_unsup_splitter.pyx @@ -34,9 +34,15 @@ cdef inline void _init_split(SplitRecord* self, SIZE_t start_pos) noexcept nogil cdef class UnsupervisedSplitter(BaseSplitter): """Base class for unsupervised splitters.""" - def __cinit__(self, UnsupervisedCriterion criterion, SIZE_t max_features, - SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, *argv): + def __cinit__( + self, + UnsupervisedCriterion criterion, + SIZE_t max_features, + SIZE_t min_samples_leaf, + double min_weight_leaf, + object random_state, + *argv + ): """ Parameters ---------- @@ -174,7 +180,9 @@ cdef class BestUnsupervisedSplitter(UnsupervisedSplitter): self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound ) except -1 nogil: """Find the best_split split on node samples[start:end]. diff --git a/sktree/tree/unsupervised/_unsup_tree.pyx b/sktree/tree/unsupervised/_unsup_tree.pyx index 8cb804b18..9a997e435 100644 --- a/sktree/tree/unsupervised/_unsup_tree.pyx +++ b/sktree/tree/unsupervised/_unsup_tree.pyx @@ -295,6 +295,8 @@ cdef class UnsupervisedBestFirstTreeBuilder(UnsupervisedTreeBuilder): bint is_left, Node* parent, SIZE_t depth, + double lower_bound, + double upper_bound, FrontierRecord* res ) except -1 nogil: """Adds node w/ partition ``[start, end)`` to the frontier. """ @@ -325,7 +327,13 @@ cdef class UnsupervisedBestFirstTreeBuilder(UnsupervisedTreeBuilder): ) if not is_leaf: - splitter.node_split(impurity, split_ptr, &n_constant_features) + splitter.node_split( + impurity, + split_ptr, + &n_constant_features, + lower_bound, + upper_bound + ) # assign local copy of SplitRecord to assign # pos, improvement, and impurity scores @@ -348,12 +356,17 @@ cdef class UnsupervisedBestFirstTreeBuilder(UnsupervisedTreeBuilder): # compute values also for split nodes (might become leafs later). splitter.node_value(tree.value + node_id * tree.value_stride) + if splitter.with_monotonic_cst: + splitter.clip_node_value(tree.value + node_id * tree.value_stride, lower_bound, upper_bound) res.node_id = node_id res.start = start res.end = end res.depth = depth res.impurity = impurity + res.lower_bound = lower_bound + res.upper_bound = upper_bound + res.middle_value = splitter.criterion.middle_value() if not is_leaf: # is split node