Skip to content

Commit

Permalink
Add support for different prediction types
Browse files Browse the repository at this point in the history
  • Loading branch information
nunosilva800 committed Dec 12, 2024
1 parent fca59ef commit 1ef6902
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 20 deletions.
1 change: 1 addition & 0 deletions lib/lightgbm.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# modules
require_relative "lightgbm/utils"
require_relative "lightgbm/macros"
require_relative "lightgbm/booster"
require_relative "lightgbm/dataset"
require_relative "lightgbm/version"
Expand Down
58 changes: 54 additions & 4 deletions lib/lightgbm/booster.rb
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,21 @@ def num_trees
out.read_int
end

# TODO support different prediction types
def predict(input, start_iteration: nil, num_iteration: nil, **params)

# Make prediction for a new dataset.
# C-API: LGBM_BoosterPredictForMat.
#
# @param input [Array, Array<Array,Hash>, Hash{String => Numeric, String}, Daru::DataFrame, Rover::DataFrame] Input data
# @param start_iteration [Integer] Start index of the iteration to predict
# @param num_iteration [Integer] Number of iteration for prediction, <= 0 means no limit
# @param predict_type [Integer] What should be predicted
# - C_API_PREDICT_NORMAL: normal prediction, with transform (if needed);
# - C_API_PREDICT_RAW_SCORE: raw score;
# - C_API_PREDICT_LEAF_INDEX: leaf index;
# - C_API_PREDICT_CONTRIB: feature contributions (SHAP values)
# @param **params [Hash] Other parameters for prediction, e.g. early stopping for prediction
# @return [Float, Array<Float>] Prediction results
def predict(input, start_iteration: nil, num_iteration: nil, predict_type: C_API_PREDICT_NORMAL, **params)
input =
if daru?(input)
input[*cached_feature_name].map_rows(&:to_a)
Expand Down Expand Up @@ -170,14 +183,51 @@ def predict(input, start_iteration: nil, num_iteration: nil, **params)
data.write_array_of_double(flat_input)

out_len = ::FFI::MemoryPointer.new(:int64)
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count)
check_result FFI.LGBM_BoosterPredictForMat(handle_pointer, data, 1, input.count, input.first.count, 1, 0, start_iteration, num_iteration, params_str(params), out_len, out_result)
case predict_type
when C_API_PREDICT_NORMAL, C_API_PREDICT_RAW_SCORE
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count)
when C_API_PREDICT_LEAF_INDEX
num_predict = num_preds(start_iteration:, num_iteration:, nrow: input.count, predict_type:)
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count * num_predict)
singular = false
when C_API_PREDICT_CONTRIB
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count * (num_feature + 1))
singular = false
end

check_result FFI.LGBM_BoosterPredictForMat(
handle_pointer,
data,
1,
input.count,
input.first.count,
1,
predict_type,
start_iteration,
num_iteration,
params_str(params),
out_len,
out_result
)
out = out_result.read_array_of_double(out_len.read_int64)
out = out.each_slice(num_class).to_a if num_class > 1

singular ? out.first : out
end

def num_preds(start_iteration: 0, num_iteration: best_iteration, nrow: nil, predict_type: C_API_PREDICT_NORMAL)
out_len = ::FFI::MemoryPointer.new(:int64)
check_result FFI.LGBM_BoosterCalcNumPredict(
handle_pointer,
nrow,
predict_type,
start_iteration,
num_iteration,
out_len
)
out_len.read_int64
end

def save_model(filename, num_iteration: nil, start_iteration: 0)
num_iteration ||= best_iteration
feature_importance_type = 0 # TODO add
Expand Down
1 change: 1 addition & 0 deletions lib/lightgbm/ffi.rb
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ module FFI
attach_function :LGBM_BoosterLoadModelFromString, %i[string pointer pointer], :int
attach_function :LGBM_BoosterFree, %i[pointer], :int
attach_function :LGBM_BoosterAddValidData, %i[pointer pointer], :int
attach_function :LGBM_BoosterCalcNumPredict, %i[pointer int int int int pointer], :int
attach_function :LGBM_BoosterGetNumClasses, %i[pointer pointer], :int
attach_function :LGBM_BoosterUpdateOneIter, %i[pointer pointer], :int
attach_function :LGBM_BoosterGetCurrentIteration, %i[pointer pointer], :int
Expand Down
7 changes: 7 additions & 0 deletions lib/lightgbm/macros.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module LightGBM
# Macro definition of prediction type in C API of LightGBM
C_API_PREDICT_NORMAL = 0
C_API_PREDICT_RAW_SCORE = 1
C_API_PREDICT_LEAF_INDEX = 2
C_API_PREDICT_CONTRIB = 3
end
79 changes: 63 additions & 16 deletions test/booster_test.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
require_relative "test_helper"

class BoosterTest < Minitest::Test
def test_model_file
def test_predict
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
booster = LightGBM::Booster.new(model_file: "test/support/model.txt")
y_pred = booster.predict(x_test)
Expand All @@ -23,21 +23,6 @@ def test_model_from_string
assert_elements_in_delta [0.9823112229173586, 0.9583143724610858], y_pred.first(2)
end

def test_feature_importance
assert_equal [280, 285, 335, 148], booster.feature_importance
end

def test_feature_name
assert_equal ["x0", "x1", "x2", "x3"], booster.feature_name
end

def test_feature_importance_bad_importance_type
error = assert_raises(LightGBM::Error) do
booster.feature_importance(importance_type: "bad")
end
assert_includes error.message, "Unknown importance type"
end

def test_predict_hash
pred = booster.predict({x0: 3.7, x1: 1.2, x2: 7.2, x3: 9.0})
assert_in_delta 0.9823112229173586, pred
Expand Down Expand Up @@ -88,6 +73,68 @@ def test_predict_rover
end
end

def test_predict_type_leaf_index
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
leaf_indexes = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_LEAF_INDEX)
assert_equal 200, leaf_indexes.count
assert_equal 9.0, leaf_indexes.first
assert_equal 7.0, leaf_indexes.last

x_test = [3.7, 1.2, 7.2, 9.0]
leaf_indexes = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_LEAF_INDEX)
assert_equal 100, leaf_indexes.count
assert_equal 9.0, leaf_indexes.first
assert_equal 10.0, leaf_indexes.last
end

def test_predict_type_contrib
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
results = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_CONTRIB)
assert_equal 10, results.count

# split results on num_features + 1
predictions = results.each_slice(5).to_a
shap_values_1 = predictions.first[0..-2]
ypred_1 = predictions.first[-1]
assert_elements_in_delta [
-0.0733949225678886, -0.24289592050101766, 0.24183795683166504, 0.063430775771174
], shap_values_1
assert_in_delta (0.9933333333834246), ypred_1

shap_values_2 = predictions.last[0..-2]
ypred_2 = predictions.last[-1]
assert_elements_in_delta [
0.1094902954684793, -0.2810485083947154, 0.26691627597706397, -0.13037702397316747
], shap_values_2
assert_in_delta (0.9933333333834246), ypred_2

# single row
x_test = [3.7, 1.2, 7.2, 9.0]
results = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_CONTRIB)
assert_equal 5, results.count
shap_values = results[0..-2]
ypred = results[-1]
assert_elements_in_delta [
-0.0733949225678886, -0.24289592050101766, 0.24183795683166504, 0.063430775771174
], shap_values
assert_in_delta (0.9933333333834246), ypred
end

def test_feature_importance
assert_equal [280, 285, 335, 148], booster.feature_importance
end

def test_feature_name
assert_equal ["x0", "x1", "x2", "x3"], booster.feature_name
end

def test_feature_importance_bad_importance_type
error = assert_raises(LightGBM::Error) do
booster.feature_importance(importance_type: "bad")
end
assert_includes error.message, "Unknown importance type"
end

def test_model_to_string
assert booster.model_to_string
end
Expand Down

0 comments on commit 1ef6902

Please sign in to comment.