Skip to content

Commit

Permalink
Merge branch 'main' into plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
acalejos committed Jan 20, 2024
2 parents 9610d51 + 9649af0 commit 975e285
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 370 deletions.
386 changes: 164 additions & 222 deletions lib/exgboost/training.ex

Large diffs are not rendered by default.

135 changes: 61 additions & 74 deletions lib/exgboost/training/callback.ex
Original file line number Diff line number Diff line change
Expand Up @@ -46,30 +46,35 @@ defmodule EXGBoost.Training.Callback do
@enforce_keys [:event, :fun]
defstruct [:event, :fun, :name, :init_state]

@doc """
Factory for a new callback without an initial state. See `EXGBoost.Callback.new/4` for more details.
"""
@spec new(
event :: :before_training | :after_training | :before_iteration | :after_iteration,
fun :: (State.t() -> {:cont, State.t()} | {:halt, State.t()})
) :: Callback.t()
def new(event, fun) do
new(event, fun, nil, %{})
end
@type event :: :before_training | :after_training | :before_iteration | :after_iteration
@type fun :: (State.t() -> State.t())

@valid_events [:before_training, :after_training, :before_iteration, :after_iteration]

@doc """
Factory for a new callback with an initial state.
"""
@spec new(
event :: :before_training | :after_training | :before_iteration | :after_iteration,
fun :: (State.t() -> {:cont, State.t()} | {:halt, State.t()}),
name :: atom(),
init_state :: map()
) :: Callback.t()
def new(event, fun, name, %{} = init_state)
when event in [:before_training, :after_training, :before_iteration, :after_iteration] and
is_atom(name) do
@spec new(event :: event(), fun :: fun(), name :: atom(), init_state :: any()) :: Callback.t()
def new(event, fun, name, init_state \\ %{})
when event in @valid_events and is_function(fun, 1) and is_atom(name) and not is_nil(name) do
%__MODULE__{event: event, fun: fun, name: name, init_state: init_state}
|> validate!()
end

def validate!(%__MODULE__{} = callback) do
unless is_atom(callback.name) and not is_nil(callback.name) do
raise "A callback must have a non-`nil` atom for a name. Found: #{callback.name}."
end

unless callback.event in @valid_events do
raise "Callback #{callback.name} must have an event in #{@valid_events}. Found: #{callback.event}."
end

unless is_function(callback.fun, 1) do
raise "Callback #{callback.name} must have a 1-arity function. Found: #{callback.event}."
end

callback
end

@doc """
Expand All @@ -83,12 +88,13 @@ defmodule EXGBoost.Training.Callback do
%State{
booster: bst,
meta_vars: %{lr_scheduler: %{learning_rates: learning_rates}},
iteration: i
iteration: i,
status: :cont
} = state
) do
lr = if is_list(learning_rates), do: Enum.at(learning_rates, i), else: learning_rates.(i)
boostr = EXGBoost.Booster.set_params(bst, learning_rate: lr)
{:cont, %{state | booster: boostr}}
%{state | booster: boostr}
end

# TODO: Ideally this would be generalized like it is in Axon to allow generic monitoring of metrics,
Expand All @@ -110,20 +116,20 @@ defmodule EXGBoost.Training.Callback do
def early_stop(
%State{
booster: bst,
meta_vars:
%{
early_stop: %{
best: best,
patience: patience,
target_metric: target_metric,
target_eval: target_eval,
mode: mode,
since_last_improvement: since_last_improvement
}
} = meta_vars,
metrics: metrics
meta_vars: %{early_stop: early_stop} = meta_vars,
metrics: metrics,
status: :cont
} = state
) do
%{
best: best_score,
patience: patience,
target_metric: target_metric,
target_eval: target_eval,
mode: mode,
since_last_improvement: since_last_improvement
} = early_stop

unless Map.has_key?(metrics, target_eval) do
raise ArgumentError,
"target eval_set #{inspect(target_eval)} not found in metrics #{inspect(metrics)}"
Expand All @@ -134,54 +140,34 @@ defmodule EXGBoost.Training.Callback do
"target metric #{inspect(target_metric)} not found in metrics #{inspect(metrics)}"
end

prev_criteria_value = best

cur_criteria_value = metrics[target_eval][target_metric]
score = metrics[target_eval][target_metric]

improved? =
case mode do
:min ->
prev_criteria_value == nil or
cur_criteria_value < prev_criteria_value

:max ->
prev_criteria_value == nil or
cur_criteria_value > prev_criteria_value
cond do
best_score == nil -> true
mode == :min -> score < best_score
mode == :max -> score > best_score
end

over_patience? = since_last_improvement >= patience

cond do
improved? ->
updated_meta_vars =
meta_vars
|> put_in([:early_stop, :best], cur_criteria_value)
|> put_in([:early_stop, :since_last_improvement], 0)
early_stop = %{early_stop | best: score, since_last_improvement: 0}

bst =
bst
|> struct(best_iteration: state.iteration, best_score: cur_criteria_value)
|> EXGBoost.Booster.set_attr(
best_iteration: state.iteration,
best_score: cur_criteria_value
)
|> struct(best_iteration: state.iteration, best_score: score)
|> EXGBoost.Booster.set_attr(best_iteration: state.iteration, best_score: score)

{:cont, %{state | meta_vars: updated_meta_vars, booster: bst}}
%{state | booster: bst, meta_vars: %{meta_vars | early_stop: early_stop}}

not improved? and not over_patience? ->
updated_meta_vars =
meta_vars
|> put_in([:early_stop, :since_last_improvement], since_last_improvement + 1)

{:cont, %{state | meta_vars: updated_meta_vars}}
since_last_improvement < patience ->
early_stop = Map.update!(early_stop, :since_last_improvement, &(&1 + 1))
%{state | meta_vars: %{meta_vars | early_stop: early_stop}}

true ->
updated_meta_vars =
meta_vars
|> put_in([:early_stop, :since_last_improvement], since_last_improvement + 1)

bst = struct(bst, best_iteration: state.iteration, best_score: cur_criteria_value)
{:halt, %{state | meta_vars: updated_meta_vars, booster: bst}}
early_stop = Map.update!(early_stop, :since_last_improvement, &(&1 + 1))
bst = struct(bst, best_iteration: state.iteration, best_score: score)
%{state | booster: bst, meta_vars: %{meta_vars | early_stop: early_stop}, status: :halt}
end
end

Expand All @@ -198,7 +184,8 @@ defmodule EXGBoost.Training.Callback do
%State{
booster: bst,
iteration: iter,
meta_vars: %{eval_metrics: %{evals: evals, filter: filter}}
meta_vars: %{eval_metrics: %{evals: evals, filter: filter}},
status: :cont
} = state
) do
metrics =
Expand All @@ -210,7 +197,7 @@ defmodule EXGBoost.Training.Callback do
end)
|> Map.filter(filter)

{:cont, %{state | metrics: metrics}}
%{state | metrics: metrics}
end

@doc """
Expand All @@ -228,14 +215,14 @@ defmodule EXGBoost.Training.Callback do
metrics: metrics,
meta_vars: %{
monitor_metrics: %{period: period, filter: filter}
}
},
status: :cont
} = state
) do
if period != 0 and rem(iteration, period) == 0 do
metrics = Map.filter(metrics, filter)
IO.puts("Iteration #{iteration}: #{inspect(metrics)}")
IO.puts("Iteration #{iteration}: #{inspect(Map.filter(metrics, filter))}")
end

{:cont, state}
state
end
end
14 changes: 12 additions & 2 deletions lib/exgboost/training/state.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@ defmodule EXGBoost.Training.State do
@enforce_keys [:booster]
defstruct [
:booster,
meta_vars: %{},
iteration: 0,
max_iteration: -1,
metrics: %{}
meta_vars: %{},
metrics: %{},
status: :cont
]

def validate!(%__MODULE__{} = state) do
unless state.status in [:cont, :halt] do
raise ArgumentError,
"`status` must be `:cont` or `:halt`, found: `#{inspect(state.status)}`."
end

state
end
end
8 changes: 8 additions & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ defmodule EXGBoost.MixProject do
defp docs do
[
main: "EXGBoost",
extras: [
"notebooks/compiled_benchmarks.livemd",
"notebooks/iris_classification.livemd",
"notebooks/quantile_prediction_interval.livemd"
],
groups_for_extras: [
Notebooks: Path.wildcard("notebooks/*.livemd")
],
before_closing_body_tag: &before_closing_body_tag/1
]
end
Expand Down
9 changes: 4 additions & 5 deletions notebooks/compiled_benchmarks.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Mix.install([
{:mockingjay, github: "acalejos/mockingjay"},
{:nx, "~> 0.5", override: true},
{:exla, "~> 0.5"},
{:scholar, "~> 0.2"},
{:benchee, "~> 1.0"}
])
```
Expand Down Expand Up @@ -71,10 +72,9 @@ funcs = %{
## Run Time Benchmarks

```elixir
benches = Enum.reduce(funcs, %{}, fn {k, v}, acc -> Map.put(acc, k, fn -> v.(x_train) end) end)
benches = Map.new(funcs, fn {k, v} -> {k, v.(x_train)} end)

Benchee.run(
benches,
Benchee.run(benches,
time: 10,
memory_time: 2,
warmup: 5
Expand All @@ -89,11 +89,10 @@ Nx.default_backend(Nx.BinaryBackend)

accuracies =
Enum.reduce(funcs, %{}, fn {name, pred_fn}, acc ->

accuracy =
pred_fn.(x_test)
|> Nx.argmax(axis: -1)
|> then(&Scholar.Metrics.accuracy(y_test, &1))
|> then(&Scholar.Metrics.Classification.accuracy(y_test, &1))
|> Nx.to_number()

Map.put(acc, name, accuracy)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/iris_classification.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ To get predictions from a trained booster, we can just call `EXGBoost.predict/2`

```elixir
preds = EXGBoost.predict(booster, x_test) |> Nx.argmax(axis: -1)
Scholar.Metrics.accuracy(y_test, preds)
Scholar.Metrics.Classification.accuracy(y_test, preds)
```

And that's it! We've successfully trained a booster on the Iris dataset with `EXGBoost`.
Loading

0 comments on commit 975e285

Please sign in to comment.