Skip to content

Commit

Permalink
Feature selection (part 2)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699129506
  • Loading branch information
achoum authored and copybara-github committed Nov 22, 2024
1 parent 2b71b90 commit 075e567
Show file tree
Hide file tree
Showing 17 changed files with 406 additions and 7 deletions.
7 changes: 7 additions & 0 deletions documentation/public/docs/glossary.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,10 @@ The **default NDCG** is computed by averaging the gain over all the examples.
See section 3 of
[From RankNet to LambdaRank to LambdaMART: An Overview](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/MSR-TR-2010-82.pdf)
for more details.

## Feature selection

Feature selection algorithms identify and remove unnecessary input features,
improving model quality, and speeding up subsequent training. See the
[Wikipedia article](https://en.wikipedia.org/wiki/Feature_selection) for more
details.
4 changes: 4 additions & 0 deletions yggdrasil_decision_forests/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ cc_library_ydf(
deps = [
":abstract_model",
":hyperparameter_cc_proto",
"//yggdrasil_decision_forests/dataset:data_spec",
"//yggdrasil_decision_forests/dataset:data_spec_cc_proto",
"//yggdrasil_decision_forests/metric:report",
"//yggdrasil_decision_forests/model/decision_tree",
Expand All @@ -156,10 +157,13 @@ cc_library_ydf(
"//yggdrasil_decision_forests/utils:html_content",
"//yggdrasil_decision_forests/utils:plot",
"//yggdrasil_decision_forests/utils:protobuf",
"//yggdrasil_decision_forests/utils:status_macros",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
)

Expand Down
37 changes: 37 additions & 0 deletions yggdrasil_decision_forests/model/abstract_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ void AbstractModel::ExportProto(const AbstractModel& model,
*proto->mutable_hyperparameter_optimizer_logs() =
model.hyperparameter_optimizer_logs_.value();
}

if (model.feature_selection_logs_.has_value()) {
*proto->mutable_feature_selection_logs() =
model.feature_selection_logs_.value();
}
}

void AbstractModel::ImportProto(const proto::AbstractModel& proto,
Expand Down Expand Up @@ -143,6 +148,9 @@ void AbstractModel::ImportProto(const proto::AbstractModel& proto,
model->hyperparameter_optimizer_logs_ =
proto.hyperparameter_optimizer_logs();
}
if (proto.has_feature_selection_logs()) {
model->feature_selection_logs_ = proto.feature_selection_logs();
}
}

metric::proto::EvaluationResults AbstractModel::Evaluate(
Expand Down Expand Up @@ -921,6 +929,28 @@ void AbstractModel::AppendDescriptionAndStatistics(
if (hyperparameter_optimizer_logs_.has_value()) {
AppendHyperparameterOptimizerLogs(description);
}

if (feature_selection_logs_.has_value()) {
AppendFeatureSelectionLogs(description);
}
}

void AbstractModel::AppendFeatureSelectionLogs(std::string* description) const {
absl::StrAppend(description, "Feature selection logs:\n\n");
for (int iteration_idx = 0;
iteration_idx < feature_selection_logs_->iterations_size();
iteration_idx++) {
const auto& iteration =
feature_selection_logs_->iterations()[iteration_idx];
absl::StrAppendFormat(
description,
"Iteration:%d Score:%g\n\tFeatures: %s\n\tMetrics:", iteration_idx,
iteration.score(), absl::StrJoin(iteration.features(), ","));
for (const auto& metric : iteration.metrics()) {
absl::StrAppendFormat(description, "%s:%g", metric.first, metric.second);
}
absl::StrAppend(description, "\n");
}
}

void AbstractModel::AppendHyperparameterOptimizerLogs(
Expand Down Expand Up @@ -1246,6 +1276,12 @@ void AbstractModel::CopyAbstractModelMetaData(AbstractModel* dst) const {
} else {
dst->hyperparameter_optimizer_logs_ = {};
}

if (feature_selection_logs_.has_value()) {
dst->feature_selection_logs_ = feature_selection_logs_;
} else {
dst->feature_selection_logs_ = {};
}
}

absl::Status AbstractModel::Validate() const {
Expand Down Expand Up @@ -1438,6 +1474,7 @@ absl::Status AbstractModel::MakePureServing() {
is_pure_model_ = true;
precomputed_variable_importances_.clear();
hyperparameter_optimizer_logs_ = {};
feature_selection_logs_ = {};
return Validate();
}

Expand Down
14 changes: 14 additions & 0 deletions yggdrasil_decision_forests/model/abstract_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,15 @@ class AbstractModel {
return &hyperparameter_optimizer_logs_;
}

// Feature selection logs
const std::optional<proto::FeatureSelectionLogs>& feature_selection_logs()
const {
return feature_selection_logs_;
}
std::optional<proto::FeatureSelectionLogs>* mutable_feature_selection_logs() {
return &feature_selection_logs_;
}

// Clear the model from any information that is not required for model
// serving. This function is called when the model is trained with
// "pure_serving_model=true", or when using the "--pure_serving" operation in
Expand Down Expand Up @@ -494,6 +503,9 @@ class AbstractModel {
// Prints information about the hyper-parameter optimizer logs.
void AppendHyperparameterOptimizerLogs(std::string* description) const;

// Prints information about the feature selection logs.
void AppendFeatureSelectionLogs(std::string* description) const;

// Checks if the ModelIOOptions are sufficient to load the model.
//
// At this time, this function checks if a prefix if given.
Expand Down Expand Up @@ -544,6 +556,8 @@ class AbstractModel {
std::optional<proto::HyperparametersOptimizerLogs>
hyperparameter_optimizer_logs_;

std::optional<proto::FeatureSelectionLogs> feature_selection_logs_;

// Indicate if a model is pure for serving i.e. the model was tripped of all
// information not required for serving.
bool is_pure_model_ = false;
Expand Down
15 changes: 15 additions & 0 deletions yggdrasil_decision_forests/model/abstract_model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ message AbstractModel {
// Logs of the automated hyper-parameter tuning of the model.
optional HyperparametersOptimizerLogs hyperparameter_optimizer_logs = 11;

// Logs of the automated feature selection of the model.
optional FeatureSelectionLogs feature_selection_logs = 13;

// Indicate if a model is pure for serving i.e. the model was tripped of all
// information not required for serving.
optional bool is_pure_model = 12 [default = false];
Expand Down Expand Up @@ -155,6 +158,18 @@ message HyperparametersOptimizerLogs {
}
}

message FeatureSelectionLogs {
// Logs of a feature selection algorithm.
message Iteration {
optional float score = 1;
repeated string features = 2;
map<string, float> metrics = 3;
}

repeated Iteration iterations = 1;
optional int32 best_iteration_idx = 2;
}

// Proto used to serialize / deserialize the model to / from string. See
// "SerializeModel" and "DeserializeModel".
//
Expand Down
Loading

0 comments on commit 075e567

Please sign in to comment.