Skip to content

Commit

Permalink
Modification for including Linex Loss in criterion (to build trees)
Browse files Browse the repository at this point in the history
  • Loading branch information
Edouard360 committed Dec 5, 2016
1 parent fc3bec7 commit b00d6f0
Show file tree
Hide file tree
Showing 4 changed files with 448 additions and 20 deletions.
65 changes: 61 additions & 4 deletions sklearn/ensemble/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ def predict(self, X):
y.fill(self.quantile)
return y

## TODO REMOVE
class LinexEstimator(BaseEstimator):
"""An estimator predicting the mean of the training targets."""

def fit(self, X, y, sample_weight=None):
alpha = 0.1
if sample_weight is None:
self.opt = (1 / alpha) * np.log(np.mean(np.exp(y * alpha)))
else:
self.opt = (1 / alpha) * np.log(np.mean(np.exp(y * alpha) * sample_weight))

def predict(self, X):
y = np.empty((X.shape[0], 1), dtype=np.float64)
y.fill(self.opt)
return y

class MeanEstimator(BaseEstimator):
"""An estimator predicting the mean of the training targets."""
Expand Down Expand Up @@ -397,6 +412,43 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y,
np.minimum(np.abs(diff_minus_median), gamma))


# TODO REMOVE
class LinexLossFunction(RegressionLossFunction):
def __init__(self, n_classes, grad_factor=100):
super(LinexLossFunction, self).__init__(n_classes)
self.grad_factor = grad_factor

def init_estimator(self):
return LinexEstimator()

def __call__(self, y, pred, sample_weight=None):
alpha = 0.1
if sample_weight is None:
return np.mean(np.exp((y - pred.ravel()) * alpha) - alpha * (y - pred.ravel()) - 1)
else:
return (1.0 / sample_weight.sum() *
np.sum(sample_weight * (np.exp((y - pred.ravel()) * alpha) - alpha * (y - pred.ravel()) - 1)))

def negative_gradient(self, y, pred, **kargs):
"""1.0 if y - pred > 0.0 else -1.0"""

alpha = 0.1
pred = pred.ravel()
grad = alpha * (np.exp(alpha * (y - pred)) - 1)
return self.grad_factor * grad

def _update_terminal_region(self, tree, terminal_regions, leaf, X, y,
residual, pred, sample_weight):
alpha = 0.1
terminal_region = np.where(terminal_regions == leaf)[0]
sample_weight = sample_weight.take(terminal_region, axis=0)
diff = y.take(terminal_region, axis=0) - pred.take(terminal_region, axis=0)
if sample_weight is None:
tree.value[leaf, 0, 0] = (1 / alpha) * np.log(np.mean(np.exp(diff * alpha))) ## TODO
else:
tree.value[leaf, 0, 0] = (1 / alpha) * np.log(np.mean(sample_weight * np.exp(diff * alpha)))


class QuantileLossFunction(RegressionLossFunction):
"""Loss function for quantile regression.
Expand Down Expand Up @@ -653,6 +705,7 @@ def _score_to_decision(self, score):
'lad': LeastAbsoluteError,
'huber': HuberLossFunction,
'quantile': QuantileLossFunction,
'linex': LinexLossFunction, # TODO REMOVE
'deviance': None, # for both, multinomial and binomial
'exponential': ExponentialLoss,
}
Expand Down Expand Up @@ -724,7 +777,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion,
min_samples_split, min_samples_leaf, min_weight_fraction_leaf,
max_depth, min_impurity_split, init, subsample, max_features,
random_state, alpha=0.9, verbose=0, max_leaf_nodes=None,
warm_start=False, presort='auto'):
warm_start=False, presort='auto', grad_factor=100):

self.n_estimators = n_estimators
self.learning_rate = learning_rate
Expand All @@ -744,6 +797,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion,
self.max_leaf_nodes = max_leaf_nodes
self.warm_start = warm_start
self.presort = presort
self.grad_factor = grad_factor ## TODO REMOVE

self.estimators_ = np.empty((0, 0), dtype=np.object)

Expand All @@ -762,6 +816,7 @@ def _fit_stage(self, i, X, y, y_pred, sample_weight, sample_mask,
residual = loss.negative_gradient(y, y_pred, k=k,
sample_weight=sample_weight)

max_depth = self.max_depth
# induce regression tree on residuals
tree = DecisionTreeRegressor(
criterion=self.criterion,
Expand Down Expand Up @@ -824,6 +879,8 @@ def _check_params(self):

if self.loss in ('huber', 'quantile'):
self.loss_ = loss_class(self.n_classes_, self.alpha)
elif self.loss in ('linex'):
self.loss_ = loss_class(self.n_classes_, self.grad_factor)
else:
self.loss_ = loss_class(self.n_classes_)

Expand Down Expand Up @@ -1831,14 +1888,14 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin):
Elements of Statistical Learning Ed. 2, Springer, 2009.
"""

_SUPPORTED_LOSS = ('ls', 'lad', 'huber', 'quantile')
_SUPPORTED_LOSS = ('ls', 'lad', 'huber', 'quantile','linex')

def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
subsample=1.0, criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3, min_impurity_split=1e-7, init=None, random_state=None,
max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None,
warm_start=False, presort='auto'):
warm_start=False, presort='auto',grad_factor=100): ## TODO REMOVE

super(GradientBoostingRegressor, self).__init__(
loss=loss, learning_rate=learning_rate, n_estimators=n_estimators,
Expand All @@ -1849,7 +1906,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
max_features=max_features, min_impurity_split=min_impurity_split,
random_state=random_state, alpha=alpha, verbose=verbose,
max_leaf_nodes=max_leaf_nodes, warm_start=warm_start,
presort=presort)
presort=presort,grad_factor= grad_factor) ## TODO REMOVE

def predict(self, X):
"""Predict regression target for X.
Expand Down
9 changes: 8 additions & 1 deletion sklearn/tree/_criterion.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,21 @@ cdef class Criterion:
cdef double weighted_n_left # Weighted number of samples in the left node
cdef double weighted_n_right # Weighted number of samples in the right node


cdef double* sum_total # For classification criteria, the sum of the
# weighted count of each label. For regression,
# the sum of w*y. sum_total[k] is equal to
# sum_{i=start}^{end-1} w[samples[i]]*y[samples[i], k],
# where k is output index.
# where k is output index.
cdef double* sum_left # Same as above, but for the left side of the split
cdef double* sum_right # same as above, but for the right side of the split

# For our RegressionCriterionAuxiliaryForLinex, we need this array to compute impurity
cdef double* exp_sum_total
cdef double* exp_sum_left
cdef double* exp_sum_right
cdef DOUBLE_t max_value

# The criterion object is maintained such that left and right collected
# statistics correspond to samples[start:pos] and samples[pos:end].

Expand Down
Loading

0 comments on commit b00d6f0

Please sign in to comment.