diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 2db5b574ade2a..622e1160faefb 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -86,6 +86,21 @@ def predict(self, X): y.fill(self.quantile) return y +## TODO REMOVE +class LinexEstimator(BaseEstimator): + """An estimator predicting the mean of the training targets.""" + + def fit(self, X, y, sample_weight=None): + alpha = 0.1 + if sample_weight is None: + self.opt = (1 / alpha) * np.log(np.mean(np.exp(y * alpha))) + else: + self.opt = (1 / alpha) * np.log(np.mean(np.exp(y * alpha) * sample_weight)) + + def predict(self, X): + y = np.empty((X.shape[0], 1), dtype=np.float64) + y.fill(self.opt) + return y class MeanEstimator(BaseEstimator): """An estimator predicting the mean of the training targets.""" @@ -397,6 +412,43 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, np.minimum(np.abs(diff_minus_median), gamma)) +# TODO REMOVE +class LinexLossFunction(RegressionLossFunction): + def __init__(self, n_classes, grad_factor=100): + super(LinexLossFunction, self).__init__(n_classes) + self.grad_factor = grad_factor + + def init_estimator(self): + return LinexEstimator() + + def __call__(self, y, pred, sample_weight=None): + alpha = 0.1 + if sample_weight is None: + return np.mean(np.exp((y - pred.ravel()) * alpha) - alpha * (y - pred.ravel()) - 1) + else: + return (1.0 / sample_weight.sum() * + np.sum(sample_weight * (np.exp((y - pred.ravel()) * alpha) - alpha * (y - pred.ravel()) - 1))) + + def negative_gradient(self, y, pred, **kargs): + """1.0 if y - pred > 0.0 else -1.0""" + + alpha = 0.1 + pred = pred.ravel() + grad = alpha * (np.exp(alpha * (y - pred)) - 1) + return self.grad_factor * grad + + def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, + residual, pred, sample_weight): + alpha = 0.1 + terminal_region = np.where(terminal_regions == leaf)[0] + sample_weight = sample_weight.take(terminal_region, axis=0) + diff = y.take(terminal_region, axis=0) - pred.take(terminal_region, axis=0) + if sample_weight is None: + tree.value[leaf, 0, 0] = (1 / alpha) * np.log(np.mean(np.exp(diff * alpha))) ## TODO + else: + tree.value[leaf, 0, 0] = (1 / alpha) * np.log(np.mean(sample_weight * np.exp(diff * alpha))) + + class QuantileLossFunction(RegressionLossFunction): """Loss function for quantile regression. @@ -653,6 +705,7 @@ def _score_to_decision(self, score): 'lad': LeastAbsoluteError, 'huber': HuberLossFunction, 'quantile': QuantileLossFunction, + 'linex': LinexLossFunction, # TODO REMOVE 'deviance': None, # for both, multinomial and binomial 'exponential': ExponentialLoss, } @@ -724,7 +777,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion, min_samples_split, min_samples_leaf, min_weight_fraction_leaf, max_depth, min_impurity_split, init, subsample, max_features, random_state, alpha=0.9, verbose=0, max_leaf_nodes=None, - warm_start=False, presort='auto'): + warm_start=False, presort='auto', grad_factor=100): self.n_estimators = n_estimators self.learning_rate = learning_rate @@ -744,6 +797,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion, self.max_leaf_nodes = max_leaf_nodes self.warm_start = warm_start self.presort = presort + self.grad_factor = grad_factor ## TODO REMOVE self.estimators_ = np.empty((0, 0), dtype=np.object) @@ -762,6 +816,7 @@ def _fit_stage(self, i, X, y, y_pred, sample_weight, sample_mask, residual = loss.negative_gradient(y, y_pred, k=k, sample_weight=sample_weight) + max_depth = self.max_depth # induce regression tree on residuals tree = DecisionTreeRegressor( criterion=self.criterion, @@ -824,6 +879,8 @@ def _check_params(self): if self.loss in ('huber', 'quantile'): self.loss_ = loss_class(self.n_classes_, self.alpha) + elif self.loss in ('linex'): + self.loss_ = loss_class(self.n_classes_, self.grad_factor) else: self.loss_ = loss_class(self.n_classes_) @@ -1831,14 +1888,14 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): Elements of Statistical Learning Ed. 2, Springer, 2009. """ - _SUPPORTED_LOSS = ('ls', 'lad', 'huber', 'quantile') + _SUPPORTED_LOSS = ('ls', 'lad', 'huber', 'quantile','linex') def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, subsample=1.0, criterion='friedman_mse', min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0., max_depth=3, min_impurity_split=1e-7, init=None, random_state=None, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, - warm_start=False, presort='auto'): + warm_start=False, presort='auto',grad_factor=100): ## TODO REMOVE super(GradientBoostingRegressor, self).__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -1849,7 +1906,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, max_features=max_features, min_impurity_split=min_impurity_split, random_state=random_state, alpha=alpha, verbose=verbose, max_leaf_nodes=max_leaf_nodes, warm_start=warm_start, - presort=presort) + presort=presort,grad_factor= grad_factor) ## TODO REMOVE def predict(self, X): """Predict regression target for X. diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index cf6d32d1b7fe1..ec0cbceb455a6 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -41,14 +41,21 @@ cdef class Criterion: cdef double weighted_n_left # Weighted number of samples in the left node cdef double weighted_n_right # Weighted number of samples in the right node + cdef double* sum_total # For classification criteria, the sum of the # weighted count of each label. For regression, # the sum of w*y. sum_total[k] is equal to # sum_{i=start}^{end-1} w[samples[i]]*y[samples[i], k], - # where k is output index. + # where k is output index. cdef double* sum_left # Same as above, but for the left side of the split cdef double* sum_right # same as above, but for the right side of the split + # For our RegressionCriterionAuxiliaryForLinex, we need this array to compute impurity + cdef double* exp_sum_total + cdef double* exp_sum_left + cdef double* exp_sum_right + cdef DOUBLE_t max_value + # The criterion object is maintained such that left and right collected # statistics correspond to samples[start:pos] and samples[pos:end]. diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 26c40dc8d6616..1c2582ec918db 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -20,7 +20,7 @@ from libc.stdlib cimport calloc from libc.stdlib cimport free from libc.string cimport memcpy from libc.string cimport memset -from libc.math cimport fabs +from libc.math cimport fabs, exp, log as logarithm import numpy as np cimport numpy as np @@ -44,6 +44,9 @@ cdef class Criterion: free(self.sum_total) free(self.sum_left) free(self.sum_right) + free(self.exp_sum_total) + free(self.exp_sum_left) + free(self.exp_sum_right) def __getstate__(self): return {} @@ -200,12 +203,11 @@ cdef class Criterion: self.children_impurity(&impurity_left, &impurity_right) return ((self.weighted_n_node_samples / self.weighted_n_samples) * - (impurity - (self.weighted_n_right / + (impurity - (self.weighted_n_right / self.weighted_n_node_samples * impurity_right) - - (self.weighted_n_left / + - (self.weighted_n_left / self.weighted_n_node_samples * impurity_left))) - cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" @@ -266,7 +268,7 @@ cdef class ClassificationCriterion(Criterion): self.sum_left = calloc(n_elements, sizeof(double)) self.sum_right = calloc(n_elements, sizeof(double)) - if (self.sum_total == NULL or + if (self.sum_total == NULL or self.sum_left == NULL or self.sum_right == NULL): raise MemoryError() @@ -496,7 +498,6 @@ cdef class ClassificationCriterion(Criterion): dest += self.sum_stride sum_total += self.sum_stride - cdef class Entropy(ClassificationCriterion): """Cross Entropy impurity criterion. @@ -577,7 +578,6 @@ cdef class Entropy(ClassificationCriterion): impurity_left[0] = entropy_left / self.n_outputs impurity_right[0] = entropy_right / self.n_outputs - cdef class Gini(ClassificationCriterion): """Gini Index impurity criterion. @@ -671,6 +671,373 @@ cdef class Gini(ClassificationCriterion): impurity_left[0] = gini_left / self.n_outputs impurity_right[0] = gini_right / self.n_outputs +cdef class RegressionCriterionAuxiliaryForLinex(Criterion): + """ + Inspired from regression criterion. Handle exp_sum_total,exp_sum_left,exp_sum_right, + in order to compute proxy of impurity improvements. + """ + + def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples): + + self.y = NULL + self.y_stride = 0 + self.sample_weight = NULL + + self.samples = NULL + self.start = 0 + self.pos = 0 + self.end = 0 + + self.n_outputs = n_outputs + self.n_samples = n_samples + self.n_node_samples = 0 + self.weighted_n_node_samples = 0.0 + self.weighted_n_left = 0.0 + self.weighted_n_right = 0.0 + + # Allocate accumulators. Make sure they are NULL, not uninitialized, + # before an exception can be raised (which triggers __dealloc__). + self.sum_total = NULL + self.sum_left = NULL + self.sum_right = NULL + self.exp_sum_total = NULL + self.exp_sum_left = NULL + self.exp_sum_right = NULL + + # Allocate memory for the accumulators + self.sum_total = calloc(n_outputs, sizeof(double)) + self.sum_left = calloc(n_outputs, sizeof(double)) + self.sum_right = calloc(n_outputs, sizeof(double)) + self.exp_sum_total = calloc(n_outputs, sizeof(double)) + self.exp_sum_left = calloc(n_outputs, sizeof(double)) + self.exp_sum_right = calloc(n_outputs, sizeof(double)) + + if (self.sum_total == NULL or + self.sum_left == NULL or + self.sum_right == NULL or + self.exp_sum_total == NULL or + self.exp_sum_left == NULL or + self.exp_sum_right == NULL): + raise MemoryError() + + def __reduce__(self): + return (type(self), (self.n_outputs, self.n_samples), self.__getstate__()) + + cdef void init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight, + double weighted_n_samples, SIZE_t* samples, SIZE_t start, + SIZE_t end) nogil: + """Initialize the criterion at node samples[start:end] and + children samples[start:start] and samples[start:end].""" + # Initialize fields + self.y = y + self.y_stride = y_stride + self.sample_weight = sample_weight + self.samples = samples + self.start = start + self.end = end + self.n_node_samples = end - start + self.weighted_n_samples = weighted_n_samples + self.weighted_n_node_samples = 0. + + cdef SIZE_t i + cdef SIZE_t p + cdef SIZE_t k + cdef DOUBLE_t y_ik + cdef DOUBLE_t w_y_ik + cdef DOUBLE_t w = 1.0 + cdef DOUBLE_t alpha = 0.1 + + memset(self.sum_total, 0, self.n_outputs * sizeof(double)) + memset(self.exp_sum_total, 0, self.n_outputs * sizeof(double)) + + self.max_value = 0 + + for p in range(start, end): + i = samples[p] + + if sample_weight != NULL: + w = sample_weight[i] + + for k in range(self.n_outputs): + y_ik = y[i * y_stride + k] + w_y_ik = w * y_ik + if(y_ik > self.max_value): + self.max_value = y_ik + self.sum_total[k] += w_y_ik + self.exp_sum_total[k] += w * ( exp( alpha*y_ik)) + + self.weighted_n_node_samples += w + + # Reset to pos=start + self.reset() + + cdef void reset(self) nogil: + """Reset the criterion at pos=start.""" + cdef SIZE_t n_bytes = self.n_outputs * sizeof(double) + memset(self.sum_left, 0, n_bytes) + memcpy(self.sum_right, self.sum_total, n_bytes) + memset(self.exp_sum_left, 0, n_bytes) + memcpy(self.exp_sum_right, self.exp_sum_total, n_bytes) + + self.weighted_n_left = 0.0 + self.weighted_n_right = self.weighted_n_node_samples + self.pos = self.start + + cdef void reverse_reset(self) nogil: + """Reset the criterion at pos=end.""" + cdef SIZE_t n_bytes = self.n_outputs * sizeof(double) + + memset(self.sum_right, 0, n_bytes) + memcpy(self.sum_left, self.sum_total, n_bytes) + + memset(self.exp_sum_right, 0, n_bytes) + memcpy(self.exp_sum_left, self.exp_sum_total, n_bytes) + + self.weighted_n_right = 0.0 + self.weighted_n_left = self.weighted_n_node_samples + self.pos = self.end + + cdef void update(self, SIZE_t new_pos) nogil: + """Updated statistics by moving samples[pos:new_pos] to the left.""" + + cdef double* sum_left = self.sum_left + cdef double* sum_right = self.sum_right + cdef double* sum_total = self.sum_total + + cdef double* exp_sum_left = self.exp_sum_left + cdef double* exp_sum_right = self.exp_sum_right + cdef double* exp_sum_total = self.exp_sum_total + + cdef double* sample_weight = self.sample_weight + cdef SIZE_t* samples = self.samples + + cdef DOUBLE_t* y = self.y + cdef SIZE_t pos = self.pos + cdef SIZE_t end = self.end + cdef SIZE_t i + cdef SIZE_t p + cdef SIZE_t k + cdef DOUBLE_t w = 1.0 + cdef DOUBLE_t alpha = 0.1 + cdef DOUBLE_t y_ik + + if (new_pos - pos) <= (end - new_pos): + for p in range(pos, new_pos): + i = samples[p] + + if sample_weight != NULL: + w = sample_weight[i] + + for k in range(self.n_outputs): + y_ik = y[i * self.y_stride + k] + sum_left[k] += w * y_ik + exp_sum_left[k] += w * ( exp( alpha*y_ik)) + + self.weighted_n_left += w + else: + self.reverse_reset() + + for p in range(end - 1, new_pos - 1, -1): + i = samples[p] + + if sample_weight != NULL: + w = sample_weight[i] + + for k in range(self.n_outputs): + y_ik = y[i * self.y_stride + k] + sum_left[k] -= w * y_ik + exp_sum_left[k] -= w * ( exp( alpha*y_ik)) + + self.weighted_n_left -= w + + self.weighted_n_right = (self.weighted_n_node_samples - + self.weighted_n_left) + for k in range(self.n_outputs): + sum_right[k] = sum_total[k] - sum_left[k] + exp_sum_right[k] = exp_sum_total[k] - exp_sum_left[k] + + self.pos = new_pos + + cdef double node_impurity(self) nogil: + pass + + cdef void children_impurity(self, double* impurity_left, + double* impurity_right) nogil: + pass + + cdef void node_value(self, double* dest) nogil: + """Compute the node value of samples[start:end] into dest.""" + + cdef DOUBLE_t alpha = 0.1 + cdef SIZE_t k + + for k in range(self.n_outputs): + dest[k] = (1 / alpha) * ( logarithm( (self.exp_sum_total[k] / self.weighted_n_node_samples))) + +cdef class LinexLeavesCriterionHalfMSESplit(RegressionCriterionAuxiliaryForLinex): + """ + A Criterion changing the leaves according to LinexLoss, doing the split using + the MSE criterion, except when the max_value of the sample is not too high, and + that therefore, we can compute the loss. + """ + + cdef double proxy_impurity_improvement(self) nogil: + """ + We take the MSE impurity if the max_value of our samples + is too high. This is because, computing the actuel LinexError + is impossible due to computation approximation + """ + cdef double* sum_left = self.sum_left + cdef double* sum_right = self.sum_right + + cdef SIZE_t k + cdef double proxy_impurity_left = 0.0 + cdef double proxy_impurity_right = 0.0 + cdef double impurity_left + cdef double impurity_right + + if(self.max_value>=90): + for k in range(self.n_outputs): + proxy_impurity_left += sum_left[k] * sum_left[k] + proxy_impurity_right += sum_right[k] * sum_right[k] + + return (proxy_impurity_left / self.weighted_n_left + + proxy_impurity_right / self.weighted_n_right) + else: + self.children_impurity(&impurity_left, &impurity_right) + + return (- self.weighted_n_right * impurity_right + - self.weighted_n_left * impurity_left) + + cdef double node_impurity(self) nogil: + """Evaluate the impurity of the current node, i.e. the impurity of + samples[start:end].Used at the start""" + + cdef double* sum_total = self.sum_total + cdef double* exp_sum_total = self.exp_sum_total + cdef double impurity + cdef SIZE_t k + + cdef DOUBLE_t alpha = 0.1 + + impurity = 0. + + for k in range(self.n_outputs): + impurity += ( (self.weighted_n_node_samples * logarithm((exp_sum_total[k] / self.weighted_n_node_samples))) - alpha * sum_total[k]) + + return impurity / self.n_outputs + + cdef void children_impurity(self, double* impurity_left, + double* impurity_right) nogil: + """Evaluate the impurity in children nodes, i.e. the impurity of the + left child (samples[start:pos]) and the impurity the right child + (samples[pos:end]).""" + + + cdef DOUBLE_t* y = self.y + cdef DOUBLE_t* sample_weight = self.sample_weight + cdef SIZE_t* samples = self.samples + cdef SIZE_t pos = self.pos + cdef SIZE_t start = self.start + + cdef double* sum_left = self.sum_left + cdef double* sum_right = self.sum_right + cdef double* exp_sum_left = self.exp_sum_left + cdef double* exp_sum_right = self.exp_sum_right + + cdef SIZE_t i + cdef SIZE_t p + cdef SIZE_t k + cdef DOUBLE_t w = 1.0 + cdef DOUBLE_t y_ik + cdef DOUBLE_t alpha = 0.1 + + impurity_left[0] = 0. + impurity_right[0] = 0. + + for k in range(self.n_outputs): + impurity_left[0] += ( self.weighted_n_left*logarithm( exp_sum_left[k] / self.weighted_n_left) - alpha * sum_left[k]) + impurity_right[0] += ( self.weighted_n_right*logarithm( exp_sum_right[k] / self.weighted_n_right) - alpha * sum_right[k]) + + impurity_left[0] /= self.n_outputs + impurity_right[0] /= self.n_outputs + +cdef class LinexLeavesCriterionMSESplit(RegressionCriterionAuxiliaryForLinex): + """ + A Criterion changing the leaves according to LinexLoss, but doing the split using + the MSE criterion. + """ + + cdef double node_impurity(self) nogil: + """Evaluate the impurity of the current node, i.e. the impurity of + samples[start:end].Used at the start""" + + cdef double* sum_total = self.sum_total + cdef double* exp_sum_total = self.exp_sum_total + cdef double impurity + cdef SIZE_t k + + cdef DOUBLE_t alpha = 0.1 + + impurity = 0. + + for k in range(self.n_outputs): + impurity += ( (self.weighted_n_node_samples * logarithm((exp_sum_total[k] / self.weighted_n_node_samples))) - alpha * sum_total[k]) + + return impurity / self.n_outputs + + cdef double proxy_impurity_improvement(self) nogil: + """ + This is from the MSE strategy but it works pretty well for linex + """ + + cdef double* sum_left = self.sum_left + cdef double* sum_right = self.sum_right + + cdef SIZE_t k + cdef double proxy_impurity_left = 0.0 + cdef double proxy_impurity_right = 0.0 + + for k in range(self.n_outputs): + proxy_impurity_left += sum_left[k] * sum_left[k] + proxy_impurity_right += sum_right[k] * sum_right[k] + + return (proxy_impurity_left / self.weighted_n_left + + proxy_impurity_right / self.weighted_n_right) + + cdef void children_impurity(self, double* impurity_left, + double* impurity_right) nogil: + """Evaluate the impurity in children nodes, i.e. the impurity of the + left child (samples[start:pos]) and the impurity the right child + (samples[pos:end]).""" + + cdef DOUBLE_t* y = self.y + cdef DOUBLE_t* sample_weight = self.sample_weight + cdef SIZE_t* samples = self.samples + cdef SIZE_t pos = self.pos + cdef SIZE_t start = self.start + + cdef double* sum_left = self.sum_left + cdef double* sum_right = self.sum_right + cdef double* exp_sum_left = self.exp_sum_left + cdef double* exp_sum_right = self.exp_sum_right + + cdef SIZE_t i + cdef SIZE_t p + cdef SIZE_t k + cdef DOUBLE_t w = 1.0 + cdef DOUBLE_t y_ik + cdef DOUBLE_t alpha = 0.1 + + impurity_left[0] = 0. + impurity_right[0] = 0. + + for k in range(self.n_outputs): + impurity_left[0] += ( self.weighted_n_left*logarithm( exp_sum_left[k] / self.weighted_n_left) - alpha * sum_left[k]) + impurity_right[0] += ( self.weighted_n_right*logarithm( exp_sum_right[k] / self.weighted_n_right) - alpha * sum_right[k]) + + impurity_left[0] /= self.n_outputs + impurity_right[0] /= self.n_outputs cdef class RegressionCriterion(Criterion): """Abstract regression criterion. @@ -728,7 +1095,7 @@ cdef class RegressionCriterion(Criterion): self.sum_left = calloc(n_outputs, sizeof(double)) self.sum_right = calloc(n_outputs, sizeof(double)) - if (self.sum_total == NULL or + if (self.sum_total == NULL or self.sum_left == NULL or self.sum_right == NULL): raise MemoryError() @@ -853,7 +1220,7 @@ cdef class RegressionCriterion(Criterion): self.weighted_n_left -= w - self.weighted_n_right = (self.weighted_n_node_samples - + self.weighted_n_right = (self.weighted_n_node_samples - self.weighted_n_left) for k in range(self.n_outputs): sum_right[k] = sum_total[k] - sum_left[k] @@ -875,7 +1242,6 @@ cdef class RegressionCriterion(Criterion): for k in range(self.n_outputs): dest[k] = self.sum_total[k] / self.weighted_n_node_samples - cdef class MSE(RegressionCriterion): """Mean squared error impurity criterion. @@ -964,7 +1330,7 @@ cdef class MSE(RegressionCriterion): for k in range(self.n_outputs): impurity_left[0] -= (sum_left[k] / self.weighted_n_left) ** 2.0 - impurity_right[0] -= (sum_right[k] / self.weighted_n_right) ** 2.0 + impurity_right[0] -= (sum_right[k] / self.weighted_n_right) ** 2.0 impurity_left[0] /= self.n_outputs impurity_right[0] /= self.n_outputs @@ -1263,7 +1629,6 @@ cdef class MAE(RegressionCriterion): impurity_right[0] /= ((self.weighted_n_right) * self.n_outputs) - cdef class FriedmanMSE(MSE): """Mean squared error impurity criterion with improvement score by Friedman @@ -1320,5 +1685,5 @@ cdef class FriedmanMSE(MSE): diff = (self.weighted_n_right * total_sum_left - self.weighted_n_left * total_sum_right) / self.n_outputs - return (diff * diff / (self.weighted_n_left * self.weighted_n_right * + return (diff * diff / (self.weighted_n_left * self.weighted_n_right * self.weighted_n_node_samples)) diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index c3567e864c10b..d1eccb3a5ccf8 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -57,7 +57,7 @@ DOUBLE = _tree.DOUBLE CRITERIA_CLF = {"gini": _criterion.Gini, "entropy": _criterion.Entropy} -CRITERIA_REG = {"mse": _criterion.MSE, "friedman_mse": _criterion.FriedmanMSE, +CRITERIA_REG = {"mse": _criterion.MSE, "friedman_mse": _criterion.FriedmanMSE,"linex": _criterion.LinexLeavesCriterionMSESplit,"pureLinex": _criterion.LinexLeavesCriterionHalfMSESplit, "mae": _criterion.MAE} DENSE_SPLITTERS = {"best": _splitter.BestSplitter, @@ -1029,7 +1029,6 @@ def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=X_idx_sorted) return self - class ExtraTreeClassifier(DecisionTreeClassifier): """An extremely randomized tree classifier.