Skip to content

Commit

Permalink
Add MAE loss test.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574830522
  • Loading branch information
achoum authored and copybara-github committed Oct 19, 2023
1 parent 569a13c commit b1baa9f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
### Features

- Add support for monotonic constraints.
- Add support for Mean average error (MAE) loss.
- Add support for Poisson loss.

## 1.6.0 2023-09-27

Expand Down
43 changes: 43 additions & 0 deletions tensorflow_decision_forests/keras/keras_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2233,6 +2233,13 @@ def test_gbt_loss(self):
)
mse_model.fit(x=x_train, y=y_train)

mae_model = keras.GradientBoostedTreesModel(
validation_ratio=0.0,
loss="MEAN_AVERAGE_ERROR",
task=keras.Task.REGRESSION,
)
mae_model.fit(x=x_train, y=y_train)

multinom_model = keras.GradientBoostedTreesModel(
validation_ratio=0.0, loss="MULTINOMIAL_LOG_LIKELIHOOD"
)
Expand Down Expand Up @@ -2993,6 +3000,42 @@ def test_monotonic_non_compatible_options(self):
):
model.fit(tf_dataset)

def test_abalone_mae_loss(self):
dataset = abalone_dataset()
tf_train, tf_test = dataset_to_tf_dataset(dataset)

model = keras.GradientBoostedTreesModel(
loss="MEAN_AVERAGE_ERROR", task=keras.Task.REGRESSION
)
model.compile(metrics=["mse", "mae"])

model.fit(tf_train)
evaluation = model.evaluate(tf_test)

logging.info("Evaluation: %s", evaluation)
self.assertLessEqual(evaluation[1], 6.0)

predictions = model.predict(tf_test)
logging.info("Predictions: %s", predictions)

def test_abalone_poison_loss(self):
dataset = abalone_dataset()
tf_train, tf_test = dataset_to_tf_dataset(dataset)

model = keras.GradientBoostedTreesModel(
loss="POISSON", task=keras.Task.REGRESSION
)
model.compile(metrics=["mse", "mae"])

model.fit(tf_train)
evaluation = model.evaluate(tf_test)

logging.info("Evaluation: %s", evaluation)
self.assertLessEqual(evaluation[1], 6.0)

predictions = model.predict(tf_test)
logging.info("Predictions: %s", predictions)


if __name__ == "__main__":
tf.test.main()

0 comments on commit b1baa9f

Please sign in to comment.