Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor training logic #35

Merged
merged 4 commits into from
Jan 4, 2024

Conversation

billylanchantin
Copy link
Contributor

Adds:

  • %Training.State{status: :cont | :halt}

This way the fun on %Training.Callback{} can have the spec:

  • fun :: State.t() -> State.t()

instead of wrapping the output in a tuple. It also means we can write code like this:

state =
  state
  |> run_callbacks(callbacks, :before_training)
  |> run_training(callbacks, dmat, objective)
  |> run_callbacks(callbacks, :after_training)

since now run_callbacks/3 and run_training/4 (new) can branch on state.status.

I did some other code shuffling while I was in there, but none of it is crucial.

bst = struct(bst, best_iteration: state.iteration, best_score: cur_criteria_value)
{:halt, %{state | meta_vars: updated_meta_vars, booster: bst}}
early_stop = Map.update!(early_stop, :since_last_improvement, &(&1 + 1))
bst = struct(bst, best_iteration: state.iteration, best_score: score)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this possibly a bug? It doesn't match the other cond branch where:

bst =
  bst
  |> struct(best_iteration: state.iteration, best_score: score)
  |> EXGBoost.Booster.set_attr(best_iteration: state.iteration, best_score: score)

I haven't had time time to dig into the particulars of what's going on in Booster yet, so I wasn't sure.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it should probably either do bot add to struct and update the attirbute or just update the struct. The Booster attributes are held in the actual booster so are actually mutable. The implementation to use attributes exists but I dont think should really be used outside of maybe seeing any attributes from a serialized model that maybe were set elsewhere

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also Im glad you cleaned up a lot of this code because IDK why I was using put_in so much

Comment on lines +65 to +67
Enum.map(callbacks ++ defaults, fn %Callback{fun: fun} = callback ->
%{callback | fun: fn state -> state |> fun.() |> State.validate!() end}
end)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a tiny bit of overhead, but I think it'll result in better error messages.

} = state
) do
if period != 0 and rem(iteration, period) == 0 do
metrics = Map.filter(metrics, filter)
IO.puts("Iteration #{iteration}: #{inspect(metrics)}")
IO.puts("Iteration #{iteration}: #{inspect(Map.filter(metrics, filter))}")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if you have any opinions about improving this logging. It's pretty bare-bones right now. Not necessarily for this PR though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think bare-bones is fine. And yeah, not for this PR.

However, an example in the docs that overrides this callback to, say, use Logger might be quite instructive.

(Also this was such I nit-picky change on my part, lol. The code before was completely fine.)

iteration: 0,
max_iteration: -1,
metrics: %{}
meta_vars: %{},
metrics: %{},
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im trying to remember why I didnt just use the meta_vars to hold the metrics in...

|> validate!()
end

def validate!(%__MODULE__{} = callback) do
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only hangup here is that the only reason we need a name is to look up associated meta_vars, so there should still be the ability to create a callback that has no state and no name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I totally forgot about this part. Sorry I meant to call it out.

My thinking was this: we can't use name as the key in meta_vars unless all names are unique. Multiple callbacks with name: nil break that. Since I couldn't find an instance where name: nil was actually used, I figured I'd just make it required.

If we do want to allow name: nil, then we need to alter the creation of meta_vars to account for duplicates. Come to think of it, we need to do that anyway.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought was that you just check for nil name and dont even look at meta vars if its nil and instead just run the func. And then the names would have to be unique still

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed the duplicate name issue and wrote a test.

f1977bf

This doesn't address the question about how to handle name: nil, but it prevents us clobbering callbacks when we build meta_vars. We can easily filter name: nil callbacks out or handle them with some other convention.

Thoughts?

Copy link
Contributor Author

@billylanchantin billylanchantin Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I think the other way to handle this is to do something like:

meta_vars =
  callbacks
  |> Enum.sort_by(& &1.name) # for determinism
  |> Enum.with_index()
  |> Map.new(fn {callback, i} -> {i, callback.init_state} end)

That side-steps the name issue altogether. But then if duplicate names are present, it's not possible to provide an error message like "Callback X did the wrong thing.".

@acalejos acalejos merged commit 9649af0 into acalejos:main Jan 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants