diff --git a/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto b/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto index b89e4969..a1eed710 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto +++ b/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto @@ -234,22 +234,32 @@ message DecisionTreeTrainingConfig { BinaryWeights binary = 7; ContinuousWeights continuous = 8; PowerOfTwoWeights power_of_two = 9; + IntegerWeights integer = 10; } - // Weights sample in {-1, 1} (default in "Sparse Projection Oblique Random + // Weights sampled in {-1, 1} (default in "Sparse Projection Oblique Random // Forests" (Tomita et al, 2020))). message BinaryWeights {} - // Weights sample in [-1, 1]. Consistently gives better quality models than + // Weights sampled in [-1, 1]. Consistently gives better quality models than // binary weights. message ContinuousWeights {} - // Weights sample in powers of two. + // Weights sampled uniformly in the exponend space, i.e. the weights are of + // the form $s * 2^i$ with the integer exponent $i$ sampled uniformly in + // [min_exponent, max_exponent] and the sign $s$ sampled uniformly in {-1, + // 1}. message PowerOfTwoWeights { optional int32 min_exponent = 1 [default = -3]; optional int32 max_exponent = 2 [default = 3]; } + // Weights sampled in uniformly in the integer range [minimum, maximum]. + message IntegerWeights { + optional int32 minimum = 1 [default = -5]; + optional int32 maximum = 2 [default = 5]; + } + // Normalization applied on the features, before applying the sparse oblique // projections. optional Normalization normalization = 5 [default = NONE]; diff --git a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.cc b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.cc index 7fc8ddba..94f370e7 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.cc +++ b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.cc @@ -316,6 +316,8 @@ Increasing this value increases training and inference time (on average). This v kHParamSplitAxisSparseObliqueWeightsContinuous); param->mutable_categorical()->add_possible_values( kHParamSplitAxisSparseObliqueWeightsPowerOfTwo); + param->mutable_categorical()->add_possible_values( + kHParamSplitAxisSparseObliqueWeightsInteger); param->mutable_conditional()->set_control_field(kHParamSplitAxis); param->mutable_conditional()->mutable_categorical()->add_values( kHParamSplitAxisSparseOblique); @@ -328,9 +330,12 @@ Possible values: - `BINARY`: The oblique weights are sampled in {-1,1} (default). - `CONTINUOUS`: The oblique weights are be sampled in [-1,1]. - `POWER_OF_TWO`: The oblique weights are powers of two. The exponents are sampled - uniformly in [$0, $1], the sign is uniformly sampled)", + uniformly in [$0, $1], the sign is uniformly sampled. +- `INTEGER`: The weights are integers sampled uniformly from the range [$2, $3].)", kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, - kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent)); + kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, + kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, + kHParamSplitAxisSparseObliqueWeightsIntegerMaximum)); } { @@ -363,6 +368,36 @@ power-of-two weights i.e. `sparse_oblique_weights=POWER_OF_TWO`. Minimum exponen power-of-two weights i.e. `sparse_oblique_weights=POWER_OF_TWO`. Maximum exponent of the weights)"); } + { + ASSIGN_OR_RETURN( + auto param, + get_params(kHParamSplitAxisSparseObliqueWeightsIntegerMinimum)); + param->mutable_integer()->set_default_value( + config.sparse_oblique_split().integer().minimum()); + param->mutable_conditional()->set_control_field( + kHParamSplitAxisSparseObliqueWeights); + param->mutable_conditional()->mutable_categorical()->add_values( + (kHParamSplitAxisSparseObliqueWeightsInteger)); + param->mutable_documentation()->set_description( + R"(For sparse oblique splits i.e. `split_axis=SPARSE_OBLIQUE` with +integer weights i.e. `sparse_oblique_weights=INTEGER`. Minimum value of the weights.)"); + } + + { + ASSIGN_OR_RETURN( + auto param, + get_params(kHParamSplitAxisSparseObliqueWeightsIntegerMaximum)); + param->mutable_integer()->set_default_value( + config.sparse_oblique_split().integer().maximum()); + param->mutable_conditional()->set_control_field( + kHParamSplitAxisSparseObliqueWeights); + param->mutable_conditional()->mutable_categorical()->add_values( + (kHParamSplitAxisSparseObliqueWeightsInteger)); + param->mutable_documentation()->set_description( + R"(For sparse oblique splits i.e. `split_axis=SPARSE_OBLIQUE` with +integer weights i.e. `sparse_oblique_weights=INTEGER`. Maximum value of the weights)"); + } + { ASSIGN_OR_RETURN( auto param, get_params(kHParamSplitAxisSparseObliqueMaxNumProjections)); @@ -788,13 +823,16 @@ absl::Status SetHyperParameters( dt_config->mutable_sparse_oblique_split()->mutable_continuous(); } else if (value == kHParamSplitAxisSparseObliqueWeightsPowerOfTwo) { dt_config->mutable_sparse_oblique_split()->mutable_power_of_two(); + } else if (value == kHParamSplitAxisSparseObliqueWeightsInteger) { + dt_config->mutable_sparse_oblique_split()->mutable_integer(); } else { return absl::InvalidArgumentError(absl::StrCat( "Unknown value for parameter ", kHParamSplitAxisSparseObliqueWeights, ". Possible values are: ", kHParamSplitAxisSparseObliqueWeightsBinary, ", ", - kHParamSplitAxisSparseObliqueWeightsContinuous, " and ", - kHParamSplitAxisSparseObliqueWeightsPowerOfTwo, ".")); + kHParamSplitAxisSparseObliqueWeightsContinuous, ", ", + kHParamSplitAxisSparseObliqueWeightsPowerOfTwo, "and", + kHParamSplitAxisSparseObliqueWeightsInteger, ".")); } } else { return absl::InvalidArgumentError( @@ -845,6 +883,46 @@ absl::Status SetHyperParameters( } } + { + const auto hparam = generic_hyper_params->Get( + kHParamSplitAxisSparseObliqueWeightsIntegerMinimum); + if (hparam.has_value()) { + const auto hparam_value = hparam.value().value().integer(); + if (dt_config->has_sparse_oblique_split() && + dt_config->sparse_oblique_split().has_integer()) { + dt_config->mutable_sparse_oblique_split() + ->mutable_integer() + ->set_minimum(hparam_value); + } else { + return absl::InvalidArgumentError(absl::StrCat( + kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, + " only works with sparse oblique trees " + "(`split_axis=SPARSE_OBLIQUE`) and integer weights " + "(`sparse_oblique_weights=INTEGER`)")); + } + } + } + + { + const auto hparam = generic_hyper_params->Get( + kHParamSplitAxisSparseObliqueWeightsIntegerMaximum); + if (hparam.has_value()) { + const auto hparam_value = hparam.value().value().integer(); + if (dt_config->has_sparse_oblique_split() && + dt_config->sparse_oblique_split().has_integer()) { + dt_config->mutable_sparse_oblique_split() + ->mutable_integer() + ->set_maximum(hparam_value); + } else { + return absl::InvalidArgumentError(absl::StrCat( + kHParamSplitAxisSparseObliqueWeightsIntegerMaximum, + " only works with sparse oblique trees " + "(`split_axis=SPARSE_OBLIQUE`) and integer weights " + "(`sparse_oblique_weights=INTEGER`)")); + } + } + } + // Mhld oblique trees { const auto hparam = diff --git a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.h b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.h index d3ffc5f0..674b45ea 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.h +++ b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.h @@ -80,6 +80,11 @@ constexpr char kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent[] = "sparse_oblique_weights_power_of_two_min_exponent"; constexpr char kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent[] = "sparse_oblique_weights_power_of_two_max_exponent"; +constexpr char kHParamSplitAxisSparseObliqueWeightsInteger[] = "INTEGER"; +constexpr char kHParamSplitAxisSparseObliqueWeightsIntegerMinimum[] = + "sparse_oblique_weights_integer_minimum"; +constexpr char kHParamSplitAxisSparseObliqueWeightsIntegerMaximum[] = + "sparse_oblique_weights_integer_maximum"; constexpr char kHParamSplitAxisSparseObliqueNormalization[] = "sparse_oblique_normalization"; diff --git a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters_test.cc b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters_test.cc index 47c4b5ba..2ccc921c 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters_test.cc +++ b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters_test.cc @@ -71,7 +71,9 @@ TEST(GenericParameters, GiveValidAndInvalidHyperparameters) { kHParamHonestRatioLeafExamples, kHParamHonestFixedSeparation, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, - kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent}; + kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, + kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, + kHParamSplitAxisSparseObliqueWeightsIntegerMaximum}; EXPECT_OK(GetGenericHyperParameterSpecification( config, &hparam_def, valid_hyperparameters, invalid_hyperparameters)); EXPECT_EQ(hparam_def.fields().size(), 2); @@ -114,7 +116,9 @@ TEST(GenericParameters, MissingValidHyperparameters) { kHParamHonestRatioLeafExamples, kHParamHonestFixedSeparation, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, - kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent}; + kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, + kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, + kHParamSplitAxisSparseObliqueWeightsIntegerMaximum}; absl::Status status = GetGenericHyperParameterSpecification( config, &hparam_def, valid_hyperparameters, invalid_hyperparameters); EXPECT_THAT(status, test::StatusIs(absl::StatusCode::kInternal, @@ -155,7 +159,9 @@ TEST(GenericParameters, MissingInvalidHyperparameters) { kHParamHonestRatioLeafExamples, kHParamHonestFixedSeparation, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, - kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent}; + kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, + kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, + kHParamSplitAxisSparseObliqueWeightsIntegerMaximum}; absl::Status status = GetGenericHyperParameterSpecification( config, &hparam_def, valid_hyperparameters, invalid_hyperparameters); EXPECT_THAT(status, test::StatusIs( @@ -197,7 +203,9 @@ TEST(GenericParameters, UnknownValidHyperparameter) { kHParamHonestRatioLeafExamples, kHParamHonestFixedSeparation, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, - kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent}; + kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, + kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, + kHParamSplitAxisSparseObliqueWeightsIntegerMaximum}; absl::Status status = GetGenericHyperParameterSpecification( config, &hparam_def, valid_hyperparameters, invalid_hyperparameters); EXPECT_THAT( @@ -240,6 +248,8 @@ TEST(GenericParameters, UnknownInvalidHyperparameter) { kHParamHonestFixedSeparation, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, + kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, + kHParamSplitAxisSparseObliqueWeightsIntegerMaximum, "does_not_exist_invalid"}; absl::Status status = GetGenericHyperParameterSpecification( config, &hparam_def, valid_hyperparameters, invalid_hyperparameters); @@ -283,7 +293,9 @@ TEST(GenericParameters, ExistingHyperparameter) { kHParamHonestRatioLeafExamples, kHParamHonestFixedSeparation, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, - kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent}; + kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, + kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, + kHParamSplitAxisSparseObliqueWeightsIntegerMaximum}; absl::Status status = GetGenericHyperParameterSpecification( config, &hparam_def, valid_hyperparameters, invalid_hyperparameters); EXPECT_THAT( diff --git a/yggdrasil_decision_forests/learner/decision_tree/oblique.cc b/yggdrasil_decision_forests/learner/decision_tree/oblique.cc index 068d3c83..363f5c13 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/oblique.cc +++ b/yggdrasil_decision_forests/learner/decision_tree/oblique.cc @@ -678,16 +678,33 @@ void SampleProjection(const absl::Span& features, const auto gen_weight = [&](const int feature) -> float { float weight = unif1m1(*random); - if (oblique_config.has_binary() || - oblique_config.weights_case() == oblique_config.WEIGHTS_NOT_SET) { - weight = (weight >= 0) ? 1.f : -1.f; - } else if (oblique_config.has_power_of_two()) { - float sign = (weight >= 0) ? 1.f : -1.f; - int exponent = - absl::Uniform(absl::IntervalClosed, *random, - oblique_config.power_of_two().min_exponent(), - oblique_config.power_of_two().max_exponent()); - weight = sign * std::pow(2, exponent); + switch (oblique_config.weights_case()) { + case (proto::DecisionTreeTrainingConfig::SparseObliqueSplit::WeightsCase:: + kBinary): { + weight = (weight >= 0) ? 1.f : -1.f; + break; + } + case (proto::DecisionTreeTrainingConfig::SparseObliqueSplit::WeightsCase:: + kPowerOfTwo): { + float sign = (weight >= 0) ? 1.f : -1.f; + int exponent = + absl::Uniform(absl::IntervalClosed, *random, + oblique_config.power_of_two().min_exponent(), + oblique_config.power_of_two().max_exponent()); + weight = sign * std::pow(2, exponent); + break; + } + case (proto::DecisionTreeTrainingConfig::SparseObliqueSplit::WeightsCase:: + kInteger): { + weight = absl::Uniform(absl::IntervalClosed, *random, + oblique_config.integer().minimum(), + oblique_config.integer().maximum()); + break; + } + default: { + // Return continuous weights. + break; + } } if (config_link.per_columns_size() > 0 && @@ -698,8 +715,8 @@ void SampleProjection(const absl::Span& features, if (direction_increasing == (weight < 0)) { weight = -weight; } - // As soon as one selected feature is monotonic, the oblique split becomes - // monotonic. + // As soon as one selected feature is monotonic, the oblique split + // becomes monotonic. *monotonic_direction = 1; } diff --git a/yggdrasil_decision_forests/learner/decision_tree/training.cc b/yggdrasil_decision_forests/learner/decision_tree/training.cc index 59d6882c..25b3392c 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/training.cc +++ b/yggdrasil_decision_forests/learner/decision_tree/training.cc @@ -3837,6 +3837,8 @@ void SetDefaultHyperParameters(proto::DecisionTreeTrainingConfig* config) { config->mutable_internal()->set_sorting_strategy(sorting_strategy); + // The binary weight hyperparameter is deprecated for the more general weights + // hyperparameter. if (config->sparse_oblique_split().has_binary_weight()) { if (config->sparse_oblique_split().binary_weight()) { config->mutable_sparse_oblique_split()->mutable_binary(); @@ -3845,6 +3847,14 @@ void SetDefaultHyperParameters(proto::DecisionTreeTrainingConfig* config) { } config->mutable_sparse_oblique_split()->clear_binary_weight(); } + + // By default, we use binary weights. + if (config->has_sparse_oblique_split() && + config->sparse_oblique_split().weights_case() == + proto::DecisionTreeTrainingConfig::SparseObliqueSplit:: + WEIGHTS_NOT_SET) { + config->mutable_sparse_oblique_split()->mutable_binary(); + } } template @@ -4102,6 +4112,14 @@ absl::Status DecisionTreeTrain( dt_config.sparse_oblique_split().power_of_two().min_exponent(), dt_config.sparse_oblique_split().power_of_two().max_exponent())); } + if (dt_config.sparse_oblique_split().integer().minimum() > + dt_config.sparse_oblique_split().integer().maximum()) { + return absl::InvalidArgumentError(absl::Substitute( + "The minimum value for sparse oblique integer weights cannot " + "be larger than the maximum value. Got minimum: $0, maximum: $1", + dt_config.sparse_oblique_split().integer().minimum(), + dt_config.sparse_oblique_split().integer().maximum())); + } } if (dt_config.has_honest()) { diff --git a/yggdrasil_decision_forests/learner/isolation_forest/isolation_forest.cc b/yggdrasil_decision_forests/learner/isolation_forest/isolation_forest.cc index 08cbe232..9278b7f2 100644 --- a/yggdrasil_decision_forests/learner/isolation_forest/isolation_forest.cc +++ b/yggdrasil_decision_forests/learner/isolation_forest/isolation_forest.cc @@ -763,6 +763,8 @@ IsolationForestLearner::GetGenericHyperParameterSpecification() const { decision_tree::kHParamSplitAxisSparseObliqueWeights, decision_tree::kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, decision_tree::kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, + decision_tree::kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, + decision_tree::kHParamSplitAxisSparseObliqueWeightsIntegerMaximum, }; // Remove not yet implemented hyperparameters // TODO: b/345425508 - Implement more hyperparameters for isolation forests. diff --git a/yggdrasil_decision_forests/port/python/CHANGELOG.md b/yggdrasil_decision_forests/port/python/CHANGELOG.md index 77d1840b..ce73b449 100644 --- a/yggdrasil_decision_forests/port/python/CHANGELOG.md +++ b/yggdrasil_decision_forests/port/python/CHANGELOG.md @@ -21,7 +21,7 @@ learner constructor argument. See the [feature selection tutorial]() for more details. - Add standalone prediction evaluation `ydf.evaluate_predictions()`. -- Add option "POWER_OF_TWO" for sparse oblique weights. +- Add options "POWER_OF_TWO" and "INTEGER" for sparse oblique weights. ### Fix diff --git a/yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py b/yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py index 87f18cf6..85a15d2e 100644 --- a/yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py +++ b/yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py @@ -1122,6 +1122,28 @@ def test_oblique_weights_power_of_two(self): acceptable_weights_2 = [x * 2**y for x in (1.0, -1.0) for y in range(4, 8)] self.assertTrue(all(x in acceptable_weights_2 for x in root_weights_2)) + def test_oblique_weights_integer(self): + learner = specialized_learners.CartLearner( + label="label", + max_depth=2, + split_axis="SPARSE_OBLIQUE", + sparse_oblique_weights="INTEGER", + ) + f1 = np.linspace(-1, 1, 200) ** 2 + f2 = np.linspace(1.5, -0.5, 200) ** 2 + label = (0.2 * f1 + 0.7 * f2 >= 0.25).astype(int) + ds = {"f1": f1, "f2": f2, "label": label} + model = learner.train(ds) + root_weights = model.get_tree(0).root.condition.weights + acceptable_weights = [x * y for x in (1.0, -1.0) for y in range(0, 6)] + self.assertTrue(all(x in acceptable_weights for x in root_weights)) + learner.hyperparameters["sparse_oblique_weights_integer_minimum"] = 7 + learner.hyperparameters["sparse_oblique_weights_integer_maximum"] = 14 + model_2 = learner.train(ds) + root_weights_2 = model_2.get_tree(0).root.condition.weights + acceptable_weights_2 = [x * y for x in (1.0, -1.0) for y in range(7, 15)] + self.assertTrue(all(x in acceptable_weights_2 for x in root_weights_2)) + class GradientBoostedTreesLearnerTest(LearnerTest):