Skip to content

Commit

Permalink
WIP for sklearn monotonicity
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Jul 5, 2023
1 parent 66aa004 commit 7072c31
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 18 deletions.
242 changes: 239 additions & 3 deletions sktree/tree/_classes.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion sktree/tree/_oblique_splitter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ cdef class BaseObliqueSplitter(Splitter):
self,
double impurity, # Impurity of the node
SplitRecord* split,
SIZE_t* n_constant_features
SIZE_t* n_constant_features,
double lower_bound,
double upper_bound
) except -1 nogil


Expand Down
8 changes: 6 additions & 2 deletions sktree/tree/_oblique_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ cdef class BaseObliqueSplitter(Splitter):
self,
double impurity,
SplitRecord* split,
SIZE_t* n_constant_features
SIZE_t* n_constant_features,
double lower_bound,
double upper_bound
) except -1 nogil:
"""Find the best_split split on node samples[start:end]
Expand Down Expand Up @@ -456,7 +458,9 @@ cdef class BestObliqueSplitter(ObliqueSplitter):
self,
double impurity,
SplitRecord* split,
SIZE_t* n_constant_features
SIZE_t* n_constant_features,
double lower_bound,
double upper_bound
) except -1 nogil:
"""Find the best_split split on node samples[start:end]
Expand Down
12 changes: 8 additions & 4 deletions sktree/tree/unsupervised/_unsup_oblique_splitter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,14 @@ cdef class UnsupervisedObliqueSplitter(UnsupervisedSplitter):
cdef int node_reset(self, SIZE_t start, SIZE_t end,
double* weighted_n_node_samples) except -1 nogil

cdef int node_split(self,
double impurity, # Impurity of the node
SplitRecord* split,
SIZE_t* n_constant_features) except -1 nogil
cdef int node_split(
self,
double impurity, # Impurity of the node
SplitRecord* split,
SIZE_t* n_constant_features,
double lower_bound,
double upper_bound
) except -1 nogil
cdef int init(
self,
const DTYPE_t[:, :] X,
Expand Down
10 changes: 8 additions & 2 deletions sktree/tree/unsupervised/_unsup_oblique_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,14 @@ cdef class BestObliqueUnsupervisedSplitter(UnsupervisedObliqueSplitter):
proj_mat_indices[proj_i].push_back(feat_i) # Store index of nonzero
proj_mat_weights[proj_i].push_back(weight) # Store weight of nonzero

cdef int node_split(self, double impurity, SplitRecord* split,
SIZE_t* n_constant_features) except -1 nogil:
cdef int node_split(
self,
double impurity,
SplitRecord* split,
SIZE_t* n_constant_features,
double lower_bound,
double upper_bound
) except -1 nogil:
"""Find the best_split split on node samples[start:end]
Returns -1 in case of failure to allocate memory (and raise MemoryError)
Expand Down
4 changes: 3 additions & 1 deletion sktree/tree/unsupervised/_unsup_splitter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ cdef class UnsupervisedSplitter(BaseSplitter):
self,
double impurity, # Impurity of the node
SplitRecord* split,
SIZE_t* n_constant_features
SIZE_t* n_constant_features,
double lower_bound,
double upper_bound
) except -1 nogil
cdef void node_value(self, double* dest) noexcept nogil
cdef double node_impurity(self) noexcept nogil
16 changes: 12 additions & 4 deletions sktree/tree/unsupervised/_unsup_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,15 @@ cdef inline void _init_split(SplitRecord* self, SIZE_t start_pos) noexcept nogil
cdef class UnsupervisedSplitter(BaseSplitter):
"""Base class for unsupervised splitters."""

def __cinit__(self, UnsupervisedCriterion criterion, SIZE_t max_features,
SIZE_t min_samples_leaf, double min_weight_leaf,
object random_state, *argv):
def __cinit__(
self,
UnsupervisedCriterion criterion,
SIZE_t max_features,
SIZE_t min_samples_leaf,
double min_weight_leaf,
object random_state,
*argv
):
"""
Parameters
----------
Expand Down Expand Up @@ -174,7 +180,9 @@ cdef class BestUnsupervisedSplitter(UnsupervisedSplitter):
self,
double impurity,
SplitRecord* split,
SIZE_t* n_constant_features
SIZE_t* n_constant_features,
double lower_bound,
double upper_bound
) except -1 nogil:
"""Find the best_split split on node samples[start:end].
Expand Down
15 changes: 14 additions & 1 deletion sktree/tree/unsupervised/_unsup_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ cdef class UnsupervisedBestFirstTreeBuilder(UnsupervisedTreeBuilder):
bint is_left,
Node* parent,
SIZE_t depth,
double lower_bound,
double upper_bound,
FrontierRecord* res
) except -1 nogil:
"""Adds node w/ partition ``[start, end)`` to the frontier. """
Expand Down Expand Up @@ -325,7 +327,13 @@ cdef class UnsupervisedBestFirstTreeBuilder(UnsupervisedTreeBuilder):
)

if not is_leaf:
splitter.node_split(impurity, split_ptr, &n_constant_features)
splitter.node_split(
impurity,
split_ptr,
&n_constant_features,
lower_bound,
upper_bound
)

# assign local copy of SplitRecord to assign
# pos, improvement, and impurity scores
Expand All @@ -348,12 +356,17 @@ cdef class UnsupervisedBestFirstTreeBuilder(UnsupervisedTreeBuilder):

# compute values also for split nodes (might become leafs later).
splitter.node_value(tree.value + node_id * tree.value_stride)
if splitter.with_monotonic_cst:
splitter.clip_node_value(tree.value + node_id * tree.value_stride, lower_bound, upper_bound)

res.node_id = node_id
res.start = start
res.end = end
res.depth = depth
res.impurity = impurity
res.lower_bound = lower_bound
res.upper_bound = upper_bound
res.middle_value = splitter.criterion.middle_value()

if not is_leaf:
# is split node
Expand Down

0 comments on commit 7072c31

Please sign in to comment.