Skip to content

Commit

Permalink
Merge with main - fix errors
Browse files Browse the repository at this point in the history
  • Loading branch information
acalejos committed Jan 20, 2024
1 parent 975e285 commit f3c5aae
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 16 deletions.
2 changes: 1 addition & 1 deletion lib/exgboost/plotting.ex
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ defmodule EXGBoost.Plotting do

def deep_merge_kw(a, b) do
Keyword.merge(a, b, fn
key, val_a, val_b when is_list(val_a) and is_list(val_b) ->
_key, val_a, val_b when is_list(val_a) and is_list(val_b) ->
deep_merge_kw(val_a, val_b)

key, val_a, val_b ->
Expand Down
29 changes: 24 additions & 5 deletions lib/exgboost/training.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ defmodule EXGBoost.Training do
learning_rates: nil,
num_boost_rounds: 10,
obj: nil,
verbose_eval: true
verbose_eval: true,
disable_default_eval_metric: false
]

{opts, booster_params} = Keyword.split(opts, Keyword.keys(valid_opts))

[
callbacks: callbacks,
disable_default_eval_metric: disable_default_eval_metric,
early_stopping_rounds: early_stopping_rounds,
evals: evals,
learning_rates: learning_rates,
Expand Down Expand Up @@ -55,7 +57,14 @@ defmodule EXGBoost.Training do
)

defaults =
default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds)
default_callbacks(
bst,
learning_rates,
verbose_eval,
evals_dmats,
early_stopping_rounds,
disable_default_eval_metric
)

callbacks =
Enum.map(callbacks ++ defaults, fn %Callback{fun: fun} = callback ->
Expand Down Expand Up @@ -119,7 +128,14 @@ defmodule EXGBoost.Training do
%{state | iteration: iter}
end

defp default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) do
defp default_callbacks(
bst,
learning_rates,
verbose_eval,
evals_dmats,
early_stopping_rounds,
disable_default_eval_metric
) do
default_callbacks = []

default_callbacks =
Expand Down Expand Up @@ -154,12 +170,15 @@ defmodule EXGBoost.Training do
if early_stopping_rounds && evals_dmats != [] do
[{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats)

# This is still somewhat hacky and relies on a modification made to
# XGBoost in the Makefile to dump the config to JSON.
#
%{"learner" => %{"metrics" => metrics, "default_metric" => default_metric}} =
EXGBoost.dump_config(bst) |> Jason.decode!()
EXGBoost.dump_config(bst) |> Jason.decode!() |> dbg()

metric_name =
cond do
Enum.empty?(metrics) && opts[:disable_default_eval_metric] ->
Enum.empty?(metrics) && disable_default_eval_metric ->
raise ArgumentError,
"`:early_stopping_rounds` requires at least one evaluation set. This means you have likely set `disable_default_eval_metric: true` and have not set any explicit evalutation metrics. Please supply at least one metric in the `:eval_metric` option or set `disable_default_eval_metric: false` (default option)"

Expand Down
24 changes: 14 additions & 10 deletions test/exgboost_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,20 @@ defmodule EXGBoostTest do
refute is_nil(booster.best_iteration)
refute is_nil(booster.best_score)

{booster, _} =
ExUnit.CaptureIO.with_io(fn ->
EXGBoost.train(x, y,
disable_default_eval_metric: true,
num_boost_rounds: 10,
early_stopping_rounds: 1,
evals: [{x, y, "validation"}],
tree_method: :hist
)
end)
# If no eval metric is provided, the default metric is used. If the default
# metric is disabled, an error is raised.
assert_raise ArgumentError,
fn ->
ExUnit.CaptureIO.with_io(fn ->
EXGBoost.train(x, y,
disable_default_eval_metric: true,
num_boost_rounds: 10,
early_stopping_rounds: 1,
evals: [{x, y, "validation"}],
tree_method: :hist
)
end)
end

refute is_nil(booster.best_iteration)
refute is_nil(booster.best_score)
Expand Down

0 comments on commit f3c5aae

Please sign in to comment.