From f3c5aae972c1d120b0834ff9d395e73313c4b183 Mon Sep 17 00:00:00 2001 From: acalejos Date: Fri, 19 Jan 2024 23:14:16 -0500 Subject: [PATCH] Merge with main - fix errors --- lib/exgboost/plotting.ex | 2 +- lib/exgboost/training.ex | 29 ++++++++++++++++++++++++----- test/exgboost_test.exs | 24 ++++++++++++++---------- 3 files changed, 39 insertions(+), 16 deletions(-) diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index e89e032..c25b1f3 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -387,7 +387,7 @@ defmodule EXGBoost.Plotting do def deep_merge_kw(a, b) do Keyword.merge(a, b, fn - key, val_a, val_b when is_list(val_a) and is_list(val_b) -> + _key, val_a, val_b when is_list(val_a) and is_list(val_b) -> deep_merge_kw(val_a, val_b) key, val_a, val_b -> diff --git a/lib/exgboost/training.ex b/lib/exgboost/training.ex index 7bf942f..8c27b44 100644 --- a/lib/exgboost/training.ex +++ b/lib/exgboost/training.ex @@ -13,13 +13,15 @@ defmodule EXGBoost.Training do learning_rates: nil, num_boost_rounds: 10, obj: nil, - verbose_eval: true + verbose_eval: true, + disable_default_eval_metric: false ] {opts, booster_params} = Keyword.split(opts, Keyword.keys(valid_opts)) [ callbacks: callbacks, + disable_default_eval_metric: disable_default_eval_metric, early_stopping_rounds: early_stopping_rounds, evals: evals, learning_rates: learning_rates, @@ -55,7 +57,14 @@ defmodule EXGBoost.Training do ) defaults = - default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) + default_callbacks( + bst, + learning_rates, + verbose_eval, + evals_dmats, + early_stopping_rounds, + disable_default_eval_metric + ) callbacks = Enum.map(callbacks ++ defaults, fn %Callback{fun: fun} = callback -> @@ -119,7 +128,14 @@ defmodule EXGBoost.Training do %{state | iteration: iter} end - defp default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) do + defp default_callbacks( + bst, + learning_rates, + verbose_eval, + evals_dmats, + early_stopping_rounds, + disable_default_eval_metric + ) do default_callbacks = [] default_callbacks = @@ -154,12 +170,15 @@ defmodule EXGBoost.Training do if early_stopping_rounds && evals_dmats != [] do [{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats) + # This is still somewhat hacky and relies on a modification made to + # XGBoost in the Makefile to dump the config to JSON. + # %{"learner" => %{"metrics" => metrics, "default_metric" => default_metric}} = - EXGBoost.dump_config(bst) |> Jason.decode!() + EXGBoost.dump_config(bst) |> Jason.decode!() |> dbg() metric_name = cond do - Enum.empty?(metrics) && opts[:disable_default_eval_metric] -> + Enum.empty?(metrics) && disable_default_eval_metric -> raise ArgumentError, "`:early_stopping_rounds` requires at least one evaluation set. This means you have likely set `disable_default_eval_metric: true` and have not set any explicit evalutation metrics. Please supply at least one metric in the `:eval_metric` option or set `disable_default_eval_metric: false` (default option)" diff --git a/test/exgboost_test.exs b/test/exgboost_test.exs index e239ff6..cf96554 100644 --- a/test/exgboost_test.exs +++ b/test/exgboost_test.exs @@ -152,16 +152,20 @@ defmodule EXGBoostTest do refute is_nil(booster.best_iteration) refute is_nil(booster.best_score) - {booster, _} = - ExUnit.CaptureIO.with_io(fn -> - EXGBoost.train(x, y, - disable_default_eval_metric: true, - num_boost_rounds: 10, - early_stopping_rounds: 1, - evals: [{x, y, "validation"}], - tree_method: :hist - ) - end) + # If no eval metric is provided, the default metric is used. If the default + # metric is disabled, an error is raised. + assert_raise ArgumentError, + fn -> + ExUnit.CaptureIO.with_io(fn -> + EXGBoost.train(x, y, + disable_default_eval_metric: true, + num_boost_rounds: 10, + early_stopping_rounds: 1, + evals: [{x, y, "validation"}], + tree_method: :hist + ) + end) + end refute is_nil(booster.best_iteration) refute is_nil(booster.best_score)