Skip to content

Commit

Permalink
Added support for callbacks - #11
Browse files Browse the repository at this point in the history
Co-authored-by: Brett Shollenberger <brett.shollenberger@gmail.com>
  • Loading branch information
ankane and brettshollenberger committed Oct 16, 2024
1 parent 066362c commit 73c4b4c
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.8.1 (unreleased)

- Added support for callbacks

## 0.8.0 (2023-09-13)

- Updated XGBoost to 2.0.0
Expand Down
12 changes: 11 additions & 1 deletion lib/xgboost.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# modules
require_relative "xgboost/utils"
require_relative "xgboost/booster"
require_relative "xgboost/callback_container"
require_relative "xgboost/dmatrix"
require_relative "xgboost/training_callback"
require_relative "xgboost/version"

# scikit-learn API
Expand Down Expand Up @@ -44,8 +46,12 @@ class << self
autoload :FFI, "xgboost/ffi"

class << self
def train(params, dtrain, num_boost_round: 10, evals: nil, early_stopping_rounds: nil, verbose_eval: true)
def train(params, dtrain, num_boost_round: 10, evals: nil, early_stopping_rounds: nil, verbose_eval: true, callbacks: nil)
callbacks ||= []
booster = Booster.new(params: params)
cb_container = CallbackContainer.new(callbacks)
booster = cb_container.before_training(booster)

num_feature = dtrain.num_col
booster.set_param("num_feature", num_feature)
booster.feature_names = dtrain.feature_names
Expand All @@ -59,6 +65,7 @@ def train(params, dtrain, num_boost_round: 10, evals: nil, early_stopping_rounds
end

num_boost_round.times do |iteration|
break if cb_container.before_iteration(booster, iteration, dtrain, evals)
booster.update(dtrain, iteration)

if evals.any?
Expand All @@ -80,11 +87,14 @@ def train(params, dtrain, num_boost_round: 10, evals: nil, early_stopping_rounds
best_message = message
elsif early_stopping_rounds && iteration - best_iter >= early_stopping_rounds
booster.best_iteration = best_iter
booster.best_score = best_score
puts "Stopping. Best iteration:\n#{best_message}" if verbose_eval
break
end
break if cb_container.after_iteration(booster, iteration, dtrain, evals)
end
end
booster = cb_container.after_training(booster)

booster
end
Expand Down
2 changes: 1 addition & 1 deletion lib/xgboost/booster.rb
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module XGBoost
class Booster
attr_accessor :best_iteration, :feature_names, :feature_types
attr_accessor :best_iteration, :feature_names, :feature_types, :best_score

def initialize(params: nil, model_file: nil)
@handle = ::FFI::MemoryPointer.new(:pointer)
Expand Down
84 changes: 84 additions & 0 deletions lib/xgboost/callback_container.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
module XGBoost
class CallbackContainer
def initialize(callbacks)
@callbacks = callbacks
callbacks.each do |callback|
unless callback.is_a?(TrainingCallback)
raise TypeError, "callback must be an instance of XGBoost::TrainingCallback"
end
end

@history = {}
end

def before_training(model)
@callbacks.each do |callback|
model = callback.before_training(model)
unless model.is_a?(Booster)
raise TypeError, "before_training should return the model"
end
end
model
end

def after_training(model)
@callbacks.each do |callback|
model = callback.after_training(model)
unless model.is_a?(Booster)
raise TypeError, "after_training should return the model"
end
end
model
end

def before_iteration(model, epoch, dtrain, evals)
@callbacks.any? do |callback|
callback.before_iteration(model, epoch, @history)
end
end

def after_iteration(model, epoch, dtrain, evals)
evals ||= []
evals.each do |_, name|
if name.include?("-")
raise ArgumentError, "Dataset name should not contain `-`"
end
end
score = model.eval_set(evals, epoch)
metric_score = parse_eval_str(score)
update_history(metric_score, epoch)

@callbacks.any? do |callback|
callback.after_iteration(model, epoch, @history)
end
end

private

def update_history(score, epoch)
score.each do |d|
name = d[0]
s = d[1]
x = s
splited_names = name.split("-")
data_name = splited_names[0]
metric_name = splited_names[1..].join("-")
@history[data_name] ||= {}
data_history = @history[data_name]
data_history[metric_name] ||= []
metric_history = data_history[metric_name]
metric_history << x.to_f
end
end

# TODO move
def parse_eval_str(result)
splited = result.split[1..]
# split up `test-error:0.1234`
metric_score_str = splited.map { |s| s.split(":") }
# convert to float
metric_score = metric_score_str.map { |n, s| [n, s.to_f] }
metric_score
end
end
end
23 changes: 23 additions & 0 deletions lib/xgboost/training_callback.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module XGBoost
class TrainingCallback
def before_training(model)
# Run before training starts
model
end

def after_training(model)
# Run after training is finished
model
end

def before_iteration(model, epoch, evals_log)
# Run before each iteration. Returns true when training should stop.
false
end

def after_iteration(model, epoch, evals_log)
# Run after each iteration. Returns true when training should stop.
false
end
end
end
166 changes: 166 additions & 0 deletions test/callbacks_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
require_relative "test_helper"

class MockCallback < XGBoost::TrainingCallback
attr_reader :before_training_count, :after_training_count, :before_iteration_count,
:after_iteration_count, :before_iteration_args, :history

def initialize
@before_training_count = 0
@after_training_count = 0
@before_iteration_count = 0
@after_iteration_count = 0
@before_iteration_args = []
@history = {}
end

def before_training(model)
@before_training_count += 1
model
end

def after_training(model)
@after_training_count += 1
model
end

def before_iteration(model, epoch, evals_log)
@before_iteration_count += 1
@before_iteration_args << {epoch: epoch}
false
end

def after_iteration(model, epoch, evals_log)
@after_iteration_count += 1
@history = evals_log
false
end
end

class CallbacksTest < Minitest::Test
def test_callback_raises_when_not_training_callback
error = assert_raises(TypeError) do
XGBoost.train(regression_params, regression_train, callbacks: [Object.new])
end
assert_equal "callback must be an instance of XGBoost::TrainingCallback", error.message
end

def test_callback
callback = MockCallback.new
num_boost_round = 10

XGBoost.train(
regression_params,
regression_train,
num_boost_round: num_boost_round,
callbacks: [callback],
evals: [[regression_train, "train"], [regression_test, "eval"]],
verbose_eval: false
)

assert_equal 1, callback.before_training_count
assert_equal 1, callback.after_training_count
assert_equal num_boost_round, callback.before_iteration_count
assert_equal num_boost_round, callback.after_iteration_count

# Verify arguments
train_rmse = callback.history["train"]["rmse"]
assert_equal num_boost_round, train_rmse.size
train_rmse.each do |value|
assert_in_delta 0.00, value, 1.0
end
eval_rmse = callback.history["eval"]["rmse"]
assert_equal num_boost_round, eval_rmse.size
eval_rmse.each do |value|
assert_in_delta 0.00, value, 1.0
end

epochs = callback.before_iteration_args.map { |e| e[:epoch] }
assert_equal (0...num_boost_round).to_a, epochs
end

def test_callback_breaks_on_before_iteration
callback = MockCallback.new
def callback.before_iteration(model, epoch, evals_log)
@before_iteration_count += 1
@before_iteration_args << {epoch: epoch}
epoch.odd?
end

XGBoost.train(
regression_params,
regression_train,
callbacks: [callback],
evals: [[regression_train, "train"], [regression_test, "eval"]],
verbose_eval: false
)

assert_equal 1, callback.before_training_count
assert_equal 1, callback.after_training_count
assert_equal 2, callback.before_iteration_count
assert_equal 1, callback.after_iteration_count

# Verify arguments
train_rmse = callback.history["train"]["rmse"]
assert_equal 1, train_rmse.size
train_rmse.each do |value|
assert_in_delta 0.00, value, 1.0
end
eval_rmse = callback.history["eval"]["rmse"]
assert_equal 1, eval_rmse.size
eval_rmse.each do |value|
assert_in_delta 0.00, value, 1.0
end

epochs = callback.before_iteration_args.map { |e| e[:epoch] }
assert_equal (0...2).to_a, epochs
end

def test_callback_breaks_on_after_iteration
callback = MockCallback.new
def callback.after_iteration(model, epoch, evals_log)
@after_iteration_count += 1
@history = evals_log
epoch >= 7
end

XGBoost.train(
regression_params,
regression_train,
callbacks: [callback],
evals: [[regression_train, "train"], [regression_test, "eval"]],
verbose_eval: false
)

assert_equal 1, callback.before_training_count
assert_equal 1, callback.after_training_count
assert_equal 8, callback.before_iteration_count
assert_equal 8, callback.after_iteration_count

# Verify arguments
train_rmse = callback.history["train"]["rmse"]
assert_equal 8, train_rmse.size
train_rmse.each do |value|
assert_in_delta 0.00, value, 1.0
end
eval_rmse = callback.history["eval"]["rmse"]
assert_equal 8, eval_rmse.size
eval_rmse.each do |value|
assert_in_delta 0.00, value, 1.0
end

epochs = callback.before_iteration_args.map { |e| e[:epoch] }
assert_equal (0...8).to_a, epochs
end

def test_updates_model_before_training
callback = MockCallback.new
def callback.before_training(model)
model["device"] = "cuda:0"
model
end

model = XGBoost.train(regression_params, regression_train, callbacks: [callback])

assert_equal model["device"], "cuda:0"
end
end

0 comments on commit 73c4b4c

Please sign in to comment.