-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor training logic #35
Refactor training logic #35
Conversation
bst = struct(bst, best_iteration: state.iteration, best_score: cur_criteria_value) | ||
{:halt, %{state | meta_vars: updated_meta_vars, booster: bst}} | ||
early_stop = Map.update!(early_stop, :since_last_improvement, &(&1 + 1)) | ||
bst = struct(bst, best_iteration: state.iteration, best_score: score) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this possibly a bug? It doesn't match the other cond
branch where:
bst =
bst
|> struct(best_iteration: state.iteration, best_score: score)
|> EXGBoost.Booster.set_attr(best_iteration: state.iteration, best_score: score)
I haven't had time time to dig into the particulars of what's going on in Booster
yet, so I wasn't sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it should probably either do bot add to struct and update the attirbute or just update the struct. The Booster attributes are held in the actual booster so are actually mutable. The implementation to use attributes exists but I dont think should really be used outside of maybe seeing any attributes from a serialized model that maybe were set elsewhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also Im glad you cleaned up a lot of this code because IDK why I was using put_in
so much
Enum.map(callbacks ++ defaults, fn %Callback{fun: fun} = callback -> | ||
%{callback | fun: fn state -> state |> fun.() |> State.validate!() end} | ||
end) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a tiny bit of overhead, but I think it'll result in better error messages.
} = state | ||
) do | ||
if period != 0 and rem(iteration, period) == 0 do | ||
metrics = Map.filter(metrics, filter) | ||
IO.puts("Iteration #{iteration}: #{inspect(metrics)}") | ||
IO.puts("Iteration #{iteration}: #{inspect(Map.filter(metrics, filter))}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know if you have any opinions about improving this logging. It's pretty bare-bones right now. Not necessarily for this PR though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think bare-bones is fine. And yeah, not for this PR.
However, an example in the docs that overrides this callback to, say, use Logger
might be quite instructive.
(Also this was such I nit-picky change on my part, lol. The code before was completely fine.)
iteration: 0, | ||
max_iteration: -1, | ||
metrics: %{} | ||
meta_vars: %{}, | ||
metrics: %{}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Im trying to remember why I didnt just use the meta_vars
to hold the metrics in...
|> validate!() | ||
end | ||
|
||
def validate!(%__MODULE__{} = callback) do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My only hangup here is that the only reason we need a name is to look up associated meta_vars
, so there should still be the ability to create a callback that has no state and no name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I totally forgot about this part. Sorry I meant to call it out.
My thinking was this: we can't use name
as the key in meta_vars
unless all names are unique. Multiple callbacks with name: nil
break that. Since I couldn't find an instance where name: nil
was actually used, I figured I'd just make it required.
If we do want to allow name: nil
, then we need to alter the creation of meta_vars
to account for duplicates. Come to think of it, we need to do that anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thought was that you just check for nil
name and dont even look at meta vars if its nil
and instead just run the func. And then the name
s would have to be unique still
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I addressed the duplicate name issue and wrote a test.
This doesn't address the question about how to handle name: nil
, but it prevents us clobbering callbacks when we build meta_vars
. We can easily filter name: nil
callbacks out or handle them with some other convention.
Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, I think the other way to handle this is to do something like:
meta_vars =
callbacks
|> Enum.sort_by(& &1.name) # for determinism
|> Enum.with_index()
|> Map.new(fn {callback, i} -> {i, callback.init_state} end)
That side-steps the name issue altogether. But then if duplicate names are present, it's not possible to provide an error message like "Callback X did the wrong thing.".
Adds:
%Training.State{status: :cont | :halt}
This way the
fun
on%Training.Callback{}
can have the spec:fun :: State.t() -> State.t()
instead of wrapping the output in a tuple. It also means we can write code like this:
since now
run_callbacks/3
andrun_training/4
(new) can branch onstate.status
.I did some other code shuffling while I was in there, but none of it is crucial.