From 7e9dc2269e8fb237240a996bc4ed6ef6c9e966ec Mon Sep 17 00:00:00 2001 From: Adam Li Date: Mon, 9 Sep 2024 12:19:35 -0400 Subject: [PATCH] MAINT Clean up Cython files (#321) * Clean up Cython files in oblique and morf splitter * Migrate `self._validate_data` to `validate_data` in scikit-learn developer API * Update spin to v0.12+ * Update c++ to c++11 standard --------- Signed-off-by: Adam Li --- .spin/cmds.py | 5 ++ build_requirements.txt | 2 +- meson.build | 2 +- pyproject.toml | 2 +- treeple/__init__.py | 4 +- treeple/_lib/meson.build | 19 ++++++ treeple/_lib/sklearn_fork | 2 +- treeple/ensemble/_honest_forest.py | 8 ++- treeple/ensemble/_unsupervised_forest.py | 12 ++-- treeple/meson.build | 1 + treeple/neighbors.py | 15 +++-- treeple/tree/_classes.py | 82 ++++++++++++++++++------ treeple/tree/_neighbors.py | 4 -- treeple/tree/_oblique_splitter.pxd | 6 -- treeple/tree/_oblique_splitter.pyx | 42 ++++-------- treeple/tree/_utils.pxd | 38 +++++++++-- treeple/tree/_utils.pyx | 65 ++++++++++++++----- treeple/tree/manifold/_morf_splitter.pxd | 10 +-- treeple/tree/manifold/_morf_splitter.pyx | 14 ++-- 19 files changed, 224 insertions(+), 109 deletions(-) diff --git a/.spin/cmds.py b/.spin/cmds.py index 7a80393d0..b5631b0e6 100644 --- a/.spin/cmds.py +++ b/.spin/cmds.py @@ -5,6 +5,7 @@ import click from spin import util from spin.cmds import meson +from spin.cmds.meson import build_dir_option def get_git_revision_hash(submodule) -> str: @@ -145,14 +146,18 @@ def setup_submodule(forcesubmodule=False): @click.option( "--forcesubmodule", is_flag=True, help="Force submodule pull.", envvar="FORCE_SUBMODULE" ) +@build_dir_option @click.pass_context def build( ctx, + *, meson_args, jobs=None, clean=False, verbose=False, gcov=False, + quiet=False, + build_dir=None, forcesubmodule=False, ): """Build treeple using submodules. diff --git a/build_requirements.txt b/build_requirements.txt index 95bc6c98e..ec63cfb3b 100644 --- a/build_requirements.txt +++ b/build_requirements.txt @@ -8,5 +8,5 @@ click rich-click doit pydevtool -spin +spin>=0.12 build diff --git a/meson.build b/meson.build index 26f909dea..07ec4c9c2 100644 --- a/meson.build +++ b/meson.build @@ -8,7 +8,7 @@ project( license: 'PolyForm Noncommercial 1.0.0', meson_version: '>= 1.1.0', default_options: [ - 'c_std=c99', + 'c_std=c11', 'cpp_std=c++14', ], ) diff --git a/pyproject.toml b/pyproject.toml index 596d2408b..c0a50d95a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ build = [ 'twine', 'meson', 'meson-python', - 'spin', + 'spin>=0.12', 'doit', 'scikit-learn>=1.5.0', 'Cython>=3.0.10', diff --git a/treeple/__init__.py b/treeple/__init__.py index 2a70afefe..dafad7deb 100644 --- a/treeple/__init__.py +++ b/treeple/__init__.py @@ -22,6 +22,7 @@ # https://github.com/ContinuumIO/anaconda-issues/issues/11294 os.environ.setdefault("KMP_INIT_AT_FORK", "FALSE") + try: # This variable is injected in the __builtins__ by the build # process. It is used to enable importing subpackages of sklearn when @@ -64,7 +65,8 @@ msg = """Error importing treeple: you cannot import treeple while being in treeple source directory; please exit the treeple source tree first and relaunch your Python interpreter.""" - raise ImportError(msg) from e + raise Exception(e) + # raise ImportError(msg) from e __all__ = [ "_lib", diff --git a/treeple/_lib/meson.build b/treeple/_lib/meson.build index 5dd37c868..ae83cf4a5 100644 --- a/treeple/_lib/meson.build +++ b/treeple/_lib/meson.build @@ -94,3 +94,22 @@ foreach ext: extensions subdir: 'treeple/_lib/sklearn/utils/', ) endforeach + + +# python_sources = [ +# '__init__.py', +# ] + +# py.install_sources( +# python_sources, +# subdir: 'treeple/_lib' # Folder relative to site-packages to install to +# ) + +# tempita = files('./sklearn/_build_utils/tempita.py') + +# # Copy all the .py files to the install dir, rather than using +# # py.install_sources and needing to list them explicitely one by one +# # install_subdir('sklearn', install_dir: py.get_install_dir()) +# install_subdir('sklearn', install_dir: join_paths(py.get_install_dir(), 'treeple/_lib')) + +# subdir('sklearn') diff --git a/treeple/_lib/sklearn_fork b/treeple/_lib/sklearn_fork index ac5cb8abd..e4b9728cb 160000 --- a/treeple/_lib/sklearn_fork +++ b/treeple/_lib/sklearn_fork @@ -1 +1 @@ -Subproject commit ac5cb8abd5c9b425c3c02a2be1d91296adf643a3 +Subproject commit e4b9728cb8667d0a40ed0c6c45f0414811f5f1f8 diff --git a/treeple/ensemble/_honest_forest.py b/treeple/ensemble/_honest_forest.py index 96c010625..447371b37 100644 --- a/treeple/ensemble/_honest_forest.py +++ b/treeple/ensemble/_honest_forest.py @@ -720,8 +720,12 @@ def oob_samples_(self): oob_samples.append(_oob_samples) return oob_samples - def _more_tags(self): - return {"multioutput": False} + def __sklearn_tags__(self): + # XXX: nans should be supportable in HRF + tags = super().__sklearn_tags__() + tags.classifier_tags.multi_output = False + tags.input_tags.allow_nan = False + return tags def decision_path(self, X): """ diff --git a/treeple/ensemble/_unsupervised_forest.py b/treeple/ensemble/_unsupervised_forest.py index a66c330af..980c1ebbd 100644 --- a/treeple/ensemble/_unsupervised_forest.py +++ b/treeple/ensemble/_unsupervised_forest.py @@ -21,7 +21,12 @@ ) from sklearn.metrics import calinski_harabasz_score from sklearn.utils.parallel import Parallel, delayed -from sklearn.utils.validation import _check_sample_weight, check_is_fitted, check_random_state +from sklearn.utils.validation import ( + _check_sample_weight, + check_is_fitted, + check_random_state, + validate_data, +) from .._lib.sklearn.ensemble._forest import BaseForest from .._lib.sklearn.tree._tree import DTYPE @@ -85,10 +90,9 @@ def fit(self, X, y=None, sample_weight=None): self : object Returns the instance itself. """ - self._validate_params() - # Validate or convert input data - X = self._validate_data( + X = validate_data( + self, X, dtype=DTYPE, # accept_sparse="csc", ) diff --git a/treeple/meson.build b/treeple/meson.build index 3d1715dbe..4801d0536 100644 --- a/treeple/meson.build +++ b/treeple/meson.build @@ -103,6 +103,7 @@ scikit_learn_cython_args = [ '-X language_level=3', '-X boundscheck=' + boundscheck, '-X wraparound=False', '-X initializedcheck=False', '-X nonecheck=False', '-X cdivision=True', '-X profile=False', + '-X embedsignature=True', # Needed for cython imports across subpackages, e.g. cluster pyx that # cimports metrics pxd '--include-dir', meson.global_build_root(), diff --git a/treeple/neighbors.py b/treeple/neighbors.py index 473b4363f..b16e732f9 100644 --- a/treeple/neighbors.py +++ b/treeple/neighbors.py @@ -5,8 +5,9 @@ from sklearn.base import BaseEstimator, MetaEstimatorMixin from sklearn.exceptions import NotFittedError from sklearn.neighbors import NearestNeighbors -from sklearn.utils.validation import check_is_fitted +from sklearn.utils.validation import check_is_fitted, validate_data +from treeple.tree import DecisionTreeClassifier from treeple.tree._neighbors import _compute_distance_matrix, compute_forest_similarity_matrix @@ -31,13 +32,19 @@ class NearestNeighborsMetaEstimator(BaseEstimator, MetaEstimatorMixin): The number of parallel jobs to run for neighbors, by default None. """ - def __init__(self, estimator, n_neighbors=5, radius=1.0, algorithm="auto", n_jobs=None): + def __init__(self, estimator=None, n_neighbors=5, radius=1.0, algorithm="auto", n_jobs=None): self.estimator = estimator self.n_neighbors = n_neighbors self.algorithm = algorithm self.radius = radius self.n_jobs = n_jobs + def get_estimator(self): + if self.estimator is not None: + return DecisionTreeClassifier(random_state=0) + else: + return copy(self.estimator) + def fit(self, X, y=None): """Fit the nearest neighbors estimator from the training dataset. @@ -56,9 +63,9 @@ def fit(self, X, y=None): self : object Fitted estimator. """ - X, y = self._validate_data(X, y, accept_sparse="csc") + X, y = validate_data(self, X, y, accept_sparse="csc") - self.estimator_ = copy(self.estimator) + self.estimator_ = self.get_estimator() try: check_is_fitted(self.estimator_) except NotFittedError: diff --git a/treeple/tree/_classes.py b/treeple/tree/_classes.py index 16eb6ea52..aa93d4c08 100644 --- a/treeple/tree/_classes.py +++ b/treeple/tree/_classes.py @@ -8,7 +8,7 @@ from sklearn.cluster import AgglomerativeClustering from sklearn.utils import check_random_state from sklearn.utils._param_validation import Interval -from sklearn.utils.validation import check_is_fitted +from sklearn.utils.validation import check_is_fitted, validate_data from .._lib.sklearn.tree import ( BaseDecisionTree, @@ -216,7 +216,7 @@ def fit(self, X, y=None, sample_weight=None, check_input=True): if check_input: # TODO: allow X to be sparse check_X_params = dict(dtype=DTYPE) # , accept_sparse="csc" - X = self._validate_data(X, validate_separately=(check_X_params)) + X = validate_data(self, X, validate_separately=(check_X_params)) if issparse(X): X.sort_indices() @@ -378,6 +378,13 @@ def _assign_labels(self, affinity_matrix): predict_labels = cluster.fit_predict(affinity_matrix) return predict_labels + def __sklearn_tags__(self): + # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values + # However, for MORF it is not supported + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = False + return tags + class UnsupervisedObliqueDecisionTree(UnsupervisedDecisionTree): """Unsupervised oblique decision tree. @@ -577,6 +584,13 @@ def _build_tree( builder.build(self.tree_, X, sample_weight) return self + def __sklearn_tags__(self): + # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values + # However, for MORF it is not supported + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = False + return tags + class ObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): """An oblique decision tree classifier. @@ -820,7 +834,7 @@ class ObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): tree_type = "oblique" - _parameter_constraints = { + _parameter_constraints: dict = { **DecisionTreeClassifier._parameter_constraints, "feature_combinations": [ Interval(Real, 1.0, None, closed="left"), @@ -1070,6 +1084,13 @@ def _update_tree(self, X, y, sample_weight): self._prune_tree() return self + def __sklearn_tags__(self): + # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values + # However, for MORF it is not supported + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = False + return tags + class ObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor): """An oblique decision tree Regressor. @@ -1283,7 +1304,7 @@ class ObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor): tree_type = "oblique" - _parameter_constraints = { + _parameter_constraints: dict = { **DecisionTreeRegressor._parameter_constraints, "feature_combinations": [ Interval(Real, 1.0, None, closed="left"), @@ -1450,6 +1471,13 @@ def _build_tree( builder.build(self.tree_, X, y, sample_weight, None) return self + def __sklearn_tags__(self): + # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values + # However, for MORF it is not supported + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = False + return tags + class PatchObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): """A oblique decision tree classifier that operates over patches of data. @@ -1684,7 +1712,7 @@ class PatchObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier) """ tree_type = "oblique" - _parameter_constraints = { + _parameter_constraints: dict = { **DecisionTreeClassifier._parameter_constraints, "min_patch_dims": ["array-like", None], "max_patch_dims": ["array-like", None], @@ -1798,8 +1826,8 @@ def _build_tree( self.feature_combinations_ = 1 if self.feature_weight is not None: - self.feature_weight = self._validate_data( - self.feature_weight, ensure_2d=True, dtype=DTYPE + self.feature_weight = validate_data( + self, self.feature_weight, ensure_2d=True, dtype=DTYPE ) if self.feature_weight.shape != X.shape: raise ValueError( @@ -1927,11 +1955,13 @@ def _build_tree( return self - def _more_tags(self): + def __sklearn_tags__(self): # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values # However, for MORF it is not supported - allow_nan = False - return {"multilabel": True, "allow_nan": allow_nan} + tags = super().__sklearn_tags__() + tags.classifier_tags.multi_label = True + tags.input_tags.allow_nan = False + return tags @property def _inheritable_fitted_attribute(self): @@ -2166,7 +2196,7 @@ class PatchObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor): """ tree_type = "oblique" - _parameter_constraints = { + _parameter_constraints: dict = { **DecisionTreeRegressor._parameter_constraints, "min_patch_dims": ["array-like", None], "max_patch_dims": ["array-like", None], @@ -2277,8 +2307,8 @@ def _build_tree( self.feature_combinations_ = 1 if self.feature_weight is not None: - self.feature_weight = self._validate_data( - self.feature_weight, ensure_2d=True, dtype=DTYPE + self.feature_weight = validate_data( + self, self.feature_weight, ensure_2d=True, dtype=DTYPE ) if self.feature_weight.shape != X.shape: raise ValueError( @@ -2407,11 +2437,13 @@ def _build_tree( return self - def _more_tags(self): + def __sklearn_tags__(self): # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values # However, for MORF it is not supported - allow_nan = False - return {"multilabel": True, "allow_nan": allow_nan} + tags = super().__sklearn_tags__() + tags.regressor_tags.multi_label = True + tags.input_tags.allow_nan = False + return tags class ExtraObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): @@ -2669,7 +2701,7 @@ class ExtraObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier) tree_type = "oblique" - _parameter_constraints = { + _parameter_constraints: dict = { **DecisionTreeClassifier._parameter_constraints, "feature_combinations": [ Interval(Real, 1.0, None, closed="left"), @@ -2846,6 +2878,13 @@ def _inheritable_fitted_attribute(self): "feature_combinations_", ] + def __sklearn_tags__(self): + # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values + # However, for MORF it is not supported + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = False + return tags + class ExtraObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor): """An oblique decision tree Regressor. @@ -3069,7 +3108,7 @@ class ExtraObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor): -0.26552594, -0.00642017, -0.07108117, -0.40726765, -0.40315294]) """ - _parameter_constraints = { + _parameter_constraints: dict = { **DecisionTreeRegressor._parameter_constraints, "feature_combinations": [ Interval(Real, 1.0, None, closed="left"), @@ -3237,3 +3276,10 @@ def _build_tree( builder.build(self.tree_, X, y, sample_weight) return self + + def __sklearn_tags__(self): + # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values + # However, for MORF it is not supported + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = False + return tags diff --git a/treeple/tree/_neighbors.py b/treeple/tree/_neighbors.py index 94f2c8f18..93d8ff1a0 100644 --- a/treeple/tree/_neighbors.py +++ b/treeple/tree/_neighbors.py @@ -64,7 +64,3 @@ def compute_similarity_matrix(self, X): The similarity matrix among the samples. """ return compute_forest_similarity_matrix(self, X) - - def _more_tags(self): - # XXX: no treeple estimators support NaNs as of now - return {"allow_nan": False} diff --git a/treeple/tree/_oblique_splitter.pxd b/treeple/tree/_oblique_splitter.pxd index 124a66dd6..65ca16e14 100644 --- a/treeple/tree/_oblique_splitter.pxd +++ b/treeple/tree/_oblique_splitter.pxd @@ -83,12 +83,6 @@ cdef class BaseObliqueSplitter(Splitter): SplitRecord* split, ) except -1 nogil - cdef inline void fisher_yates_shuffle_memview( - self, - intp_t[::1] indices_to_sample, - intp_t grid_size, - uint32_t* random_state - ) noexcept nogil cdef class ObliqueSplitter(BaseObliqueSplitter): # The splitter searches in the input space for a linear combination of features and a threshold diff --git a/treeple/tree/_oblique_splitter.pyx b/treeple/tree/_oblique_splitter.pyx index ca77a30ac..0cceac664 100644 --- a/treeple/tree/_oblique_splitter.pyx +++ b/treeple/tree/_oblique_splitter.pyx @@ -11,6 +11,7 @@ from libcpp.vector cimport vector from .._lib.sklearn.tree._criterion cimport Criterion from .._lib.sklearn.tree._utils cimport rand_int, rand_uniform +from ._utils cimport fisher_yates_shuffle cdef float64_t INFINITY = np.inf @@ -46,8 +47,12 @@ cdef class BaseObliqueSplitter(Splitter): def __setstate__(self, d): pass - cdef int node_reset(self, intp_t start, intp_t end, - float64_t* weighted_n_node_samples) except -1 nogil: + cdef int node_reset( + self, + intp_t start, + intp_t end, + float64_t* weighted_n_node_samples + ) except -1 nogil: """Reset splitter on node samples[start:end]. Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -62,17 +67,7 @@ cdef class BaseObliqueSplitter(Splitter): weighted_n_node_samples : ndarray, dtype=float64_t pointer The total weight of those samples """ - - self.start = start - self.end = end - - self.criterion.init(self.y, - self.sample_weight, - self.weighted_n_samples, - self.samples) - self.criterion.set_sample_pointers(start, end) - - weighted_n_node_samples[0] = self.criterion.weighted_n_node_samples + Splitter.node_reset(self, start, end, weighted_n_node_samples) # Clear all projection vectors for i in range(self.max_features): @@ -102,8 +97,8 @@ cdef class BaseObliqueSplitter(Splitter): intp_t end, const intp_t[:] samples, float32_t[:] feature_values, - vector[float32_t]* proj_vec_weights, # weights of the vector (max_features,) - vector[intp_t]* proj_vec_indices # indices of the features (max_features,) + vector[float32_t]* proj_vec_weights, # weights of the vector (n_non_zeros,) + vector[intp_t]* proj_vec_indices # indices of the features (n_non_zeros,) ) noexcept nogil: """Compute the feature values for the samples[start:end] range. @@ -126,19 +121,6 @@ cdef class BaseObliqueSplitter(Splitter): feature_values[idx] = 0.0 feature_values[idx] += self.X[samples[idx], col_idx] * col_weight - cdef inline void fisher_yates_shuffle_memview( - self, - intp_t[::1] indices_to_sample, - intp_t grid_size, - uint32_t* random_state, - ) noexcept nogil: - cdef intp_t i, j - - # XXX: should this be `i` or `i+1`? for valid Fisher-Yates? - for i in range(0, grid_size - 1): - j = rand_int(i, grid_size, random_state) - indices_to_sample[j], indices_to_sample[i] = \ - indices_to_sample[i], indices_to_sample[j] cdef class ObliqueSplitter(BaseObliqueSplitter): def __cinit__( @@ -257,7 +239,7 @@ cdef class ObliqueSplitter(BaseObliqueSplitter): cdef intp_t grid_size = self.max_features * self.n_features # shuffle indices over the 2D grid to sample using Fisher-Yates - self.fisher_yates_shuffle_memview(indices_to_sample, grid_size, random_state) + fisher_yates_shuffle(indices_to_sample, grid_size, random_state) # sample 'n_non_zeros' in a mtry X n_features projection matrix # which consists of +/- 1's chosen at a 1/2s rate @@ -309,7 +291,7 @@ cdef class BestObliqueSplitter(ObliqueSplitter): cdef intp_t end = self.end # pointer array to store feature values to split on - cdef float32_t[::1] feature_values = self.feature_values + cdef float32_t[::1] feature_values = self.feature_values cdef intp_t max_features = self.max_features cdef intp_t min_samples_leaf = self.min_samples_leaf diff --git a/treeple/tree/_utils.pxd b/treeple/tree/_utils.pxd index c814cc166..ba2707791 100644 --- a/treeple/tree/_utils.pxd +++ b/treeple/tree/_utils.pxd @@ -1,3 +1,5 @@ +from libcpp.vector cimport vector + import numpy as np cimport numpy as cnp @@ -7,15 +9,41 @@ cnp.import_array() from .._lib.sklearn.tree._splitter cimport SplitRecord from .._lib.sklearn.utils._typedefs cimport float32_t, float64_t, int32_t, intp_t, uint32_t +ctypedef fused vector_or_memview: + vector[intp_t] + intp_t[::1] + intp_t[:] + + +cdef void fisher_yates_shuffle( + vector_or_memview indices_to_sample, + intp_t grid_size, + uint32_t* random_state, +) noexcept nogil -cdef int rand_weighted_binary(float64_t p0, uint32_t* random_state) noexcept nogil + +cdef int rand_weighted_binary( + float64_t p0, + uint32_t* random_state +) noexcept nogil cpdef unravel_index( - intp_t index, cnp.ndarray[intp_t, ndim=1] shape + intp_t index, + cnp.ndarray[intp_t, ndim=1] shape ) -cpdef ravel_multi_index(intp_t[:] coords, const intp_t[:] shape) +cpdef ravel_multi_index( + intp_t[:] coords, + const intp_t[:] shape +) -cdef void unravel_index_cython(intp_t index, const intp_t[:] shape, intp_t[:] coords) noexcept nogil +cdef void unravel_index_cython( + intp_t index, + const intp_t[:] shape, + vector_or_memview coords +) noexcept nogil -cdef intp_t ravel_multi_index_cython(intp_t[:] coords, const intp_t[:] shape) noexcept nogil +cdef intp_t ravel_multi_index_cython( + vector_or_memview coords, + const intp_t[:] shape +) noexcept nogil diff --git a/treeple/tree/_utils.pyx b/treeple/tree/_utils.pyx index 197b82ecf..7ce48977b 100644 --- a/treeple/tree/_utils.pyx +++ b/treeple/tree/_utils.pyx @@ -11,10 +11,40 @@ cimport numpy as cnp cnp.import_array() -from .._lib.sklearn.tree._utils cimport rand_uniform +from .._lib.sklearn.tree._utils cimport rand_int, rand_uniform -cdef inline int rand_weighted_binary(float64_t p0, uint32_t* random_state) noexcept nogil: +cdef inline void fisher_yates_shuffle( + vector_or_memview indices_to_sample, + intp_t grid_size, + uint32_t* random_state, +) noexcept nogil: + """Shuffle the indices in place using the Fisher-Yates algorithm. + Parameters + ---------- + indices_to_sample : A C++ vector or 1D memoryview + The indices to shuffle. + grid_size : intp_t + The size of the grid to shuffle. This is explicitly passed in + to support the templated `vector_or_memview` type, which allows + for both C++ vectors and Cython memoryviews. Getitng the length + of both types uses different API. + random_state : uint32_t* + The random state. + """ + cdef intp_t i, j + + # XXX: should this be `i` or `i+1`? for valid Fisher-Yates? + for i in range(0, grid_size - 1): + j = rand_int(i, grid_size, random_state) + indices_to_sample[j], indices_to_sample[i] = \ + indices_to_sample[i], indices_to_sample[j] + + +cdef inline int rand_weighted_binary( + float64_t p0, + uint32_t* random_state +) noexcept nogil: """Sample from integers 0 and 1 with different probabilities. Parameters @@ -54,7 +84,9 @@ cpdef unravel_index( index = np.intp(index) shape = np.array(shape) coords = np.empty(shape.shape[0], dtype=np.intp) - unravel_index_cython(index, shape, coords) + cdef const intp_t[:] shape_memview = shape + cdef intp_t[:] coords_memview = coords + unravel_index_cython(index, shape_memview, coords_memview) return coords @@ -83,7 +115,11 @@ cpdef ravel_multi_index(intp_t[:] coords, const intp_t[:] shape): return ravel_multi_index_cython(coords, shape) -cdef void unravel_index_cython(intp_t index, const intp_t[:] shape, intp_t[:] coords) noexcept nogil: +cdef inline void unravel_index_cython( + intp_t index, + const intp_t[:] shape, + vector_or_memview coords +) noexcept nogil: """Converts a flat index into a tuple of coordinate arrays. Parameters @@ -92,13 +128,9 @@ cdef void unravel_index_cython(intp_t index, const intp_t[:] shape, intp_t[:] co The flat index to be converted. shape : numpy.ndarray[intp_t, ndim=1] The shape of the array into which the flat index should be converted. - coords : numpy.ndarray[intp_t, ndim=1] - A preinitialized memoryview array of coordinate arrays to be converted. - - Returns - ------- - numpy.ndarray[intp_t, ndim=1] - An array of coordinate arrays, with each coordinate array having the same shape as the input `shape`. + coords : intp_t[:] or vector[intp_t] + A preinitialized array of coordinates to store the result of the + unraveled `index`. """ cdef intp_t ndim = shape.shape[0] cdef intp_t j, size @@ -109,13 +141,16 @@ cdef void unravel_index_cython(intp_t index, const intp_t[:] shape, intp_t[:] co index //= size -cdef intp_t ravel_multi_index_cython(intp_t[:] coords, const intp_t[:] shape) noexcept nogil: - """Converts a tuple of coordinate arrays into a flat index. +cdef inline intp_t ravel_multi_index_cython( + vector_or_memview coords, + const intp_t[:] shape +) noexcept nogil: + """Converts a tuple of coordinate arrays into a flat index in the vectorized dimension. Parameters ---------- - coords : numpy.ndarray[intp_t, ndim=1] - An array of coordinate arrays to be converted. + coords : intp_t[:] or vector[intp_t] + An array of coordinates to be converted and vectorized into a sinlg shape : numpy.ndarray[intp_t, ndim=1] The shape of the array into which the coordinates should be converted. diff --git a/treeple/tree/manifold/_morf_splitter.pxd b/treeple/tree/manifold/_morf_splitter.pxd index a0a61a4de..2b65fd3ba 100644 --- a/treeple/tree/manifold/_morf_splitter.pxd +++ b/treeple/tree/manifold/_morf_splitter.pxd @@ -32,14 +32,6 @@ cdef class PatchSplitter(BestObliqueSplitter): # an input data vector. The input data is vectorized, so `data_height` and # `data_width` are used to determine the vectorized indices corresponding to # (x,y) coordinates in the original un-vectorized data. - - cdef public intp_t max_patch_height # Maximum height of the patch to sample - cdef public intp_t max_patch_width # Maximum width of the patch to sample - cdef public intp_t min_patch_height # Minimum height of the patch to sample - cdef public intp_t min_patch_width # Minimum width of the patch to sample - cdef public intp_t data_height # Height of the input data - cdef public intp_t data_width # Width of the input data - cdef public intp_t ndim # The number of dimensions of the input data cdef const intp_t[:] data_dims # The dimensions of the input data @@ -56,7 +48,7 @@ cdef class PatchSplitter(BestObliqueSplitter): cdef intp_t[::1] _index_data_buffer cdef intp_t[::1] _index_patch_buffer - cdef intp_t[:] patch_dims_buff # A buffer to store the dimensions of the sampled patch + cdef intp_t[:] patch_sampled_size # A buffer to store the dimensions of the sampled patch cdef intp_t[:] unraveled_patch_point # A buffer to store the unraveled patch point # All oblique splitters (i.e. non-axis aligned splitters) require a diff --git a/treeple/tree/manifold/_morf_splitter.pyx b/treeple/tree/manifold/_morf_splitter.pyx index d6c8d0121..f1eaf2918 100644 --- a/treeple/tree/manifold/_morf_splitter.pyx +++ b/treeple/tree/manifold/_morf_splitter.pyx @@ -151,7 +151,7 @@ cdef class BestPatchSplitter(BaseDensePatchSplitter): self.data_dims = data_dims # create a buffer for storing the patch dimensions sampled per projection matrix - self.patch_dims_buff = np.zeros(data_dims.shape[0], dtype=np.intp) + self.patch_sampled_size = np.zeros(data_dims.shape[0], dtype=np.intp) self.unraveled_patch_point = np.zeros(data_dims.shape[0], dtype=np.intp) # store the min and max patch dimension constraints @@ -237,7 +237,7 @@ cdef class BestPatchSplitter(BaseDensePatchSplitter): top_left_patch_seed = rand_int(0, delta_patch_dim, random_state) # write to buffer - self.patch_dims_buff[idx] = patch_dim + self.patch_sampled_size[idx] = patch_dim patch_size *= patch_dim elif self.boundary == "wrap": # add circular boundary conditions @@ -251,7 +251,7 @@ cdef class BestPatchSplitter(BaseDensePatchSplitter): # resample the patch dimension due to padding patch_dim = min(patch_dim, min(dim+1, self.data_dims[idx] + patch_dim - dim - 1)) - self.patch_dims_buff[idx] = patch_dim + self.patch_sampled_size[idx] = patch_dim patch_size *= patch_dim # TODO: make this work @@ -283,7 +283,7 @@ cdef class BestPatchSplitter(BaseDensePatchSplitter): cdef intp_t top_left_patch_seed # size of the sampled patch, which is just the size of the n-dim patch - # (\prod_i self.patch_dims_buff[i]) + # (\prod_i self.patch_sampled_size[i]) cdef intp_t patch_size for proj_i in range(0, max_features): @@ -299,7 +299,7 @@ cdef class BestPatchSplitter(BaseDensePatchSplitter): proj_i, patch_size, top_left_patch_seed, - self.patch_dims_buff + self.patch_sampled_size ) cdef void sample_proj_vec( @@ -389,7 +389,7 @@ cdef class BestPatchSplitter(BaseDensePatchSplitter): if not self.dim_contiguous[idx]: row_index += ( (self.unraveled_patch_point[idx] // other_dims_offset) % - self.patch_dims_buff[idx] + self.patch_sampled_size[idx] ) * other_dims_offset other_dims_offset //= self.data_dims[idx] @@ -445,7 +445,7 @@ cdef class BestPatchSplitterTester(BestPatchSplitter): """A class to expose a Python interface for testing.""" cpdef sample_top_left_seed_cpdef(self): top_left_patch_seed, patch_size = self.sample_top_left_seed() - patch_dims = np.array(self.patch_dims_buff, dtype=np.intp) + patch_dims = np.array(self.patch_sampled_size, dtype=np.intp) return top_left_patch_seed, patch_size, patch_dims cpdef sample_projection_vector(