Skip to content

Commit

Permalink
[YDF] Deprecate SparseObliqueSplit.binary_weights hyperparameter
Browse files Browse the repository at this point in the history
In favor of `SparseObliqueSplit.weights`.

PiperOrigin-RevId: 701854322
  • Loading branch information
rstz authored and copybara-github committed Dec 2, 2024
1 parent e7616c5 commit e7c0eb7
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Changelog under `yggdrasil_decision_forests/port/python/CHANGELOG.md`.
- Add support for distributed training for ranking gradient boosted tree
models.
- Add support for AVRO data file using the "avro:" prefix.
- Deprecated `SparseObliqueSplit.binary_weights` hyperparameter in favor of
`SparseObliqueSplit.weights`.

### Misc

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,28 @@ message DecisionTreeTrainingConfig {
// This value is best tuned for each dataset.
optional float projection_density_factor = 3 [default = 2];

// Deprecated, use `weights` instead.
//
// If true, the weight will be sampled in {-1,1} (default in "Sparse
// Projection Oblique Random Forests" (Tomita et al, 2020)). If false, the
// weight will be sampled in [-1,1].
optional bool binary_weight = 4 [default = true];
optional bool binary_weight = 4 [default = true, deprecated = true];

// Weights to apply to the projections.
//
// Continuous weights generally give better performance.
oneof weights {
BinaryWeights binary = 7;
ContinuousWeights continuous = 8;
}

// Weights sample in {-1, 1} (default in "Sparse Projection Oblique Random
// Forests" (Tomita et al, 2020))).
message BinaryWeights {}

// Weights sample in [-1, 1]. Consistently gives better quality models than
// binary weights.
message ContinuousWeights {}

// Normalization applied on the features, before applying the sparse oblique
// projections.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,7 @@ Increasing this value increases training and inference time (on average). This v
ASSIGN_OR_RETURN(auto param,
get_params(kHParamSplitAxisSparseObliqueWeights));
param->mutable_categorical()->set_default_value(
config.sparse_oblique_split().binary_weight()
? kHParamSplitAxisSparseObliqueWeightsBinary
: kHParamSplitAxisSparseObliqueWeightsContinuous);
kHParamSplitAxisSparseObliqueWeightsBinary);
param->mutable_categorical()->add_possible_values(
kHParamSplitAxisSparseObliqueWeightsBinary);
param->mutable_categorical()->add_possible_values(
Expand Down Expand Up @@ -744,9 +742,9 @@ absl::Status SetHyperParameters(
if (dt_config->has_sparse_oblique_split()) {
const auto& value = hparam.value().value().categorical();
if (value == kHParamSplitAxisSparseObliqueWeightsBinary) {
dt_config->mutable_sparse_oblique_split()->set_binary_weight(true);
dt_config->mutable_sparse_oblique_split()->mutable_binary();
} else if (value == kHParamSplitAxisSparseObliqueWeightsContinuous) {
dt_config->mutable_sparse_oblique_split()->set_binary_weight(false);
dt_config->mutable_sparse_oblique_split()->mutable_continuous();
} else {
return absl::InvalidArgumentError(absl::StrCat(
"Unknown value for parameter ",
Expand Down
4 changes: 3 additions & 1 deletion yggdrasil_decision_forests/learner/decision_tree/oblique.cc
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,9 @@ void SampleProjection(const absl::Span<const int>& features,

const auto gen_weight = [&](const int feature) -> float {
float weight = unif1m1(*random);
if (dt_config.sparse_oblique_split().binary_weight()) {
if (dt_config.sparse_oblique_split().has_binary() ||
dt_config.sparse_oblique_split().weights_case() ==
dt_config.sparse_oblique_split().WEIGHTS_NOT_SET) {
weight = (weight >= 0) ? 1.f : -1.f;
}

Expand Down
20 changes: 20 additions & 0 deletions yggdrasil_decision_forests/learner/decision_tree/training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3836,6 +3836,15 @@ void SetDefaultHyperParameters(proto::DecisionTreeTrainingConfig* config) {
}

config->mutable_internal()->set_sorting_strategy(sorting_strategy);

if (config->sparse_oblique_split().has_binary_weight()) {
if (config->sparse_oblique_split().binary_weight()) {
config->mutable_sparse_oblique_split()->mutable_binary();
} else {
config->mutable_sparse_oblique_split()->mutable_continuous();
}
config->mutable_sparse_oblique_split()->clear_binary_weight();
}
}

template <class T, class S, class C>
Expand Down Expand Up @@ -4064,6 +4073,17 @@ absl::Status DecisionTreeTrain(
"pure_serving_model=true.");
}

// Check if oblique splits are correctly specified
if (dt_config.sparse_oblique_split().has_binary_weight() &&
dt_config.sparse_oblique_split().weights_case() !=
dt_config.sparse_oblique_split().WEIGHTS_NOT_SET) {
return absl::InvalidArgumentError(
"Both sparse_oblique_split.binary_weights and "
"sparse_oblique_split.weights are set. Setting "
"sparse_oblique_split.binary_weights is deprecated and replaced by "
"just setting sparse_oblique_split.weights.");
}

if (dt_config.has_honest()) {
// Split the examples in two parts. One ("selected_examples_buffer") will be
// used to infer the structure of the trees while the second
Expand Down
44 changes: 44 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,50 @@ def test_tuner_predefined(self):
self.assertIsNotNone(logs)
self.assertLen(logs.trials, 5)

def test_oblique_weights_default(self):
learner = specialized_learners.CartLearner(
label="label",
max_depth=2,
split_axis="SPARSE_OBLIQUE",
)
f1 = np.linspace(-1, 1, 50) ** 2
f2 = np.linspace(1.5, -0.5, 50) ** 2
label = (0.2 * f1 + 0.7 * f2 >= 0.25).astype(int)
ds = {"f1": f1, "f2": f2, "label": label}
model = learner.train(ds)
root_weights = model.get_tree(0).root.condition.weights
self.assertTrue(all(x in (-1.0, 1.0) for x in root_weights))

def test_oblique_weights_binary(self):
learner = specialized_learners.CartLearner(
label="label",
max_depth=2,
split_axis="SPARSE_OBLIQUE",
sparse_oblique_weights="BINARY",
)
f1 = np.linspace(-1, 1, 50) ** 2
f2 = np.linspace(1.5, -0.5, 50) ** 2
label = (0.2 * f1 + 0.7 * f2 >= 0.25).astype(int)
ds = {"f1": f1, "f2": f2, "label": label}
model = learner.train(ds)
root_weights = model.get_tree(0).root.condition.weights
self.assertTrue(all(x in (-1.0, 1.0) for x in root_weights))

def test_oblique_weights_continuous(self):
learner = specialized_learners.CartLearner(
label="label",
max_depth=2,
split_axis="SPARSE_OBLIQUE",
sparse_oblique_weights="CONTINUOUS",
)
f1 = np.linspace(-1, 1, 50) ** 2
f2 = np.linspace(1.5, -0.5, 50) ** 2
label = (0.2 * f1 + 0.7 * f2 >= 0.25).astype(int)
ds = {"f1": f1, "f2": f2, "label": label}
model = learner.train(ds)
root_weights = model.get_tree(0).root.condition.weights
self.assertFalse(all(x in (-1.0, 1.0) for x in root_weights))


class GradientBoostedTreesLearnerTest(LearnerTest):

Expand Down

0 comments on commit e7c0eb7

Please sign in to comment.