Skip to content

Commit

Permalink
Merge pull request #35 from billylanchantin/bl-refactor-training-logic
Browse files Browse the repository at this point in the history
Refactor training logic
  • Loading branch information
acalejos authored Jan 4, 2024
2 parents 3783796 + 2c42efb commit 9649af0
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 281 deletions.
359 changes: 155 additions & 204 deletions lib/exgboost/training.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,45 @@ defmodule EXGBoost.Training do

@spec train(DMatrix.t(), Keyword.t()) :: Booster.t()
def train(%DMatrix{} = dmat, opts \\ []) do
{opts, booster_params} =
Keyword.split(opts, [
:obj,
:num_boost_rounds,
:evals,
:verbose_eval,
:callbacks,
:learning_rates,
:early_stopping_rounds
])

opts =
Keyword.validate!(opts,
obj: nil,
num_boost_rounds: 10,
evals: [],
verbose_eval: true,
callbacks: [],
learning_rates: nil,
early_stopping_rounds: nil
)

learning_rates = Keyword.fetch!(opts, :learning_rates)

if not is_nil(learning_rates) and
not (is_function(learning_rates, 1) or is_list(learning_rates)) do
valid_opts = [
callbacks: [],
early_stopping_rounds: nil,
evals: [],
learning_rates: nil,
num_boost_rounds: 10,
obj: nil,
verbose_eval: true
]

{opts, booster_params} = Keyword.split(opts, Keyword.keys(valid_opts))

[
callbacks: callbacks,
early_stopping_rounds: early_stopping_rounds,
evals: evals,
learning_rates: learning_rates,
num_boost_rounds: num_boost_rounds,
obj: objective,
verbose_eval: verbose_eval
] = opts |> Keyword.validate!(valid_opts) |> Enum.sort()

unless is_nil(learning_rates) or is_function(learning_rates, 1) or is_list(learning_rates) do
raise ArgumentError, "learning_rates must be a function/1 or a list"
end

objective = Keyword.fetch!(opts, :obj)
evals = Keyword.fetch!(opts, :evals)
if early_stopping_rounds && evals == [] do
raise ArgumentError, "early_stopping_rounds requires at least one evaluation set"
end

verbose_eval =
case verbose_eval do
true -> 1
false -> 0
value -> value
end

evals_dmats =
Enum.map(evals, fn {%Nx.Tensor{} = x, %Nx.Tensor{} = y, name} ->
Enum.map(evals, fn {x, y, name} ->
{DMatrix.from_tensor(x, y, format: :dense), name}
end)

Expand All @@ -49,200 +54,146 @@ defmodule EXGBoost.Training do
booster_params
)

verbose_eval =
case Keyword.fetch!(opts, :verbose_eval) do
true -> 1
false -> 0
value -> value
end

callbacks = Keyword.fetch!(opts, :callbacks) |> Enum.reverse()
defaults =
default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds)

callbacks =
unless is_nil(opts[:learning_rates]) do
[
%Callback{
event: :before_iteration,
fun: &Callback.lr_scheduler/1,
name: :lr_scheduler,
init_state: %{learning_rates: learning_rates}
}
| callbacks
]
else
callbacks
end

callbacks =
unless verbose_eval == 0 or evals_dmats == [] do
[
%Callback{
event: :after_iteration,
fun: &Callback.monitor_metrics/1,
name: :monitor_metrics,
init_state: %{period: verbose_eval, filter: fn {_, _} -> true end}
}
| callbacks
]
else
callbacks
end
Enum.map(callbacks ++ defaults, fn %Callback{fun: fun} = callback ->
%{callback | fun: fn state -> state |> fun.() |> State.validate!() end}
end)

callbacks =
unless is_nil(opts[:early_stopping_rounds]) do
unless evals_dmats == [] do
[{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats)

# Default to the last metric
[%{"name" => metric_name} | _tail] =
EXGBoost.dump_config(bst)
|> Jason.decode!()
|> get_in(["learner", "metrics"])
|> Enum.reverse()

[
%Callback{
event: :after_iteration,
fun: &Callback.early_stop/1,
name: :early_stop,
init_state: %{
patience: opts[:early_stopping_rounds],
best: nil,
since_last_improvement: 0,
mode: :min,
target_eval: target_eval,
target_metric: metric_name
}
}
| callbacks
]
else
raise ArgumentError, "early_stopping_rounds requires at least one evaluation set"
end
else
callbacks
end
# Validate callbacks and ensure all names are unique.
Enum.each(callbacks, &Callback.validate!/1)
name_counts = Enum.frequencies_by(callbacks, & &1.name)

callbacks =
unless evals_dmats == [] do
[
%Callback{
event: :after_iteration,
fun: &Callback.eval_metrics/1,
name: :eval_metrics,
init_state: %{evals: evals_dmats, filter: fn {_, _} -> true end}
}
| callbacks
]
else
callbacks
end
if Enum.any?(name_counts, &(elem(&1, 1) > 1)) do
str = name_counts |> Enum.sort() |> Enum.map_join("\n\n", &" * #{inspect(&1)}")
raise ArgumentError, "Found duplicate callback names.\n\nName counts:\n\n#{str}\n"
end

default = %{
before_iteration: [],
after_iteration: [],
before_training: [],
after_training: [],
init_state: %{}
state = %State{
booster: bst,
iteration: 0,
max_iteration: num_boost_rounds,
meta_vars: Map.new(callbacks, &{&1.name, &1.init_state})
}

env =
callbacks
|> Enum.reverse()
|> Enum.reduce(default, fn %Callback{} = callback, acc ->
acc =
case callback.event do
:before_iteration ->
%{acc | before_iteration: [callback.fun | acc[:before_iteration]]}
callbacks = Enum.group_by(callbacks, & &1.event, & &1.fun)

:after_iteration ->
%{acc | after_iteration: [callback.fun | acc[:after_iteration]]}
state =
state
|> run_callbacks(callbacks, :before_training)
|> run_training(callbacks, dmat, objective)
|> run_callbacks(callbacks, :after_training)

:before_training ->
%{acc | before_training: [callback.fun | acc[:before_training]]}
state.booster
end

:after_training ->
%{acc | after_training: [callback.fun | acc[:after_training]]}
defp run_callbacks(%{status: :halt} = state, _callbacks, _event), do: state

_ ->
raise ArgumentError, "Invalid callback: #{inspect(callback)}"
end
defp run_callbacks(%{status: :cont} = state, callbacks, event) do
Enum.reduce_while(callbacks[event] || [], state, fn callback, state ->
state = callback.(state)
{state.status, state}
end)
end

case callback.name do
nil -> acc
name -> put_in(acc[:init_state][name], callback.init_state)
end
end)
defp run_training(%{status: :halt} = state, _callbacks, _dmat, _objective), do: state

start_iteration = 0
num_boost_rounds = Keyword.fetch!(opts, :num_boost_rounds)
defp run_training(%{status: :cont} = state, callbacks, dmat, objective) do
Enum.reduce_while(1..state.max_iteration, state, fn iter, state ->
state =
state
|> run_callbacks(callbacks, :before_iteration)
|> run_iteration(dmat, iter, objective)
|> run_callbacks(callbacks, :after_iteration)

init_state = %State{
booster: bst,
iteration: 0,
max_iteration: num_boost_rounds,
meta_vars: env[:init_state]
}
{state.status, state}
end)
end

defp run_iteration(%{status: :halt} = state, _dmat, _iter, _objective), do: state

defp run_iteration(%{status: :cont} = state, dmat, iter, objective) do
:ok = Booster.update(state.booster, dmat, iter, objective)
%{state | iteration: iter}
end

{status, state} =
case run_callbacks(env[:before_training], init_state) do
{:halt, state} ->
{:halted, state}

{:cont, state} ->
Enum.reduce_while(
start_iteration..(num_boost_rounds - 1),
{:cont, state},
fn iter, {_, iter_state} ->
case run_callbacks(env[:before_iteration], iter_state) do
{:halt, state} ->
{:halt, {:halted, state}}

{:cont, state} ->
Booster.update(state.booster, dmat, iter, objective)

case run_callbacks(env[:after_iteration], %{state | booster: bst}) do
{:halt, state} ->
{:halt, {:halted, state}}

{:cont, state} ->
{:cont, {:cont, %{state | iteration: state.iteration + 1}}}
end
end
end
)

_ ->
raise "invalid return value from before_training callback"
defp default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) do
default_callbacks = []

default_callbacks =
if learning_rates do
lr_scheduler = %Callback{
event: :before_iteration,
fun: &Callback.lr_scheduler/1,
name: :lr_scheduler,
init_state: %{learning_rates: learning_rates}
}

[lr_scheduler | default_callbacks]
else
default_callbacks
end

case status do
:halted ->
state.booster
default_callbacks =
if verbose_eval != 0 and evals_dmats != [] do
monitor_metrics = %Callback{
event: :after_iteration,
fun: &Callback.monitor_metrics/1,
name: :monitor_metrics,
init_state: %{period: verbose_eval, filter: fn {_, _} -> true end}
}

:cont ->
{_status, final_state} = run_callbacks(env[:after_training], state)
final_state.booster
end
end
[monitor_metrics | default_callbacks]
else
default_callbacks
end

default_callbacks =
if early_stopping_rounds && evals_dmats != [] do
[{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats)

# Default to the last metric
[%{"name" => metric_name} | _tail] =
EXGBoost.dump_config(bst)
|> Jason.decode!()
|> get_in(["learner", "metrics"])
|> Enum.reverse()

early_stop = %Callback{
event: :after_iteration,
fun: &Callback.early_stop/1,
name: :early_stop,
init_state: %{
patience: early_stopping_rounds,
best: nil,
since_last_improvement: 0,
mode: :min,
target_eval: target_eval,
target_metric: metric_name
}
}

defp run_callbacks(callbacks, state) do
callbacks
|> Enum.reduce_while({:cont, state}, fn callback, {_, state} ->
case callback.(state) do
{:cont, %State{} = state} ->
{:cont, {:cont, state}}

{:halt, %State{} = state} ->
{:halt, {:halt, state}}

invalid ->
raise ArgumentError,
"invalid value #{inspect(invalid)} returned from callback" <>
" Callback handler must return" <>
" a tuple of {status, state} where status is one of :cont," <>
" or :halt and state is an updated State struct"
[early_stop | default_callbacks]
else
default_callbacks
end
end)

default_callbacks =
if evals_dmats != [] do
eval_metrics = %Callback{
event: :after_iteration,
fun: &Callback.eval_metrics/1,
name: :eval_metrics,
init_state: %{evals: evals_dmats, filter: fn {_, _} -> true end}
}

[eval_metrics | default_callbacks]
else
default_callbacks
end

default_callbacks
end
end
Loading

0 comments on commit 9649af0

Please sign in to comment.