From 7b047c491fb885113172e864daba9e00a38457ae Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Wed, 13 Dec 2023 12:54:29 -0500 Subject: [PATCH 1/4] get tests passing --- lib/exgboost/training.ex | 349 +++++++++++++----------------- lib/exgboost/training/callback.ex | 136 ++++++------ lib/exgboost/training/state.ex | 13 +- 3 files changed, 219 insertions(+), 279 deletions(-) diff --git a/lib/exgboost/training.ex b/lib/exgboost/training.ex index 624fa39..3f882e1 100644 --- a/lib/exgboost/training.ex +++ b/lib/exgboost/training.ex @@ -6,40 +6,49 @@ defmodule EXGBoost.Training do @spec train(DMatrix.t(), Keyword.t()) :: Booster.t() def train(%DMatrix{} = dmat, opts \\ []) do - {opts, booster_params} = - Keyword.split(opts, [ - :obj, - :num_boost_rounds, - :evals, - :verbose_eval, - :callbacks, - :learning_rates, - :early_stopping_rounds - ]) - - opts = - Keyword.validate!(opts, - obj: nil, - num_boost_rounds: 10, - evals: [], - verbose_eval: true, - callbacks: [], - learning_rates: nil, - early_stopping_rounds: nil - ) + valid_opts = [ + callbacks: [], + early_stopping_rounds: nil, + evals: [], + learning_rates: nil, + num_boost_rounds: 10, + obj: nil, + verbose_eval: true + ] + + {opts, booster_params} = Keyword.split(opts, Keyword.keys(valid_opts)) + + [ + callbacks: callbacks, + early_stopping_rounds: early_stopping_rounds, + evals: evals, + learning_rates: learning_rates, + num_boost_rounds: num_boost_rounds, + obj: objective, + verbose_eval: verbose_eval + ] = opts |> Keyword.validate!(valid_opts) |> Enum.sort() + + unless is_nil(learning_rates) or is_function(learning_rates, 1) or is_list(learning_rates) do + raise ArgumentError, "learning_rates must be a function/1 or a list" + end - learning_rates = Keyword.fetch!(opts, :learning_rates) + if early_stopping_rounds && evals == [] do + raise ArgumentError, "early_stopping_rounds requires at least one evaluation set" + end - if not is_nil(learning_rates) and - not (is_function(learning_rates, 1) or is_list(learning_rates)) do - raise ArgumentError, "learning_rates must be a function/1 or a list" + for callback <- callbacks do + Callback.validate!(callback) end - objective = Keyword.fetch!(opts, :obj) - evals = Keyword.fetch!(opts, :evals) + verbose_eval = + case verbose_eval do + true -> 1 + false -> 0 + value -> value + end evals_dmats = - Enum.map(evals, fn {%Nx.Tensor{} = x, %Nx.Tensor{} = y, name} -> + Enum.map(evals, fn {x, y, name} -> {DMatrix.from_tensor(x, y, format: :dense), name} end) @@ -49,200 +58,136 @@ defmodule EXGBoost.Training do booster_params ) - verbose_eval = - case Keyword.fetch!(opts, :verbose_eval) do - true -> 1 - false -> 0 - value -> value - end + callbacks = + callbacks ++ + default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) - callbacks = Keyword.fetch!(opts, :callbacks) |> Enum.reverse() + callbacks = Enum.map(callbacks, &wrap_callback/1) - callbacks = - unless is_nil(opts[:learning_rates]) do - [ - %Callback{ - event: :before_iteration, - fun: &Callback.lr_scheduler/1, - name: :lr_scheduler, - init_state: %{learning_rates: learning_rates} - } - | callbacks - ] - else - callbacks - end + state = %State{ + booster: bst, + iteration: 0, + max_iteration: num_boost_rounds, + meta_vars: Map.new(callbacks, &{&1.name, &1.init_state}) + } - callbacks = - unless verbose_eval == 0 or evals_dmats == [] do - [ - %Callback{ - event: :after_iteration, - fun: &Callback.monitor_metrics/1, - name: :monitor_metrics, - init_state: %{period: verbose_eval, filter: fn {_, _} -> true end} - } - | callbacks - ] - else - callbacks - end + callbacks_by_event = Enum.group_by(callbacks, & &1.event, & &1.fun) - callbacks = - unless is_nil(opts[:early_stopping_rounds]) do - unless evals_dmats == [] do - [{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats) - - # Default to the last metric - [%{"name" => metric_name} | _tail] = - EXGBoost.dump_config(bst) - |> Jason.decode!() - |> get_in(["learner", "metrics"]) - |> Enum.reverse() - - [ - %Callback{ - event: :after_iteration, - fun: &Callback.early_stop/1, - name: :early_stop, - init_state: %{ - patience: opts[:early_stopping_rounds], - best: nil, - since_last_improvement: 0, - mode: :min, - target_eval: target_eval, - target_metric: metric_name - } - } - | callbacks - ] - else - raise ArgumentError, "early_stopping_rounds requires at least one evaluation set" - end - else - callbacks - end + state = run_callbacks(state, callbacks_by_event, :before_training) - callbacks = - unless evals_dmats == [] do - [ - %Callback{ - event: :after_iteration, - fun: &Callback.eval_metrics/1, - name: :eval_metrics, - init_state: %{evals: evals_dmats, filter: fn {_, _} -> true end} - } - | callbacks - ] + state = + if state.status == :halt do + state else - callbacks - end + Enum.reduce_while(1..state.max_iteration, state, fn iter, iter_state -> + iter_state = run_callbacks(iter_state, callbacks_by_event, :before_iteration) + + iter_state = + if iter_state.status == :halt do + iter_state + else + :ok = Booster.update(iter_state.booster, dmat, iter, objective) + run_callbacks(%{iter_state | iteration: iter}, callbacks_by_event, :after_iteration) + end - default = %{ - before_iteration: [], - after_iteration: [], - before_training: [], - after_training: [], - init_state: %{} - } + {iter_state.status, iter_state} + end) + end - env = - callbacks - |> Enum.reverse() - |> Enum.reduce(default, fn %Callback{} = callback, acc -> - acc = - case callback.event do - :before_iteration -> - %{acc | before_iteration: [callback.fun | acc[:before_iteration]]} + if state.status == :halt do + state.booster + else + final_state = run_callbacks(state, callbacks_by_event, :after_training) + final_state.booster + end + end - :after_iteration -> - %{acc | after_iteration: [callback.fun | acc[:after_iteration]]} + defp wrap_callback(%Callback{fun: fun} = callback) do + %{callback | fun: fn state -> state |> fun.() |> State.validate!() end} + end - :before_training -> - %{acc | before_training: [callback.fun | acc[:before_training]]} + defp run_callbacks(state, callbacks_by_event, event) do + Enum.reduce_while(callbacks_by_event[event] || [], state, fn callback, state -> + state = callback.(state) + {state.status, state} + end) + end - :after_training -> - %{acc | after_training: [callback.fun | acc[:after_training]]} + defp default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) do + default_callbacks = [] - _ -> - raise ArgumentError, "Invalid callback: #{inspect(callback)}" - end + default_callbacks = + if learning_rates do + lr_scheduler = %Callback{ + event: :before_iteration, + fun: &Callback.lr_scheduler/1, + name: :lr_scheduler, + init_state: %{learning_rates: learning_rates} + } - case callback.name do - nil -> acc - name -> put_in(acc[:init_state][name], callback.init_state) - end - end) + [lr_scheduler | default_callbacks] + else + default_callbacks + end - start_iteration = 0 - num_boost_rounds = Keyword.fetch!(opts, :num_boost_rounds) + default_callbacks = + if verbose_eval != 0 and evals_dmats != [] do + monitor_metrics = %Callback{ + event: :after_iteration, + fun: &Callback.monitor_metrics/1, + name: :monitor_metrics, + init_state: %{period: verbose_eval, filter: fn {_, _} -> true end} + } - init_state = %State{ - booster: bst, - iteration: 0, - max_iteration: num_boost_rounds, - meta_vars: env[:init_state] - } + [monitor_metrics | default_callbacks] + else + default_callbacks + end - {status, state} = - case run_callbacks(env[:before_training], init_state) do - {:halt, state} -> - {:halted, state} - - {:cont, state} -> - Enum.reduce_while( - start_iteration..(num_boost_rounds - 1), - {:cont, state}, - fn iter, {_, iter_state} -> - case run_callbacks(env[:before_iteration], iter_state) do - {:halt, state} -> - {:halt, {:halted, state}} - - {:cont, state} -> - Booster.update(state.booster, dmat, iter, objective) - - case run_callbacks(env[:after_iteration], %{state | booster: bst}) do - {:halt, state} -> - {:halt, {:halted, state}} - - {:cont, state} -> - {:cont, {:cont, %{state | iteration: state.iteration + 1}}} - end - end - end - ) + default_callbacks = + if early_stopping_rounds && evals_dmats != [] do + [{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats) + + # Default to the last metric + [%{"name" => metric_name} | _tail] = + EXGBoost.dump_config(bst) + |> Jason.decode!() + |> get_in(["learner", "metrics"]) + |> Enum.reverse() + + early_stop = %Callback{ + event: :after_iteration, + fun: &Callback.early_stop/1, + name: :early_stop, + init_state: %{ + patience: early_stopping_rounds, + best: nil, + since_last_improvement: 0, + mode: :min, + target_eval: target_eval, + target_metric: metric_name + } + } - _ -> - raise "invalid return value from before_training callback" + [early_stop | default_callbacks] + else + default_callbacks end - case status do - :halted -> - state.booster - - :cont -> - {_status, final_state} = run_callbacks(env[:after_training], state) - final_state.booster - end - end + default_callbacks = + if evals_dmats != [] do + eval_metrics = %Callback{ + event: :after_iteration, + fun: &Callback.eval_metrics/1, + name: :eval_metrics, + init_state: %{evals: evals_dmats, filter: fn {_, _} -> true end} + } - defp run_callbacks(callbacks, state) do - callbacks - |> Enum.reduce_while({:cont, state}, fn callback, {_, state} -> - case callback.(state) do - {:cont, %State{} = state} -> - {:cont, {:cont, state}} - - {:halt, %State{} = state} -> - {:halt, {:halt, state}} - - invalid -> - raise ArgumentError, - "invalid value #{inspect(invalid)} returned from callback" <> - " Callback handler must return" <> - " a tuple of {status, state} where status is one of :cont," <> - " or :halt and state is an updated State struct" + [eval_metrics | default_callbacks] + else + default_callbacks end - end) + + default_callbacks end end diff --git a/lib/exgboost/training/callback.ex b/lib/exgboost/training/callback.ex index 15b69ee..dadd0c0 100644 --- a/lib/exgboost/training/callback.ex +++ b/lib/exgboost/training/callback.ex @@ -46,30 +46,35 @@ defmodule EXGBoost.Training.Callback do @enforce_keys [:event, :fun] defstruct [:event, :fun, :name, :init_state] - @doc """ - Factory for a new callback without an initial state. See `EXGBoost.Callback.new/4` for more details. - """ - @spec new( - event :: :before_training | :after_training | :before_iteration | :after_iteration, - fun :: (State.t() -> {:cont, State.t()} | {:halt, State.t()}) - ) :: Callback.t() - def new(event, fun) do - new(event, fun, nil, %{}) - end + @type event :: :before_training | :after_training | :before_iteration | :after_iteration + @type fun :: (State.t() -> State.t()) + + @valid_events [:before_training, :after_training, :before_iteration, :after_iteration] @doc """ Factory for a new callback with an initial state. """ - @spec new( - event :: :before_training | :after_training | :before_iteration | :after_iteration, - fun :: (State.t() -> {:cont, State.t()} | {:halt, State.t()}), - name :: atom(), - init_state :: map() - ) :: Callback.t() - def new(event, fun, name, %{} = init_state) - when event in [:before_training, :after_training, :before_iteration, :after_iteration] and - is_atom(name) do + @spec new(event :: event(), fun :: fun(), name :: atom(), init_state :: any()) :: Callback.t() + def new(event, fun, name, init_state \\ %{}) + when event in @valid_events and is_function(fun, 1) and is_atom(name) and not is_nil(name) do %__MODULE__{event: event, fun: fun, name: name, init_state: init_state} + |> validate!() + end + + def validate!(%__MODULE__{} = callback) do + unless is_atom(callback.name) and not is_nil(callback.name) do + raise "A callback must have a non-`nil` atom for a name. Found: #{callback.name}." + end + + unless callback.event in @valid_events do + raise "Callback #{callback.name} must have an event in #{@valid_events}. Found: #{callback.event}." + end + + unless is_function(callback.fun, 1) do + raise "Callback #{callback.name} must have a 1-arity function. Found: #{callback.event}." + end + + callback end @doc """ @@ -83,12 +88,13 @@ defmodule EXGBoost.Training.Callback do %State{ booster: bst, meta_vars: %{lr_scheduler: %{learning_rates: learning_rates}}, - iteration: i + iteration: i, + status: :cont } = state ) do lr = if is_list(learning_rates), do: Enum.at(learning_rates, i), else: learning_rates.(i) boostr = EXGBoost.Booster.set_params(bst, learning_rate: lr) - {:cont, %{state | booster: boostr}} + %{state | booster: boostr} end # TODO: Ideally this would be generalized like it is in Axon to allow generic monitoring of metrics, @@ -110,20 +116,20 @@ defmodule EXGBoost.Training.Callback do def early_stop( %State{ booster: bst, - meta_vars: - %{ - early_stop: %{ - best: best, - patience: patience, - target_metric: target_metric, - target_eval: target_eval, - mode: mode, - since_last_improvement: since_last_improvement - } - } = meta_vars, - metrics: metrics + meta_vars: %{early_stop: early_stop} = meta_vars, + metrics: metrics, + status: :cont } = state ) do + %{ + best: best_score, + patience: patience, + target_metric: target_metric, + target_eval: target_eval, + mode: mode, + since_last_improvement: since_last_improvement + } = early_stop + unless Map.has_key?(metrics, target_eval) do raise ArgumentError, "target eval_set #{inspect(target_eval)} not found in metrics #{inspect(metrics)}" @@ -134,54 +140,34 @@ defmodule EXGBoost.Training.Callback do "target metric #{inspect(target_metric)} not found in metrics #{inspect(metrics)}" end - prev_criteria_value = best - - cur_criteria_value = metrics[target_eval][target_metric] + score = metrics[target_eval][target_metric] improved? = - case mode do - :min -> - prev_criteria_value == nil or - cur_criteria_value < prev_criteria_value - - :max -> - prev_criteria_value == nil or - cur_criteria_value > prev_criteria_value + cond do + best_score == nil -> true + mode == :min -> score < best_score + mode == :max -> score > best_score end - over_patience? = since_last_improvement >= patience - cond do improved? -> - updated_meta_vars = - meta_vars - |> put_in([:early_stop, :best], cur_criteria_value) - |> put_in([:early_stop, :since_last_improvement], 0) + early_stop = %{early_stop | best: score, since_last_improvement: 0} bst = bst - |> struct(best_iteration: state.iteration, best_score: cur_criteria_value) - |> EXGBoost.Booster.set_attr( - best_iteration: state.iteration, - best_score: cur_criteria_value - ) + |> struct(best_iteration: state.iteration, best_score: score) + |> EXGBoost.Booster.set_attr(best_iteration: state.iteration, best_score: score) - {:cont, %{state | meta_vars: updated_meta_vars, booster: bst}} + %{state | booster: bst, meta_vars: %{meta_vars | early_stop: early_stop}} - not improved? and not over_patience? -> - updated_meta_vars = - meta_vars - |> put_in([:early_stop, :since_last_improvement], since_last_improvement + 1) - - {:cont, %{state | meta_vars: updated_meta_vars}} + since_last_improvement < patience -> + early_stop = Map.update!(early_stop, :since_last_improvement, &(&1 + 1)) + %{state | meta_vars: %{meta_vars | early_stop: early_stop}} true -> - updated_meta_vars = - meta_vars - |> put_in([:early_stop, :since_last_improvement], since_last_improvement + 1) - - bst = struct(bst, best_iteration: state.iteration, best_score: cur_criteria_value) - {:halt, %{state | meta_vars: updated_meta_vars, booster: bst}} + early_stop = Map.update!(early_stop, :since_last_improvement, &(&1 + 1)) + bst = struct(bst, best_iteration: state.iteration, best_score: score) + %{state | booster: bst, meta_vars: %{meta_vars | early_stop: early_stop}, status: :halt} end end @@ -198,7 +184,8 @@ defmodule EXGBoost.Training.Callback do %State{ booster: bst, iteration: iter, - meta_vars: %{eval_metrics: %{evals: evals, filter: filter}} + meta_vars: %{eval_metrics: %{evals: evals, filter: filter}}, + status: :cont } = state ) do metrics = @@ -210,8 +197,7 @@ defmodule EXGBoost.Training.Callback do end) |> Map.filter(filter) - {:cont, %{state | metrics: metrics}} - {:cont, %{state | metrics: metrics}} + %{state | metrics: metrics} end @doc """ @@ -229,14 +215,14 @@ defmodule EXGBoost.Training.Callback do metrics: metrics, meta_vars: %{ monitor_metrics: %{period: period, filter: filter} - } + }, + status: :cont } = state ) do if period != 0 and rem(iteration, period) == 0 do - metrics = Map.filter(metrics, filter) - IO.puts("Iteration #{iteration}: #{inspect(metrics)}") + IO.puts("Iteration #{iteration}: #{inspect(Map.filter(metrics, filter))}") end - {:cont, state} + state end end diff --git a/lib/exgboost/training/state.ex b/lib/exgboost/training/state.ex index 15445ae..233e9db 100644 --- a/lib/exgboost/training/state.ex +++ b/lib/exgboost/training/state.ex @@ -3,9 +3,18 @@ defmodule EXGBoost.Training.State do @enforce_keys [:booster] defstruct [ :booster, - meta_vars: %{}, iteration: 0, max_iteration: -1, - metrics: %{} + meta_vars: %{}, + metrics: %{}, + status: :cont ] + + def validate!(%__MODULE__{} = state) do + unless state.status in [:cont, :halt] do + raise ArgumentError, "`status` must be `:cont` or `:halt`, found: #{state.status}." + end + + state + end end From d28ec8b097a826f72f82dd94dbb3ce9a8cea15d1 Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Wed, 13 Dec 2023 13:58:12 -0500 Subject: [PATCH 2/4] branch on status in function clauses --- lib/exgboost/training.ex | 71 ++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 35 deletions(-) diff --git a/lib/exgboost/training.ex b/lib/exgboost/training.ex index 3f882e1..44b3fd2 100644 --- a/lib/exgboost/training.ex +++ b/lib/exgboost/training.ex @@ -58,11 +58,13 @@ defmodule EXGBoost.Training do booster_params ) - callbacks = - callbacks ++ - default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) + defaults = + default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) - callbacks = Enum.map(callbacks, &wrap_callback/1) + callbacks = + Enum.map(callbacks ++ defaults, fn %Callback{fun: fun} = callback -> + %{callback | fun: fn state -> state |> fun.() |> State.validate!() end} + end) state = %State{ booster: bst, @@ -71,48 +73,47 @@ defmodule EXGBoost.Training do meta_vars: Map.new(callbacks, &{&1.name, &1.init_state}) } - callbacks_by_event = Enum.group_by(callbacks, & &1.event, & &1.fun) - - state = run_callbacks(state, callbacks_by_event, :before_training) + callbacks = Enum.group_by(callbacks, & &1.event, & &1.fun) state = - if state.status == :halt do - state - else - Enum.reduce_while(1..state.max_iteration, state, fn iter, iter_state -> - iter_state = run_callbacks(iter_state, callbacks_by_event, :before_iteration) - - iter_state = - if iter_state.status == :halt do - iter_state - else - :ok = Booster.update(iter_state.booster, dmat, iter, objective) - run_callbacks(%{iter_state | iteration: iter}, callbacks_by_event, :after_iteration) - end - - {iter_state.status, iter_state} - end) - end + state + |> run_callbacks(callbacks, :before_training) + |> run_training(callbacks, dmat, objective) + |> run_callbacks(callbacks, :after_training) - if state.status == :halt do - state.booster - else - final_state = run_callbacks(state, callbacks_by_event, :after_training) - final_state.booster - end + state.booster end - defp wrap_callback(%Callback{fun: fun} = callback) do - %{callback | fun: fn state -> state |> fun.() |> State.validate!() end} - end + defp run_callbacks(%{status: :halt} = state, _callbacks, _event), do: state - defp run_callbacks(state, callbacks_by_event, event) do - Enum.reduce_while(callbacks_by_event[event] || [], state, fn callback, state -> + defp run_callbacks(%{status: :cont} = state, callbacks, event) do + Enum.reduce_while(callbacks[event] || [], state, fn callback, state -> state = callback.(state) {state.status, state} end) end + defp run_training(%{status: :halt} = state, _callbacks, _dmat, _objective), do: state + + defp run_training(%{status: :cont} = state, callbacks, dmat, objective) do + Enum.reduce_while(1..state.max_iteration, state, fn iter, state -> + state = + state + |> run_callbacks(callbacks, :before_iteration) + |> run_iteration(dmat, iter, objective) + |> run_callbacks(callbacks, :after_iteration) + + {state.status, state} + end) + end + + defp run_iteration(%{status: :halt} = state, _dmat, _iter, _objective), do: state + + defp run_iteration(%{status: :cont} = state, dmat, iter, objective) do + :ok = Booster.update(state.booster, dmat, iter, objective) + %{state | iteration: iter} + end + defp default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) do default_callbacks = [] From f1977bf9190816eb237505471befb9291b5c675b Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Thu, 14 Dec 2023 14:42:20 -0500 Subject: [PATCH 3/4] demonstrate potential duplicate name behavior --- lib/exgboost/training.ex | 13 +++++++++---- test/exgboost_test.exs | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/lib/exgboost/training.ex b/lib/exgboost/training.ex index 44b3fd2..484e53f 100644 --- a/lib/exgboost/training.ex +++ b/lib/exgboost/training.ex @@ -36,10 +36,6 @@ defmodule EXGBoost.Training do raise ArgumentError, "early_stopping_rounds requires at least one evaluation set" end - for callback <- callbacks do - Callback.validate!(callback) - end - verbose_eval = case verbose_eval do true -> 1 @@ -66,6 +62,15 @@ defmodule EXGBoost.Training do %{callback | fun: fn state -> state |> fun.() |> State.validate!() end} end) + # Validate callbacks and ensure all names are unique. + Enum.each(callbacks, &Callback.validate!/1) + name_counts = Enum.frequencies_by(callbacks, & &1.name) + + if Enum.any?(name_counts, &(elem(&1, 1) > 1)) do + str = name_counts |> Enum.sort() |> Enum.map_join("\n\n", &" * #{inspect(&1)}") + raise ArgumentError, "Found duplicate callback names.\n\nName counts:\n\n#{str}\n" + end + state = %State{ booster: bst, iteration: 0, diff --git a/test/exgboost_test.exs b/test/exgboost_test.exs index 9b15aca..499fe0b 100644 --- a/test/exgboost_test.exs +++ b/test/exgboost_test.exs @@ -310,4 +310,36 @@ defmodule EXGBoostTest do assert EXGBoost.ArrayInterface.get_tensor(array_interface) == tensor end + + describe "errors" do + setup %{key: key0} do + {nrows, ncols} = {10, 10} + {x, key1} = Nx.Random.normal(key0, 0, 1, shape: {nrows, ncols}) + {y, _key2} = Nx.Random.normal(key1, 0, 1, shape: {nrows}) + %{x: x, y: y} + end + + test "duplicate callback names result in an error", %{x: x, y: y} do + # This callback's name is the same as one of the default callbacks. + custom_callback = EXGBoost.Training.Callback.new(:before_training, & &1, :monitor_metrics) + + assert_raise ArgumentError, + """ + Found duplicate callback names. + + Name counts: + + * {:eval_metrics, 1} + + * {:monitor_metrics, 2} + """, + fn -> + EXGBoost.train(x, y, + callbacks: [custom_callback], + eval_metric: [:rmse, :logloss], + evals: [{x, y, "validation"}] + ) + end + end + end end From 2c42efbc466a55aef3a9d74d975d865f5b39ca13 Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Thu, 14 Dec 2023 14:50:22 -0500 Subject: [PATCH 4/4] test callback wrapping is helpful --- lib/exgboost/training/state.ex | 3 ++- test/exgboost_test.exs | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/lib/exgboost/training/state.ex b/lib/exgboost/training/state.ex index 233e9db..06c8e09 100644 --- a/lib/exgboost/training/state.ex +++ b/lib/exgboost/training/state.ex @@ -12,7 +12,8 @@ defmodule EXGBoost.Training.State do def validate!(%__MODULE__{} = state) do unless state.status in [:cont, :halt] do - raise ArgumentError, "`status` must be `:cont` or `:halt`, found: #{state.status}." + raise ArgumentError, + "`status` must be `:cont` or `:halt`, found: `#{inspect(state.status)}`." end state diff --git a/test/exgboost_test.exs b/test/exgboost_test.exs index 499fe0b..60ffadb 100644 --- a/test/exgboost_test.exs +++ b/test/exgboost_test.exs @@ -341,5 +341,14 @@ defmodule EXGBoostTest do ) end end + + test "callback with bad function results in helpful error", %{x: x, y: y} do + bad_fun = fn state -> %{state | status: :bad_status} end + bad_callback = EXGBoost.Training.Callback.new(:before_training, bad_fun, :bad_callback) + + assert_raise ArgumentError, + "`status` must be `:cont` or `:halt`, found: `:bad_status`.", + fn -> EXGBoost.train(x, y, callbacks: [bad_callback]) end + end end end