Skip to content

Commit

Permalink
[YDF] Oblique: Add integer weights
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701890429
  • Loading branch information
rstz authored and copybara-github committed Dec 2, 2024
1 parent bd66459 commit fd34b13
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,22 +234,32 @@ message DecisionTreeTrainingConfig {
BinaryWeights binary = 7;
ContinuousWeights continuous = 8;
PowerOfTwoWeights power_of_two = 9;
IntegerWeights integer = 10;
}

// Weights sample in {-1, 1} (default in "Sparse Projection Oblique Random
// Weights sampled in {-1, 1} (default in "Sparse Projection Oblique Random
// Forests" (Tomita et al, 2020))).
message BinaryWeights {}

// Weights sample in [-1, 1]. Consistently gives better quality models than
// Weights sampled in [-1, 1]. Consistently gives better quality models than
// binary weights.
message ContinuousWeights {}

// Weights sample in powers of two.
// Weights sampled uniformly in the exponend space, i.e. the weights are of
// the form $s * 2^i$ with the integer exponent $i$ sampled uniformly in
// [min_exponent, max_exponent] and the sign $s$ sampled uniformly in {-1,
// 1}.
message PowerOfTwoWeights {
optional int32 min_exponent = 1 [default = -3];
optional int32 max_exponent = 2 [default = 3];
}

// Weights sampled in uniformly in the integer range [minimum, maximum].
message IntegerWeights {
optional int32 minimum = 1 [default = -5];
optional int32 maximum = 2 [default = 5];
}

// Normalization applied on the features, before applying the sparse oblique
// projections.
optional Normalization normalization = 5 [default = NONE];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ Increasing this value increases training and inference time (on average). This v
kHParamSplitAxisSparseObliqueWeightsContinuous);
param->mutable_categorical()->add_possible_values(
kHParamSplitAxisSparseObliqueWeightsPowerOfTwo);
param->mutable_categorical()->add_possible_values(
kHParamSplitAxisSparseObliqueWeightsInteger);
param->mutable_conditional()->set_control_field(kHParamSplitAxis);
param->mutable_conditional()->mutable_categorical()->add_values(
kHParamSplitAxisSparseOblique);
Expand All @@ -328,9 +330,12 @@ Possible values:
- `BINARY`: The oblique weights are sampled in {-1,1} (default).
- `CONTINUOUS`: The oblique weights are be sampled in [-1,1].
- `POWER_OF_TWO`: The oblique weights are powers of two. The exponents are sampled
uniformly in [$0, $1], the sign is uniformly sampled)",
uniformly in [$0, $1], the sign is uniformly sampled.
- `INTEGER`: The weights are integers sampled uniformly from the range [$2, $3].)",
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent));
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
kHParamSplitAxisSparseObliqueWeightsIntegerMaximum));
}

{
Expand Down Expand Up @@ -363,6 +368,36 @@ power-of-two weights i.e. `sparse_oblique_weights=POWER_OF_TWO`. Minimum exponen
power-of-two weights i.e. `sparse_oblique_weights=POWER_OF_TWO`. Maximum exponent of the weights)");
}

{
ASSIGN_OR_RETURN(
auto param,
get_params(kHParamSplitAxisSparseObliqueWeightsIntegerMinimum));
param->mutable_integer()->set_default_value(
config.sparse_oblique_split().integer().minimum());
param->mutable_conditional()->set_control_field(
kHParamSplitAxisSparseObliqueWeights);
param->mutable_conditional()->mutable_categorical()->add_values(
(kHParamSplitAxisSparseObliqueWeightsInteger));
param->mutable_documentation()->set_description(
R"(For sparse oblique splits i.e. `split_axis=SPARSE_OBLIQUE` with
integer weights i.e. `sparse_oblique_weights=INTEGER`. Minimum value of the weights.)");
}

{
ASSIGN_OR_RETURN(
auto param,
get_params(kHParamSplitAxisSparseObliqueWeightsIntegerMaximum));
param->mutable_integer()->set_default_value(
config.sparse_oblique_split().integer().maximum());
param->mutable_conditional()->set_control_field(
kHParamSplitAxisSparseObliqueWeights);
param->mutable_conditional()->mutable_categorical()->add_values(
(kHParamSplitAxisSparseObliqueWeightsInteger));
param->mutable_documentation()->set_description(
R"(For sparse oblique splits i.e. `split_axis=SPARSE_OBLIQUE` with
integer weights i.e. `sparse_oblique_weights=INTEGER`. Maximum value of the weights)");
}

{
ASSIGN_OR_RETURN(
auto param, get_params(kHParamSplitAxisSparseObliqueMaxNumProjections));
Expand Down Expand Up @@ -788,13 +823,16 @@ absl::Status SetHyperParameters(
dt_config->mutable_sparse_oblique_split()->mutable_continuous();
} else if (value == kHParamSplitAxisSparseObliqueWeightsPowerOfTwo) {
dt_config->mutable_sparse_oblique_split()->mutable_power_of_two();
} else if (value == kHParamSplitAxisSparseObliqueWeightsInteger) {
dt_config->mutable_sparse_oblique_split()->mutable_integer();
} else {
return absl::InvalidArgumentError(absl::StrCat(
"Unknown value for parameter ",
kHParamSplitAxisSparseObliqueWeights, ". Possible values are: ",
kHParamSplitAxisSparseObliqueWeightsBinary, ", ",
kHParamSplitAxisSparseObliqueWeightsContinuous, " and ",
kHParamSplitAxisSparseObliqueWeightsPowerOfTwo, "."));
kHParamSplitAxisSparseObliqueWeightsContinuous, ", ",
kHParamSplitAxisSparseObliqueWeightsPowerOfTwo, "and",
kHParamSplitAxisSparseObliqueWeightsInteger, "."));
}
} else {
return absl::InvalidArgumentError(
Expand Down Expand Up @@ -845,6 +883,46 @@ absl::Status SetHyperParameters(
}
}

{
const auto hparam = generic_hyper_params->Get(
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum);
if (hparam.has_value()) {
const auto hparam_value = hparam.value().value().integer();
if (dt_config->has_sparse_oblique_split() &&
dt_config->sparse_oblique_split().has_integer()) {
dt_config->mutable_sparse_oblique_split()
->mutable_integer()
->set_minimum(hparam_value);
} else {
return absl::InvalidArgumentError(absl::StrCat(
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
" only works with sparse oblique trees "
"(`split_axis=SPARSE_OBLIQUE`) and integer weights "
"(`sparse_oblique_weights=INTEGER`)"));
}
}
}

{
const auto hparam = generic_hyper_params->Get(
kHParamSplitAxisSparseObliqueWeightsIntegerMaximum);
if (hparam.has_value()) {
const auto hparam_value = hparam.value().value().integer();
if (dt_config->has_sparse_oblique_split() &&
dt_config->sparse_oblique_split().has_integer()) {
dt_config->mutable_sparse_oblique_split()
->mutable_integer()
->set_maximum(hparam_value);
} else {
return absl::InvalidArgumentError(absl::StrCat(
kHParamSplitAxisSparseObliqueWeightsIntegerMaximum,
" only works with sparse oblique trees "
"(`split_axis=SPARSE_OBLIQUE`) and integer weights "
"(`sparse_oblique_weights=INTEGER`)"));
}
}
}

// Mhld oblique trees
{
const auto hparam =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ constexpr char kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent[] =
"sparse_oblique_weights_power_of_two_min_exponent";
constexpr char kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent[] =
"sparse_oblique_weights_power_of_two_max_exponent";
constexpr char kHParamSplitAxisSparseObliqueWeightsInteger[] = "INTEGER";
constexpr char kHParamSplitAxisSparseObliqueWeightsIntegerMinimum[] =
"sparse_oblique_weights_integer_minimum";
constexpr char kHParamSplitAxisSparseObliqueWeightsIntegerMaximum[] =
"sparse_oblique_weights_integer_maximum";

constexpr char kHParamSplitAxisSparseObliqueNormalization[] =
"sparse_oblique_normalization";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ TEST(GenericParameters, GiveValidAndInvalidHyperparameters) {
kHParamHonestRatioLeafExamples,
kHParamHonestFixedSeparation,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent};
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
kHParamSplitAxisSparseObliqueWeightsIntegerMaximum};
EXPECT_OK(GetGenericHyperParameterSpecification(
config, &hparam_def, valid_hyperparameters, invalid_hyperparameters));
EXPECT_EQ(hparam_def.fields().size(), 2);
Expand Down Expand Up @@ -114,7 +116,9 @@ TEST(GenericParameters, MissingValidHyperparameters) {
kHParamHonestRatioLeafExamples,
kHParamHonestFixedSeparation,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent};
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
kHParamSplitAxisSparseObliqueWeightsIntegerMaximum};
absl::Status status = GetGenericHyperParameterSpecification(
config, &hparam_def, valid_hyperparameters, invalid_hyperparameters);
EXPECT_THAT(status, test::StatusIs(absl::StatusCode::kInternal,
Expand Down Expand Up @@ -155,7 +159,9 @@ TEST(GenericParameters, MissingInvalidHyperparameters) {
kHParamHonestRatioLeafExamples,
kHParamHonestFixedSeparation,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent};
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
kHParamSplitAxisSparseObliqueWeightsIntegerMaximum};
absl::Status status = GetGenericHyperParameterSpecification(
config, &hparam_def, valid_hyperparameters, invalid_hyperparameters);
EXPECT_THAT(status, test::StatusIs(
Expand Down Expand Up @@ -197,7 +203,9 @@ TEST(GenericParameters, UnknownValidHyperparameter) {
kHParamHonestRatioLeafExamples,
kHParamHonestFixedSeparation,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent};
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
kHParamSplitAxisSparseObliqueWeightsIntegerMaximum};
absl::Status status = GetGenericHyperParameterSpecification(
config, &hparam_def, valid_hyperparameters, invalid_hyperparameters);
EXPECT_THAT(
Expand Down Expand Up @@ -240,6 +248,8 @@ TEST(GenericParameters, UnknownInvalidHyperparameter) {
kHParamHonestFixedSeparation,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
kHParamSplitAxisSparseObliqueWeightsIntegerMaximum,
"does_not_exist_invalid"};
absl::Status status = GetGenericHyperParameterSpecification(
config, &hparam_def, valid_hyperparameters, invalid_hyperparameters);
Expand Down Expand Up @@ -283,7 +293,9 @@ TEST(GenericParameters, ExistingHyperparameter) {
kHParamHonestRatioLeafExamples,
kHParamHonestFixedSeparation,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent};
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
kHParamSplitAxisSparseObliqueWeightsIntegerMaximum};
absl::Status status = GetGenericHyperParameterSpecification(
config, &hparam_def, valid_hyperparameters, invalid_hyperparameters);
EXPECT_THAT(
Expand Down
41 changes: 29 additions & 12 deletions yggdrasil_decision_forests/learner/decision_tree/oblique.cc
Original file line number Diff line number Diff line change
Expand Up @@ -678,16 +678,33 @@ void SampleProjection(const absl::Span<const int>& features,

const auto gen_weight = [&](const int feature) -> float {
float weight = unif1m1(*random);
if (oblique_config.has_binary() ||
oblique_config.weights_case() == oblique_config.WEIGHTS_NOT_SET) {
weight = (weight >= 0) ? 1.f : -1.f;
} else if (oblique_config.has_power_of_two()) {
float sign = (weight >= 0) ? 1.f : -1.f;
int exponent =
absl::Uniform<int>(absl::IntervalClosed, *random,
oblique_config.power_of_two().min_exponent(),
oblique_config.power_of_two().max_exponent());
weight = sign * std::pow(2, exponent);
switch (oblique_config.weights_case()) {
case (proto::DecisionTreeTrainingConfig::SparseObliqueSplit::WeightsCase::
kBinary): {
weight = (weight >= 0) ? 1.f : -1.f;
break;
}
case (proto::DecisionTreeTrainingConfig::SparseObliqueSplit::WeightsCase::
kPowerOfTwo): {
float sign = (weight >= 0) ? 1.f : -1.f;
int exponent =
absl::Uniform<int>(absl::IntervalClosed, *random,
oblique_config.power_of_two().min_exponent(),
oblique_config.power_of_two().max_exponent());
weight = sign * std::pow(2, exponent);
break;
}
case (proto::DecisionTreeTrainingConfig::SparseObliqueSplit::WeightsCase::
kInteger): {
weight = absl::Uniform(absl::IntervalClosed, *random,
oblique_config.integer().minimum(),
oblique_config.integer().maximum());
break;
}
default: {
// Return continuous weights.
break;
}
}

if (config_link.per_columns_size() > 0 &&
Expand All @@ -698,8 +715,8 @@ void SampleProjection(const absl::Span<const int>& features,
if (direction_increasing == (weight < 0)) {
weight = -weight;
}
// As soon as one selected feature is monotonic, the oblique split becomes
// monotonic.
// As soon as one selected feature is monotonic, the oblique split
// becomes monotonic.
*monotonic_direction = 1;
}

Expand Down
18 changes: 18 additions & 0 deletions yggdrasil_decision_forests/learner/decision_tree/training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3837,6 +3837,8 @@ void SetDefaultHyperParameters(proto::DecisionTreeTrainingConfig* config) {

config->mutable_internal()->set_sorting_strategy(sorting_strategy);

// The binary weight hyperparameter is deprecated for the more general weights
// hyperparameter.
if (config->sparse_oblique_split().has_binary_weight()) {
if (config->sparse_oblique_split().binary_weight()) {
config->mutable_sparse_oblique_split()->mutable_binary();
Expand All @@ -3845,6 +3847,14 @@ void SetDefaultHyperParameters(proto::DecisionTreeTrainingConfig* config) {
}
config->mutable_sparse_oblique_split()->clear_binary_weight();
}

// By default, we use binary weights.
if (config->has_sparse_oblique_split() &&
config->sparse_oblique_split().weights_case() ==
proto::DecisionTreeTrainingConfig::SparseObliqueSplit::
WEIGHTS_NOT_SET) {
config->mutable_sparse_oblique_split()->mutable_binary();
}
}

template <class T, class S, class C>
Expand Down Expand Up @@ -4102,6 +4112,14 @@ absl::Status DecisionTreeTrain(
dt_config.sparse_oblique_split().power_of_two().min_exponent(),
dt_config.sparse_oblique_split().power_of_two().max_exponent()));
}
if (dt_config.sparse_oblique_split().integer().minimum() >
dt_config.sparse_oblique_split().integer().maximum()) {
return absl::InvalidArgumentError(absl::Substitute(
"The minimum value for sparse oblique integer weights cannot "
"be larger than the maximum value. Got minimum: $0, maximum: $1",
dt_config.sparse_oblique_split().integer().minimum(),
dt_config.sparse_oblique_split().integer().maximum()));
}
}

if (dt_config.has_honest()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ IsolationForestLearner::GetGenericHyperParameterSpecification() const {
decision_tree::kHParamSplitAxisSparseObliqueWeights,
decision_tree::kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
decision_tree::kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
decision_tree::kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
decision_tree::kHParamSplitAxisSparseObliqueWeightsIntegerMaximum,
};
// Remove not yet implemented hyperparameters
// TODO: b/345425508 - Implement more hyperparameters for isolation forests.
Expand Down
2 changes: 1 addition & 1 deletion yggdrasil_decision_forests/port/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
learner constructor argument. See the [feature selection tutorial]() for
more details.
- Add standalone prediction evaluation `ydf.evaluate_predictions()`.
- Add option "POWER_OF_TWO" for sparse oblique weights.
- Add options "POWER_OF_TWO" and "INTEGER" for sparse oblique weights.

### Fix

Expand Down
22 changes: 22 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,28 @@ def test_oblique_weights_power_of_two(self):
acceptable_weights_2 = [x * 2**y for x in (1.0, -1.0) for y in range(4, 8)]
self.assertTrue(all(x in acceptable_weights_2 for x in root_weights_2))

def test_oblique_weights_integer(self):
learner = specialized_learners.CartLearner(
label="label",
max_depth=2,
split_axis="SPARSE_OBLIQUE",
sparse_oblique_weights="INTEGER",
)
f1 = np.linspace(-1, 1, 200) ** 2
f2 = np.linspace(1.5, -0.5, 200) ** 2
label = (0.2 * f1 + 0.7 * f2 >= 0.25).astype(int)
ds = {"f1": f1, "f2": f2, "label": label}
model = learner.train(ds)
root_weights = model.get_tree(0).root.condition.weights
acceptable_weights = [x * y for x in (1.0, -1.0) for y in range(0, 6)]
self.assertTrue(all(x in acceptable_weights for x in root_weights))
learner.hyperparameters["sparse_oblique_weights_integer_minimum"] = 7
learner.hyperparameters["sparse_oblique_weights_integer_maximum"] = 14
model_2 = learner.train(ds)
root_weights_2 = model_2.get_tree(0).root.condition.weights
acceptable_weights_2 = [x * y for x in (1.0, -1.0) for y in range(7, 15)]
self.assertTrue(all(x in acceptable_weights_2 for x in root_weights_2))


class GradientBoostedTreesLearnerTest(LearnerTest):

Expand Down

0 comments on commit fd34b13

Please sign in to comment.