diff --git a/lib/exgboost/training.ex b/lib/exgboost/training.ex index 624fa39..484e53f 100644 --- a/lib/exgboost/training.ex +++ b/lib/exgboost/training.ex @@ -6,40 +6,45 @@ defmodule EXGBoost.Training do @spec train(DMatrix.t(), Keyword.t()) :: Booster.t() def train(%DMatrix{} = dmat, opts \\ []) do - {opts, booster_params} = - Keyword.split(opts, [ - :obj, - :num_boost_rounds, - :evals, - :verbose_eval, - :callbacks, - :learning_rates, - :early_stopping_rounds - ]) - - opts = - Keyword.validate!(opts, - obj: nil, - num_boost_rounds: 10, - evals: [], - verbose_eval: true, - callbacks: [], - learning_rates: nil, - early_stopping_rounds: nil - ) - - learning_rates = Keyword.fetch!(opts, :learning_rates) - - if not is_nil(learning_rates) and - not (is_function(learning_rates, 1) or is_list(learning_rates)) do + valid_opts = [ + callbacks: [], + early_stopping_rounds: nil, + evals: [], + learning_rates: nil, + num_boost_rounds: 10, + obj: nil, + verbose_eval: true + ] + + {opts, booster_params} = Keyword.split(opts, Keyword.keys(valid_opts)) + + [ + callbacks: callbacks, + early_stopping_rounds: early_stopping_rounds, + evals: evals, + learning_rates: learning_rates, + num_boost_rounds: num_boost_rounds, + obj: objective, + verbose_eval: verbose_eval + ] = opts |> Keyword.validate!(valid_opts) |> Enum.sort() + + unless is_nil(learning_rates) or is_function(learning_rates, 1) or is_list(learning_rates) do raise ArgumentError, "learning_rates must be a function/1 or a list" end - objective = Keyword.fetch!(opts, :obj) - evals = Keyword.fetch!(opts, :evals) + if early_stopping_rounds && evals == [] do + raise ArgumentError, "early_stopping_rounds requires at least one evaluation set" + end + + verbose_eval = + case verbose_eval do + true -> 1 + false -> 0 + value -> value + end evals_dmats = - Enum.map(evals, fn {%Nx.Tensor{} = x, %Nx.Tensor{} = y, name} -> + Enum.map(evals, fn {x, y, name} -> {DMatrix.from_tensor(x, y, format: :dense), name} end) @@ -49,200 +54,146 @@ defmodule EXGBoost.Training do booster_params ) - verbose_eval = - case Keyword.fetch!(opts, :verbose_eval) do - true -> 1 - false -> 0 - value -> value - end - - callbacks = Keyword.fetch!(opts, :callbacks) |> Enum.reverse() + defaults = + default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) callbacks = - unless is_nil(opts[:learning_rates]) do - [ - %Callback{ - event: :before_iteration, - fun: &Callback.lr_scheduler/1, - name: :lr_scheduler, - init_state: %{learning_rates: learning_rates} - } - | callbacks - ] - else - callbacks - end - - callbacks = - unless verbose_eval == 0 or evals_dmats == [] do - [ - %Callback{ - event: :after_iteration, - fun: &Callback.monitor_metrics/1, - name: :monitor_metrics, - init_state: %{period: verbose_eval, filter: fn {_, _} -> true end} - } - | callbacks - ] - else - callbacks - end + Enum.map(callbacks ++ defaults, fn %Callback{fun: fun} = callback -> + %{callback | fun: fn state -> state |> fun.() |> State.validate!() end} + end) - callbacks = - unless is_nil(opts[:early_stopping_rounds]) do - unless evals_dmats == [] do - [{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats) - - # Default to the last metric - [%{"name" => metric_name} | _tail] = - EXGBoost.dump_config(bst) - |> Jason.decode!() - |> get_in(["learner", "metrics"]) - |> Enum.reverse() - - [ - %Callback{ - event: :after_iteration, - fun: &Callback.early_stop/1, - name: :early_stop, - init_state: %{ - patience: opts[:early_stopping_rounds], - best: nil, - since_last_improvement: 0, - mode: :min, - target_eval: target_eval, - target_metric: metric_name - } - } - | callbacks - ] - else - raise ArgumentError, "early_stopping_rounds requires at least one evaluation set" - end - else - callbacks - end + # Validate callbacks and ensure all names are unique. + Enum.each(callbacks, &Callback.validate!/1) + name_counts = Enum.frequencies_by(callbacks, & &1.name) - callbacks = - unless evals_dmats == [] do - [ - %Callback{ - event: :after_iteration, - fun: &Callback.eval_metrics/1, - name: :eval_metrics, - init_state: %{evals: evals_dmats, filter: fn {_, _} -> true end} - } - | callbacks - ] - else - callbacks - end + if Enum.any?(name_counts, &(elem(&1, 1) > 1)) do + str = name_counts |> Enum.sort() |> Enum.map_join("\n\n", &" * #{inspect(&1)}") + raise ArgumentError, "Found duplicate callback names.\n\nName counts:\n\n#{str}\n" + end - default = %{ - before_iteration: [], - after_iteration: [], - before_training: [], - after_training: [], - init_state: %{} + state = %State{ + booster: bst, + iteration: 0, + max_iteration: num_boost_rounds, + meta_vars: Map.new(callbacks, &{&1.name, &1.init_state}) } - env = - callbacks - |> Enum.reverse() - |> Enum.reduce(default, fn %Callback{} = callback, acc -> - acc = - case callback.event do - :before_iteration -> - %{acc | before_iteration: [callback.fun | acc[:before_iteration]]} + callbacks = Enum.group_by(callbacks, & &1.event, & &1.fun) - :after_iteration -> - %{acc | after_iteration: [callback.fun | acc[:after_iteration]]} + state = + state + |> run_callbacks(callbacks, :before_training) + |> run_training(callbacks, dmat, objective) + |> run_callbacks(callbacks, :after_training) - :before_training -> - %{acc | before_training: [callback.fun | acc[:before_training]]} + state.booster + end - :after_training -> - %{acc | after_training: [callback.fun | acc[:after_training]]} + defp run_callbacks(%{status: :halt} = state, _callbacks, _event), do: state - _ -> - raise ArgumentError, "Invalid callback: #{inspect(callback)}" - end + defp run_callbacks(%{status: :cont} = state, callbacks, event) do + Enum.reduce_while(callbacks[event] || [], state, fn callback, state -> + state = callback.(state) + {state.status, state} + end) + end - case callback.name do - nil -> acc - name -> put_in(acc[:init_state][name], callback.init_state) - end - end) + defp run_training(%{status: :halt} = state, _callbacks, _dmat, _objective), do: state - start_iteration = 0 - num_boost_rounds = Keyword.fetch!(opts, :num_boost_rounds) + defp run_training(%{status: :cont} = state, callbacks, dmat, objective) do + Enum.reduce_while(1..state.max_iteration, state, fn iter, state -> + state = + state + |> run_callbacks(callbacks, :before_iteration) + |> run_iteration(dmat, iter, objective) + |> run_callbacks(callbacks, :after_iteration) - init_state = %State{ - booster: bst, - iteration: 0, - max_iteration: num_boost_rounds, - meta_vars: env[:init_state] - } + {state.status, state} + end) + end + + defp run_iteration(%{status: :halt} = state, _dmat, _iter, _objective), do: state + + defp run_iteration(%{status: :cont} = state, dmat, iter, objective) do + :ok = Booster.update(state.booster, dmat, iter, objective) + %{state | iteration: iter} + end - {status, state} = - case run_callbacks(env[:before_training], init_state) do - {:halt, state} -> - {:halted, state} - - {:cont, state} -> - Enum.reduce_while( - start_iteration..(num_boost_rounds - 1), - {:cont, state}, - fn iter, {_, iter_state} -> - case run_callbacks(env[:before_iteration], iter_state) do - {:halt, state} -> - {:halt, {:halted, state}} - - {:cont, state} -> - Booster.update(state.booster, dmat, iter, objective) - - case run_callbacks(env[:after_iteration], %{state | booster: bst}) do - {:halt, state} -> - {:halt, {:halted, state}} - - {:cont, state} -> - {:cont, {:cont, %{state | iteration: state.iteration + 1}}} - end - end - end - ) - - _ -> - raise "invalid return value from before_training callback" + defp default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) do + default_callbacks = [] + + default_callbacks = + if learning_rates do + lr_scheduler = %Callback{ + event: :before_iteration, + fun: &Callback.lr_scheduler/1, + name: :lr_scheduler, + init_state: %{learning_rates: learning_rates} + } + + [lr_scheduler | default_callbacks] + else + default_callbacks end - case status do - :halted -> - state.booster + default_callbacks = + if verbose_eval != 0 and evals_dmats != [] do + monitor_metrics = %Callback{ + event: :after_iteration, + fun: &Callback.monitor_metrics/1, + name: :monitor_metrics, + init_state: %{period: verbose_eval, filter: fn {_, _} -> true end} + } - :cont -> - {_status, final_state} = run_callbacks(env[:after_training], state) - final_state.booster - end - end + [monitor_metrics | default_callbacks] + else + default_callbacks + end + + default_callbacks = + if early_stopping_rounds && evals_dmats != [] do + [{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats) + + # Default to the last metric + [%{"name" => metric_name} | _tail] = + EXGBoost.dump_config(bst) + |> Jason.decode!() + |> get_in(["learner", "metrics"]) + |> Enum.reverse() + + early_stop = %Callback{ + event: :after_iteration, + fun: &Callback.early_stop/1, + name: :early_stop, + init_state: %{ + patience: early_stopping_rounds, + best: nil, + since_last_improvement: 0, + mode: :min, + target_eval: target_eval, + target_metric: metric_name + } + } - defp run_callbacks(callbacks, state) do - callbacks - |> Enum.reduce_while({:cont, state}, fn callback, {_, state} -> - case callback.(state) do - {:cont, %State{} = state} -> - {:cont, {:cont, state}} - - {:halt, %State{} = state} -> - {:halt, {:halt, state}} - - invalid -> - raise ArgumentError, - "invalid value #{inspect(invalid)} returned from callback" <> - " Callback handler must return" <> - " a tuple of {status, state} where status is one of :cont," <> - " or :halt and state is an updated State struct" + [early_stop | default_callbacks] + else + default_callbacks end - end) + + default_callbacks = + if evals_dmats != [] do + eval_metrics = %Callback{ + event: :after_iteration, + fun: &Callback.eval_metrics/1, + name: :eval_metrics, + init_state: %{evals: evals_dmats, filter: fn {_, _} -> true end} + } + + [eval_metrics | default_callbacks] + else + default_callbacks + end + + default_callbacks end end diff --git a/lib/exgboost/training/callback.ex b/lib/exgboost/training/callback.ex index 15b69ee..dadd0c0 100644 --- a/lib/exgboost/training/callback.ex +++ b/lib/exgboost/training/callback.ex @@ -46,30 +46,35 @@ defmodule EXGBoost.Training.Callback do @enforce_keys [:event, :fun] defstruct [:event, :fun, :name, :init_state] - @doc """ - Factory for a new callback without an initial state. See `EXGBoost.Callback.new/4` for more details. - """ - @spec new( - event :: :before_training | :after_training | :before_iteration | :after_iteration, - fun :: (State.t() -> {:cont, State.t()} | {:halt, State.t()}) - ) :: Callback.t() - def new(event, fun) do - new(event, fun, nil, %{}) - end + @type event :: :before_training | :after_training | :before_iteration | :after_iteration + @type fun :: (State.t() -> State.t()) + + @valid_events [:before_training, :after_training, :before_iteration, :after_iteration] @doc """ Factory for a new callback with an initial state. """ - @spec new( - event :: :before_training | :after_training | :before_iteration | :after_iteration, - fun :: (State.t() -> {:cont, State.t()} | {:halt, State.t()}), - name :: atom(), - init_state :: map() - ) :: Callback.t() - def new(event, fun, name, %{} = init_state) - when event in [:before_training, :after_training, :before_iteration, :after_iteration] and - is_atom(name) do + @spec new(event :: event(), fun :: fun(), name :: atom(), init_state :: any()) :: Callback.t() + def new(event, fun, name, init_state \\ %{}) + when event in @valid_events and is_function(fun, 1) and is_atom(name) and not is_nil(name) do %__MODULE__{event: event, fun: fun, name: name, init_state: init_state} + |> validate!() + end + + def validate!(%__MODULE__{} = callback) do + unless is_atom(callback.name) and not is_nil(callback.name) do + raise "A callback must have a non-`nil` atom for a name. Found: #{callback.name}." + end + + unless callback.event in @valid_events do + raise "Callback #{callback.name} must have an event in #{@valid_events}. Found: #{callback.event}." + end + + unless is_function(callback.fun, 1) do + raise "Callback #{callback.name} must have a 1-arity function. Found: #{callback.event}." + end + + callback end @doc """ @@ -83,12 +88,13 @@ defmodule EXGBoost.Training.Callback do %State{ booster: bst, meta_vars: %{lr_scheduler: %{learning_rates: learning_rates}}, - iteration: i + iteration: i, + status: :cont } = state ) do lr = if is_list(learning_rates), do: Enum.at(learning_rates, i), else: learning_rates.(i) boostr = EXGBoost.Booster.set_params(bst, learning_rate: lr) - {:cont, %{state | booster: boostr}} + %{state | booster: boostr} end # TODO: Ideally this would be generalized like it is in Axon to allow generic monitoring of metrics, @@ -110,20 +116,20 @@ defmodule EXGBoost.Training.Callback do def early_stop( %State{ booster: bst, - meta_vars: - %{ - early_stop: %{ - best: best, - patience: patience, - target_metric: target_metric, - target_eval: target_eval, - mode: mode, - since_last_improvement: since_last_improvement - } - } = meta_vars, - metrics: metrics + meta_vars: %{early_stop: early_stop} = meta_vars, + metrics: metrics, + status: :cont } = state ) do + %{ + best: best_score, + patience: patience, + target_metric: target_metric, + target_eval: target_eval, + mode: mode, + since_last_improvement: since_last_improvement + } = early_stop + unless Map.has_key?(metrics, target_eval) do raise ArgumentError, "target eval_set #{inspect(target_eval)} not found in metrics #{inspect(metrics)}" @@ -134,54 +140,34 @@ defmodule EXGBoost.Training.Callback do "target metric #{inspect(target_metric)} not found in metrics #{inspect(metrics)}" end - prev_criteria_value = best - - cur_criteria_value = metrics[target_eval][target_metric] + score = metrics[target_eval][target_metric] improved? = - case mode do - :min -> - prev_criteria_value == nil or - cur_criteria_value < prev_criteria_value - - :max -> - prev_criteria_value == nil or - cur_criteria_value > prev_criteria_value + cond do + best_score == nil -> true + mode == :min -> score < best_score + mode == :max -> score > best_score end - over_patience? = since_last_improvement >= patience - cond do improved? -> - updated_meta_vars = - meta_vars - |> put_in([:early_stop, :best], cur_criteria_value) - |> put_in([:early_stop, :since_last_improvement], 0) + early_stop = %{early_stop | best: score, since_last_improvement: 0} bst = bst - |> struct(best_iteration: state.iteration, best_score: cur_criteria_value) - |> EXGBoost.Booster.set_attr( - best_iteration: state.iteration, - best_score: cur_criteria_value - ) + |> struct(best_iteration: state.iteration, best_score: score) + |> EXGBoost.Booster.set_attr(best_iteration: state.iteration, best_score: score) - {:cont, %{state | meta_vars: updated_meta_vars, booster: bst}} + %{state | booster: bst, meta_vars: %{meta_vars | early_stop: early_stop}} - not improved? and not over_patience? -> - updated_meta_vars = - meta_vars - |> put_in([:early_stop, :since_last_improvement], since_last_improvement + 1) - - {:cont, %{state | meta_vars: updated_meta_vars}} + since_last_improvement < patience -> + early_stop = Map.update!(early_stop, :since_last_improvement, &(&1 + 1)) + %{state | meta_vars: %{meta_vars | early_stop: early_stop}} true -> - updated_meta_vars = - meta_vars - |> put_in([:early_stop, :since_last_improvement], since_last_improvement + 1) - - bst = struct(bst, best_iteration: state.iteration, best_score: cur_criteria_value) - {:halt, %{state | meta_vars: updated_meta_vars, booster: bst}} + early_stop = Map.update!(early_stop, :since_last_improvement, &(&1 + 1)) + bst = struct(bst, best_iteration: state.iteration, best_score: score) + %{state | booster: bst, meta_vars: %{meta_vars | early_stop: early_stop}, status: :halt} end end @@ -198,7 +184,8 @@ defmodule EXGBoost.Training.Callback do %State{ booster: bst, iteration: iter, - meta_vars: %{eval_metrics: %{evals: evals, filter: filter}} + meta_vars: %{eval_metrics: %{evals: evals, filter: filter}}, + status: :cont } = state ) do metrics = @@ -210,8 +197,7 @@ defmodule EXGBoost.Training.Callback do end) |> Map.filter(filter) - {:cont, %{state | metrics: metrics}} - {:cont, %{state | metrics: metrics}} + %{state | metrics: metrics} end @doc """ @@ -229,14 +215,14 @@ defmodule EXGBoost.Training.Callback do metrics: metrics, meta_vars: %{ monitor_metrics: %{period: period, filter: filter} - } + }, + status: :cont } = state ) do if period != 0 and rem(iteration, period) == 0 do - metrics = Map.filter(metrics, filter) - IO.puts("Iteration #{iteration}: #{inspect(metrics)}") + IO.puts("Iteration #{iteration}: #{inspect(Map.filter(metrics, filter))}") end - {:cont, state} + state end end diff --git a/lib/exgboost/training/state.ex b/lib/exgboost/training/state.ex index 15445ae..06c8e09 100644 --- a/lib/exgboost/training/state.ex +++ b/lib/exgboost/training/state.ex @@ -3,9 +3,19 @@ defmodule EXGBoost.Training.State do @enforce_keys [:booster] defstruct [ :booster, - meta_vars: %{}, iteration: 0, max_iteration: -1, - metrics: %{} + meta_vars: %{}, + metrics: %{}, + status: :cont ] + + def validate!(%__MODULE__{} = state) do + unless state.status in [:cont, :halt] do + raise ArgumentError, + "`status` must be `:cont` or `:halt`, found: `#{inspect(state.status)}`." + end + + state + end end diff --git a/test/exgboost_test.exs b/test/exgboost_test.exs index 9b15aca..60ffadb 100644 --- a/test/exgboost_test.exs +++ b/test/exgboost_test.exs @@ -310,4 +310,45 @@ defmodule EXGBoostTest do assert EXGBoost.ArrayInterface.get_tensor(array_interface) == tensor end + + describe "errors" do + setup %{key: key0} do + {nrows, ncols} = {10, 10} + {x, key1} = Nx.Random.normal(key0, 0, 1, shape: {nrows, ncols}) + {y, _key2} = Nx.Random.normal(key1, 0, 1, shape: {nrows}) + %{x: x, y: y} + end + + test "duplicate callback names result in an error", %{x: x, y: y} do + # This callback's name is the same as one of the default callbacks. + custom_callback = EXGBoost.Training.Callback.new(:before_training, & &1, :monitor_metrics) + + assert_raise ArgumentError, + """ + Found duplicate callback names. + + Name counts: + + * {:eval_metrics, 1} + + * {:monitor_metrics, 2} + """, + fn -> + EXGBoost.train(x, y, + callbacks: [custom_callback], + eval_metric: [:rmse, :logloss], + evals: [{x, y, "validation"}] + ) + end + end + + test "callback with bad function results in helpful error", %{x: x, y: y} do + bad_fun = fn state -> %{state | status: :bad_status} end + bad_callback = EXGBoost.Training.Callback.new(:before_training, bad_fun, :bad_callback) + + assert_raise ArgumentError, + "`status` must be `:cont` or `:halt`, found: `:bad_status`.", + fn -> EXGBoost.train(x, y, callbacks: [bad_callback]) end + end + end end