Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor training logic #35

Merged
merged 4 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
359 changes: 155 additions & 204 deletions lib/exgboost/training.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,45 @@ defmodule EXGBoost.Training do

@spec train(DMatrix.t(), Keyword.t()) :: Booster.t()
def train(%DMatrix{} = dmat, opts \\ []) do
{opts, booster_params} =
Keyword.split(opts, [
:obj,
:num_boost_rounds,
:evals,
:verbose_eval,
:callbacks,
:learning_rates,
:early_stopping_rounds
])

opts =
Keyword.validate!(opts,
obj: nil,
num_boost_rounds: 10,
evals: [],
verbose_eval: true,
callbacks: [],
learning_rates: nil,
early_stopping_rounds: nil
)

learning_rates = Keyword.fetch!(opts, :learning_rates)

if not is_nil(learning_rates) and
not (is_function(learning_rates, 1) or is_list(learning_rates)) do
valid_opts = [
callbacks: [],
early_stopping_rounds: nil,
evals: [],
learning_rates: nil,
num_boost_rounds: 10,
obj: nil,
verbose_eval: true
]

{opts, booster_params} = Keyword.split(opts, Keyword.keys(valid_opts))

[
callbacks: callbacks,
early_stopping_rounds: early_stopping_rounds,
evals: evals,
learning_rates: learning_rates,
num_boost_rounds: num_boost_rounds,
obj: objective,
verbose_eval: verbose_eval
] = opts |> Keyword.validate!(valid_opts) |> Enum.sort()

unless is_nil(learning_rates) or is_function(learning_rates, 1) or is_list(learning_rates) do
raise ArgumentError, "learning_rates must be a function/1 or a list"
end

objective = Keyword.fetch!(opts, :obj)
evals = Keyword.fetch!(opts, :evals)
if early_stopping_rounds && evals == [] do
raise ArgumentError, "early_stopping_rounds requires at least one evaluation set"
end

verbose_eval =
case verbose_eval do
true -> 1
false -> 0
value -> value
end

evals_dmats =
Enum.map(evals, fn {%Nx.Tensor{} = x, %Nx.Tensor{} = y, name} ->
Enum.map(evals, fn {x, y, name} ->
{DMatrix.from_tensor(x, y, format: :dense), name}
end)

Expand All @@ -49,200 +54,146 @@ defmodule EXGBoost.Training do
booster_params
)

verbose_eval =
case Keyword.fetch!(opts, :verbose_eval) do
true -> 1
false -> 0
value -> value
end

callbacks = Keyword.fetch!(opts, :callbacks) |> Enum.reverse()
defaults =
default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds)

callbacks =
unless is_nil(opts[:learning_rates]) do
[
%Callback{
event: :before_iteration,
fun: &Callback.lr_scheduler/1,
name: :lr_scheduler,
init_state: %{learning_rates: learning_rates}
}
| callbacks
]
else
callbacks
end

callbacks =
unless verbose_eval == 0 or evals_dmats == [] do
[
%Callback{
event: :after_iteration,
fun: &Callback.monitor_metrics/1,
name: :monitor_metrics,
init_state: %{period: verbose_eval, filter: fn {_, _} -> true end}
}
| callbacks
]
else
callbacks
end
Enum.map(callbacks ++ defaults, fn %Callback{fun: fun} = callback ->
%{callback | fun: fn state -> state |> fun.() |> State.validate!() end}
end)
Comment on lines +61 to +63
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a tiny bit of overhead, but I think it'll result in better error messages.


callbacks =
unless is_nil(opts[:early_stopping_rounds]) do
unless evals_dmats == [] do
[{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats)

# Default to the last metric
[%{"name" => metric_name} | _tail] =
EXGBoost.dump_config(bst)
|> Jason.decode!()
|> get_in(["learner", "metrics"])
|> Enum.reverse()

[
%Callback{
event: :after_iteration,
fun: &Callback.early_stop/1,
name: :early_stop,
init_state: %{
patience: opts[:early_stopping_rounds],
best: nil,
since_last_improvement: 0,
mode: :min,
target_eval: target_eval,
target_metric: metric_name
}
}
| callbacks
]
else
raise ArgumentError, "early_stopping_rounds requires at least one evaluation set"
end
else
callbacks
end
# Validate callbacks and ensure all names are unique.
Enum.each(callbacks, &Callback.validate!/1)
name_counts = Enum.frequencies_by(callbacks, & &1.name)

callbacks =
unless evals_dmats == [] do
[
%Callback{
event: :after_iteration,
fun: &Callback.eval_metrics/1,
name: :eval_metrics,
init_state: %{evals: evals_dmats, filter: fn {_, _} -> true end}
}
| callbacks
]
else
callbacks
end
if Enum.any?(name_counts, &(elem(&1, 1) > 1)) do
str = name_counts |> Enum.sort() |> Enum.map_join("\n\n", &" * #{inspect(&1)}")
raise ArgumentError, "Found duplicate callback names.\n\nName counts:\n\n#{str}\n"
end

default = %{
before_iteration: [],
after_iteration: [],
before_training: [],
after_training: [],
init_state: %{}
state = %State{
booster: bst,
iteration: 0,
max_iteration: num_boost_rounds,
meta_vars: Map.new(callbacks, &{&1.name, &1.init_state})
}

env =
callbacks
|> Enum.reverse()
|> Enum.reduce(default, fn %Callback{} = callback, acc ->
acc =
case callback.event do
:before_iteration ->
%{acc | before_iteration: [callback.fun | acc[:before_iteration]]}
callbacks = Enum.group_by(callbacks, & &1.event, & &1.fun)

:after_iteration ->
%{acc | after_iteration: [callback.fun | acc[:after_iteration]]}
state =
state
|> run_callbacks(callbacks, :before_training)
|> run_training(callbacks, dmat, objective)
|> run_callbacks(callbacks, :after_training)

:before_training ->
%{acc | before_training: [callback.fun | acc[:before_training]]}
state.booster
end

:after_training ->
%{acc | after_training: [callback.fun | acc[:after_training]]}
defp run_callbacks(%{status: :halt} = state, _callbacks, _event), do: state

_ ->
raise ArgumentError, "Invalid callback: #{inspect(callback)}"
end
defp run_callbacks(%{status: :cont} = state, callbacks, event) do
Enum.reduce_while(callbacks[event] || [], state, fn callback, state ->
state = callback.(state)
{state.status, state}
end)
end

case callback.name do
nil -> acc
name -> put_in(acc[:init_state][name], callback.init_state)
end
end)
defp run_training(%{status: :halt} = state, _callbacks, _dmat, _objective), do: state

start_iteration = 0
num_boost_rounds = Keyword.fetch!(opts, :num_boost_rounds)
defp run_training(%{status: :cont} = state, callbacks, dmat, objective) do
Enum.reduce_while(1..state.max_iteration, state, fn iter, state ->
state =
state
|> run_callbacks(callbacks, :before_iteration)
|> run_iteration(dmat, iter, objective)
|> run_callbacks(callbacks, :after_iteration)

init_state = %State{
booster: bst,
iteration: 0,
max_iteration: num_boost_rounds,
meta_vars: env[:init_state]
}
{state.status, state}
end)
end

defp run_iteration(%{status: :halt} = state, _dmat, _iter, _objective), do: state

defp run_iteration(%{status: :cont} = state, dmat, iter, objective) do
:ok = Booster.update(state.booster, dmat, iter, objective)
%{state | iteration: iter}
end

{status, state} =
case run_callbacks(env[:before_training], init_state) do
{:halt, state} ->
{:halted, state}

{:cont, state} ->
Enum.reduce_while(
start_iteration..(num_boost_rounds - 1),
{:cont, state},
fn iter, {_, iter_state} ->
case run_callbacks(env[:before_iteration], iter_state) do
{:halt, state} ->
{:halt, {:halted, state}}

{:cont, state} ->
Booster.update(state.booster, dmat, iter, objective)

case run_callbacks(env[:after_iteration], %{state | booster: bst}) do
{:halt, state} ->
{:halt, {:halted, state}}

{:cont, state} ->
{:cont, {:cont, %{state | iteration: state.iteration + 1}}}
end
end
end
)

_ ->
raise "invalid return value from before_training callback"
defp default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) do
default_callbacks = []

default_callbacks =
if learning_rates do
lr_scheduler = %Callback{
event: :before_iteration,
fun: &Callback.lr_scheduler/1,
name: :lr_scheduler,
init_state: %{learning_rates: learning_rates}
}

[lr_scheduler | default_callbacks]
else
default_callbacks
end

case status do
:halted ->
state.booster
default_callbacks =
if verbose_eval != 0 and evals_dmats != [] do
monitor_metrics = %Callback{
event: :after_iteration,
fun: &Callback.monitor_metrics/1,
name: :monitor_metrics,
init_state: %{period: verbose_eval, filter: fn {_, _} -> true end}
}

:cont ->
{_status, final_state} = run_callbacks(env[:after_training], state)
final_state.booster
end
end
[monitor_metrics | default_callbacks]
else
default_callbacks
end

default_callbacks =
if early_stopping_rounds && evals_dmats != [] do
[{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats)

# Default to the last metric
[%{"name" => metric_name} | _tail] =
EXGBoost.dump_config(bst)
|> Jason.decode!()
|> get_in(["learner", "metrics"])
|> Enum.reverse()

early_stop = %Callback{
event: :after_iteration,
fun: &Callback.early_stop/1,
name: :early_stop,
init_state: %{
patience: early_stopping_rounds,
best: nil,
since_last_improvement: 0,
mode: :min,
target_eval: target_eval,
target_metric: metric_name
}
}

defp run_callbacks(callbacks, state) do
callbacks
|> Enum.reduce_while({:cont, state}, fn callback, {_, state} ->
case callback.(state) do
{:cont, %State{} = state} ->
{:cont, {:cont, state}}

{:halt, %State{} = state} ->
{:halt, {:halt, state}}

invalid ->
raise ArgumentError,
"invalid value #{inspect(invalid)} returned from callback" <>
" Callback handler must return" <>
" a tuple of {status, state} where status is one of :cont," <>
" or :halt and state is an updated State struct"
[early_stop | default_callbacks]
else
default_callbacks
end
end)

default_callbacks =
if evals_dmats != [] do
eval_metrics = %Callback{
event: :after_iteration,
fun: &Callback.eval_metrics/1,
name: :eval_metrics,
init_state: %{evals: evals_dmats, filter: fn {_, _} -> true end}
}

[eval_metrics | default_callbacks]
else
default_callbacks
end

default_callbacks
end
end
Loading