-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Brett Shollenberger <brett.shollenberger@gmail.com>
- Loading branch information
1 parent
066362c
commit 73c4b4c
Showing
6 changed files
with
289 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
## 0.8.1 (unreleased) | ||
|
||
- Added support for callbacks | ||
|
||
## 0.8.0 (2023-09-13) | ||
|
||
- Updated XGBoost to 2.0.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
module XGBoost | ||
class CallbackContainer | ||
def initialize(callbacks) | ||
@callbacks = callbacks | ||
callbacks.each do |callback| | ||
unless callback.is_a?(TrainingCallback) | ||
raise TypeError, "callback must be an instance of XGBoost::TrainingCallback" | ||
end | ||
end | ||
|
||
@history = {} | ||
end | ||
|
||
def before_training(model) | ||
@callbacks.each do |callback| | ||
model = callback.before_training(model) | ||
unless model.is_a?(Booster) | ||
raise TypeError, "before_training should return the model" | ||
end | ||
end | ||
model | ||
end | ||
|
||
def after_training(model) | ||
@callbacks.each do |callback| | ||
model = callback.after_training(model) | ||
unless model.is_a?(Booster) | ||
raise TypeError, "after_training should return the model" | ||
end | ||
end | ||
model | ||
end | ||
|
||
def before_iteration(model, epoch, dtrain, evals) | ||
@callbacks.any? do |callback| | ||
callback.before_iteration(model, epoch, @history) | ||
end | ||
end | ||
|
||
def after_iteration(model, epoch, dtrain, evals) | ||
evals ||= [] | ||
evals.each do |_, name| | ||
if name.include?("-") | ||
raise ArgumentError, "Dataset name should not contain `-`" | ||
end | ||
end | ||
score = model.eval_set(evals, epoch) | ||
metric_score = parse_eval_str(score) | ||
update_history(metric_score, epoch) | ||
|
||
@callbacks.any? do |callback| | ||
callback.after_iteration(model, epoch, @history) | ||
end | ||
end | ||
|
||
private | ||
|
||
def update_history(score, epoch) | ||
score.each do |d| | ||
name = d[0] | ||
s = d[1] | ||
x = s | ||
splited_names = name.split("-") | ||
data_name = splited_names[0] | ||
metric_name = splited_names[1..].join("-") | ||
@history[data_name] ||= {} | ||
data_history = @history[data_name] | ||
data_history[metric_name] ||= [] | ||
metric_history = data_history[metric_name] | ||
metric_history << x.to_f | ||
end | ||
end | ||
|
||
# TODO move | ||
def parse_eval_str(result) | ||
splited = result.split[1..] | ||
# split up `test-error:0.1234` | ||
metric_score_str = splited.map { |s| s.split(":") } | ||
# convert to float | ||
metric_score = metric_score_str.map { |n, s| [n, s.to_f] } | ||
metric_score | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
module XGBoost | ||
class TrainingCallback | ||
def before_training(model) | ||
# Run before training starts | ||
model | ||
end | ||
|
||
def after_training(model) | ||
# Run after training is finished | ||
model | ||
end | ||
|
||
def before_iteration(model, epoch, evals_log) | ||
# Run before each iteration. Returns true when training should stop. | ||
false | ||
end | ||
|
||
def after_iteration(model, epoch, evals_log) | ||
# Run after each iteration. Returns true when training should stop. | ||
false | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
require_relative "test_helper" | ||
|
||
class MockCallback < XGBoost::TrainingCallback | ||
attr_reader :before_training_count, :after_training_count, :before_iteration_count, | ||
:after_iteration_count, :before_iteration_args, :history | ||
|
||
def initialize | ||
@before_training_count = 0 | ||
@after_training_count = 0 | ||
@before_iteration_count = 0 | ||
@after_iteration_count = 0 | ||
@before_iteration_args = [] | ||
@history = {} | ||
end | ||
|
||
def before_training(model) | ||
@before_training_count += 1 | ||
model | ||
end | ||
|
||
def after_training(model) | ||
@after_training_count += 1 | ||
model | ||
end | ||
|
||
def before_iteration(model, epoch, evals_log) | ||
@before_iteration_count += 1 | ||
@before_iteration_args << {epoch: epoch} | ||
false | ||
end | ||
|
||
def after_iteration(model, epoch, evals_log) | ||
@after_iteration_count += 1 | ||
@history = evals_log | ||
false | ||
end | ||
end | ||
|
||
class CallbacksTest < Minitest::Test | ||
def test_callback_raises_when_not_training_callback | ||
error = assert_raises(TypeError) do | ||
XGBoost.train(regression_params, regression_train, callbacks: [Object.new]) | ||
end | ||
assert_equal "callback must be an instance of XGBoost::TrainingCallback", error.message | ||
end | ||
|
||
def test_callback | ||
callback = MockCallback.new | ||
num_boost_round = 10 | ||
|
||
XGBoost.train( | ||
regression_params, | ||
regression_train, | ||
num_boost_round: num_boost_round, | ||
callbacks: [callback], | ||
evals: [[regression_train, "train"], [regression_test, "eval"]], | ||
verbose_eval: false | ||
) | ||
|
||
assert_equal 1, callback.before_training_count | ||
assert_equal 1, callback.after_training_count | ||
assert_equal num_boost_round, callback.before_iteration_count | ||
assert_equal num_boost_round, callback.after_iteration_count | ||
|
||
# Verify arguments | ||
train_rmse = callback.history["train"]["rmse"] | ||
assert_equal num_boost_round, train_rmse.size | ||
train_rmse.each do |value| | ||
assert_in_delta 0.00, value, 1.0 | ||
end | ||
eval_rmse = callback.history["eval"]["rmse"] | ||
assert_equal num_boost_round, eval_rmse.size | ||
eval_rmse.each do |value| | ||
assert_in_delta 0.00, value, 1.0 | ||
end | ||
|
||
epochs = callback.before_iteration_args.map { |e| e[:epoch] } | ||
assert_equal (0...num_boost_round).to_a, epochs | ||
end | ||
|
||
def test_callback_breaks_on_before_iteration | ||
callback = MockCallback.new | ||
def callback.before_iteration(model, epoch, evals_log) | ||
@before_iteration_count += 1 | ||
@before_iteration_args << {epoch: epoch} | ||
epoch.odd? | ||
end | ||
|
||
XGBoost.train( | ||
regression_params, | ||
regression_train, | ||
callbacks: [callback], | ||
evals: [[regression_train, "train"], [regression_test, "eval"]], | ||
verbose_eval: false | ||
) | ||
|
||
assert_equal 1, callback.before_training_count | ||
assert_equal 1, callback.after_training_count | ||
assert_equal 2, callback.before_iteration_count | ||
assert_equal 1, callback.after_iteration_count | ||
|
||
# Verify arguments | ||
train_rmse = callback.history["train"]["rmse"] | ||
assert_equal 1, train_rmse.size | ||
train_rmse.each do |value| | ||
assert_in_delta 0.00, value, 1.0 | ||
end | ||
eval_rmse = callback.history["eval"]["rmse"] | ||
assert_equal 1, eval_rmse.size | ||
eval_rmse.each do |value| | ||
assert_in_delta 0.00, value, 1.0 | ||
end | ||
|
||
epochs = callback.before_iteration_args.map { |e| e[:epoch] } | ||
assert_equal (0...2).to_a, epochs | ||
end | ||
|
||
def test_callback_breaks_on_after_iteration | ||
callback = MockCallback.new | ||
def callback.after_iteration(model, epoch, evals_log) | ||
@after_iteration_count += 1 | ||
@history = evals_log | ||
epoch >= 7 | ||
end | ||
|
||
XGBoost.train( | ||
regression_params, | ||
regression_train, | ||
callbacks: [callback], | ||
evals: [[regression_train, "train"], [regression_test, "eval"]], | ||
verbose_eval: false | ||
) | ||
|
||
assert_equal 1, callback.before_training_count | ||
assert_equal 1, callback.after_training_count | ||
assert_equal 8, callback.before_iteration_count | ||
assert_equal 8, callback.after_iteration_count | ||
|
||
# Verify arguments | ||
train_rmse = callback.history["train"]["rmse"] | ||
assert_equal 8, train_rmse.size | ||
train_rmse.each do |value| | ||
assert_in_delta 0.00, value, 1.0 | ||
end | ||
eval_rmse = callback.history["eval"]["rmse"] | ||
assert_equal 8, eval_rmse.size | ||
eval_rmse.each do |value| | ||
assert_in_delta 0.00, value, 1.0 | ||
end | ||
|
||
epochs = callback.before_iteration_args.map { |e| e[:epoch] } | ||
assert_equal (0...8).to_a, epochs | ||
end | ||
|
||
def test_updates_model_before_training | ||
callback = MockCallback.new | ||
def callback.before_training(model) | ||
model["device"] = "cuda:0" | ||
model | ||
end | ||
|
||
model = XGBoost.train(regression_params, regression_train, callbacks: [callback]) | ||
|
||
assert_equal model["device"], "cuda:0" | ||
end | ||
end |