Skip to content

Commit

Permalink
Tabular transformer (part1).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700998134
  • Loading branch information
achoum authored and copybara-github committed Nov 28, 2024
1 parent a0226e5 commit dd988de
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,7 @@ def run(
# and to reuse it at each iteration.

# List the input features
dataspec = learner._get_vertical_dataset(ds).data_spec() # pylint: disable=protected-access
non_input_feature_columns = set(learner._non_input_feature_columns()) # pylint: disable=protected-access
input_features = [
col.name
for col in dataspec.columns
if col.name not in non_input_feature_columns
]
input_features = learner.extract_input_feature_names(ds)
log.info(
"Run backward feature selection on %d features", len(input_features)
)
Expand Down
241 changes: 152 additions & 89 deletions yggdrasil_decision_forests/port/python/ydf/learner/generic_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

"""Definitions for Generic learners."""

import abc
import copy
import datetime
import os
import re
from typing import List, Optional, Sequence, Set, Union

Expand Down Expand Up @@ -47,7 +47,7 @@
_FRAMEWORK_NAME = "Python YDF"


class GenericLearner:
class GenericLearner(abc.ABC):
"""A generic YDF learner."""

def __init__(
Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__(
self._deployment_config = deployment_config
self._tuner = tuner
self._feature_selector = feature_selector
self._explicit_learner_arguments = explicit_learner_arguments

if self._label is not None and not isinstance(label, str):
raise ValueError("The 'label' should be a string")
Expand Down Expand Up @@ -118,32 +119,28 @@ def __init__(
if tuner:
tuner.set_base_learner(learner_name)

if explicit_learner_arguments is not None:
self._hyperparameters = self._clean_up_hyperparameters(
explicit_learner_arguments
)

self.validate_hyperparameters()
self.post_init()

@property
def learner_name(self) -> str:
return self._learner_name
# === Following are the virtual methods that a learner should implement ===

@property
def hyperparameters(self) -> hp_lib.HyperParameters:
"""A (mutable) dictionary of this learner's hyperparameters.
@abc.abstractmethod
def post_init(self):
"""Called after __init__."""
raise NotImplementedError

This object can be used to inspect or modify hyperparameters after creating
the learner. Modifying hyperparameters after constructing the learner is
suitable for some advanced use cases. Since this approach bypasses some
feasibility checks for the given set of hyperparameters, it generally better
to re-create the learner for each model. The current set of hyperparameters
can be validated manually with `validate_hyperparameters()`.
"""
return self._hyperparameters
@abc.abstractmethod
def train_imp(
self,
ds: dataset.InputDataset,
valid: Optional[dataset.InputDataset],
verbose: Optional[Union[int, bool]],
) -> generic_model.ModelType:
"""Trains a model."""
raise NotImplementedError

def validate_hyperparameters(self):
"""Returns None if the hyperparameters are valid, raises otherwise.
@abc.abstractmethod
def validate_hyperparameters(self) -> None:
"""Raises an exception if the hyperparameters are invalid.
This method is called automatically before training, but users may call it
to fail early. It makes sense to call this method when changing manually the
Expand All @@ -165,22 +162,72 @@ def validate_hyperparameters(self):
evaluation = model.evaluate(test_ds)
```
"""
return hp_lib.validate_hyperparameters(
self._hyperparameters,
self._get_training_config(),
self._deployment_config,
)
raise NotImplementedError

def _clean_up_hyperparameters(
self, explicit_parameters: Set[str]
) -> hp_lib.HyperParameters:
"""Returns the hyperparameters purged from the mutually exlusive ones."""
return hp_lib.fix_hyperparameters(
self._hyperparameters,
explicit_parameters,
self._get_training_config(),
self._deployment_config,
)
@abc.abstractmethod
def cross_validation(
self,
ds: dataset.InputDataset,
folds: int = 10,
bootstrapping: Union[bool, int] = False,
parallel_evaluations: int = 1,
) -> metric.Evaluation:
"""Cross-validates the learner and return the evaluation.
Usage example:
```python
import pandas as pd
import ydf
dataset = pd.read_csv("my_dataset.csv")
learner = ydf.RandomForestLearner(label="label")
evaluation = learner.cross_validation(dataset)
# In a notebook, display an interractive evaluation
evaluation
# Print the evaluation
print(evaluation)
# Look at specific metrics
print(evaluation.accuracy)
```
Args:
ds: Dataset for the cross-validation.
folds: Number of cross-validation folds.
bootstrapping: Controls whether bootstrapping is used to evaluate the
confidence intervals and statistical tests (i.e., all the metrics ending
with "[B]"). If set to false, bootstrapping is disabled. If set to true,
bootstrapping is enabled and 2000 bootstrapping samples are used. If set
to an integer, it specifies the number of bootstrapping samples to use.
In this case, if the number is less than 100, an error is raised as
bootstrapping will not yield useful results.
parallel_evaluations: Number of model to train and evaluate in parallel
using multi-threading. Note that each model is potentially already
trained with multithreading (see `num_threads` argument of Learner
constructor).
Returns:
The cross-validation evaluation.
"""
raise NotImplementedError

@classmethod
def capabilities(cls) -> abstract_learner_pb2.LearnerCapabilities:
raise NotImplementedError

@abc.abstractmethod
def extract_input_feature_names(self, ds: dataset.InputDataset) -> List[str]:
"""Extracts the input features available in a dataset."""
raise NotImplementedError

# === Following are the non virtual and general methods for all learners ===

@property
def learner_name(self) -> str:
return self._learner_name

def train(
self,
Expand Down Expand Up @@ -274,14 +321,24 @@ def train(
# Training
saved_verbose = log.verbose(verbose) if verbose is not None else None
try:
if isinstance(ds, str):
return self._train_from_path(ds, valid)
else:
return self._train_from_dataset(ds, valid)
return self.train_imp(ds, valid, verbose)
finally:
if saved_verbose is not None:
log.verbose(saved_verbose)

@property
def hyperparameters(self) -> hp_lib.HyperParameters:
"""A (mutable) dictionary of this learner's hyperparameters.
This object can be used to inspect or modify hyperparameters after creating
the learner. Modifying hyperparameters after constructing the learner is
suitable for some advanced use cases. Since this approach bypasses some
feasibility checks for the given set of hyperparameters, it generally better
to re-create the learner for each model. The current set of hyperparameters
can be validated manually with `validate_hyperparameters()`.
"""
return self._hyperparameters

def __str__(self) -> str:
return f"""\
Learner: {self._learner_name}
Expand All @@ -290,6 +347,46 @@ def __str__(self) -> str:
Hyper-parameters: ydf.{self._hyperparameters}
"""


class GenericCCLearner(GenericLearner):
"""A generic YDF learner using YDF C++ for training."""

def post_init(self):
if self._explicit_learner_arguments is not None:
self._hyperparameters = self._clean_up_hyperparameters(
self._explicit_learner_arguments
)
self.validate_hyperparameters()

def train_imp(
self,
ds: dataset.InputDataset,
valid: Optional[dataset.InputDataset],
verbose: Optional[Union[int, bool]],
) -> generic_model.ModelType:
if isinstance(ds, str):
return self._train_from_path(ds, valid)
else:
return self._train_from_dataset(ds, valid)

def validate_hyperparameters(self) -> None:
return hp_lib.validate_hyperparameters(
self._hyperparameters,
self._get_training_config(),
self._deployment_config,
)

def _clean_up_hyperparameters(
self, explicit_parameters: Set[str]
) -> hp_lib.HyperParameters:
"""Returns the hyperparameters purged from the mutually exlusive ones."""
return hp_lib.fix_hyperparameters(
self._hyperparameters,
explicit_parameters,
self._get_training_config(),
self._deployment_config,
)

def _train_from_path(
self, ds: str, valid: Optional[str]
) -> generic_model.ModelType:
Expand Down Expand Up @@ -366,7 +463,7 @@ def _get_training_config(self) -> abstract_learner_pb2.TrainingConfig:
weight_definition=self._build_weight_definition(),
ranking_group=self._ranking_group,
uplift_treatment=self._uplift_treatment,
task=self._task._to_proto_type(),
task=self._task._to_proto_type(), # pylint: disable=protected-access
metadata=abstract_model_pb2.Metadata(framework=_FRAMEWORK_NAME),
)

Expand Down Expand Up @@ -439,6 +536,8 @@ def _non_input_feature_columns(self) -> List[str]:
def _get_vertical_dataset(
self, ds: dataset.InputDataset
) -> dataset.VerticalDataset:
"""Gets the vertical dataset (i.e., dataset in raw) or a dataset."""

if isinstance(ds, dataset.VerticalDataset):
if self._data_spec is not None:
raise ValueError(
Expand Down Expand Up @@ -503,47 +602,6 @@ def cross_validation(
bootstrapping: Union[bool, int] = False,
parallel_evaluations: int = 1,
) -> metric.Evaluation:
"""Cross-validates the learner and return the evaluation.
Usage example:
```python
import pandas as pd
import ydf
dataset = pd.read_csv("my_dataset.csv")
learner = ydf.RandomForestLearner(label="label")
evaluation = learner.cross_validation(dataset)
# In a notebook, display an interractive evaluation
evaluation
# Print the evaluation
print(evaluation)
# Look at specific metrics
print(evaluation.accuracy)
```
Args:
ds: Dataset for the cross-validation.
folds: Number of cross-validation folds.
bootstrapping: Controls whether bootstrapping is used to evaluate the
confidence intervals and statistical tests (i.e., all the metrics ending
with "[B]"). If set to false, bootstrapping is disabled. If set to true,
bootstrapping is enabled and 2000 bootstrapping samples are used. If set
to an integer, it specifies the number of bootstrapping samples to use.
In this case, if the number is less than 100, an error is raised as
bootstrapping will not yield useful results.
parallel_evaluations: Number of model to train and evaluate in parallel
using multi-threading. Note that each model is potentially already
trained with multithreading (see `num_threads` argument of Learner
constructor).
Returns:
The cross-validation evaluation.
"""

fold_generator = fold_generator_pb2.FoldGenerator(
cross_validation=fold_generator_pb2.FoldGenerator.CrossValidation(
num_folds=folds,
Expand All @@ -562,7 +620,7 @@ def cross_validation(
)
evaluation_options = metric_pb2.EvaluationOptions(
bootstrapping_samples=bootstrapping_samples,
task=self._task._to_proto_type(),
task=self._task._to_proto_type(), # pylint: disable=protected-access
)

deployment_evaluation = abstract_learner_pb2.DeploymentConfig(
Expand Down Expand Up @@ -725,9 +783,14 @@ def _build_deployment_config(

return config

@classmethod
def capabilities(cls) -> abstract_learner_pb2.LearnerCapabilities:
raise NotImplementedError("Not implemented")
def extract_input_feature_names(self, ds: dataset.InputDataset) -> List[str]:
spec = self._get_vertical_dataset(ds).data_spec()
non_input_feature_columns = set(self._non_input_feature_columns())
return [
col.name
for col in spec.columns
if col.name not in non_input_feature_columns
]


def _feature_name_to_regex(name: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from ydf.utils import func_helpers


class RandomForestLearner(generic_learner.GenericLearner):
class RandomForestLearner(generic_learner.GenericCCLearner):
r"""Random Forest learning algorithm.
A [Random Forest](https://www.stat.berkeley.edu/~breiman/randomforest2001.pdf)
Expand Down Expand Up @@ -674,7 +674,7 @@ def hyperparameter_templates(
}


class IsolationForestLearner(generic_learner.GenericLearner):
class IsolationForestLearner(generic_learner.GenericCCLearner):
r"""Isolation Forest learning algorithm.
An [Isolation Forest](https://ieeexplore.ieee.org/abstract/document/4781136)
Expand Down Expand Up @@ -999,7 +999,7 @@ def hyperparameter_templates(
return {}


class GradientBoostedTreesLearner(generic_learner.GenericLearner):
class GradientBoostedTreesLearner(generic_learner.GenericCCLearner):
r"""Gradient Boosted Trees learning algorithm.
A [Gradient Boosted Trees](https://statweb.stanford.edu/~jhf/ftp/trebst.pdf)
Expand Down Expand Up @@ -1767,7 +1767,7 @@ def hyperparameter_templates(
}


class DistributedGradientBoostedTreesLearner(generic_learner.GenericLearner):
class DistributedGradientBoostedTreesLearner(generic_learner.GenericCCLearner):
r"""Distributed Gradient Boosted Trees learning algorithm.
Exact distributed version of the Gradient Boosted Tree learning algorithm. See
Expand Down Expand Up @@ -2126,7 +2126,7 @@ def hyperparameter_templates(
return {}


class CartLearner(generic_learner.GenericLearner):
class CartLearner(generic_learner.GenericCCLearner):
r"""Cart learning algorithm.
A CART (Classification and Regression Trees) a decision tree. The non-leaf
Expand Down
Loading

0 comments on commit dd988de

Please sign in to comment.