Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for different prediction types #10

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/lightgbm.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# modules
require_relative "lightgbm/utils"
require_relative "lightgbm/macros"
require_relative "lightgbm/booster"
require_relative "lightgbm/dataset"
require_relative "lightgbm/version"
Expand Down
58 changes: 54 additions & 4 deletions lib/lightgbm/booster.rb
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,21 @@ def num_trees
out.read_int
end

# TODO support different prediction types
def predict(input, start_iteration: nil, num_iteration: nil, **params)

# Make prediction for a new dataset.
# C-API: LGBM_BoosterPredictForMat.
#
# @param input [Array, Array<Array,Hash>, Hash{String => Numeric, String}, Daru::DataFrame, Rover::DataFrame] Input data
# @param start_iteration [Integer] Start index of the iteration to predict
# @param num_iteration [Integer] Number of iteration for prediction, <= 0 means no limit
# @param predict_type [Integer] What should be predicted
# - C_API_PREDICT_NORMAL: normal prediction, with transform (if needed);
# - C_API_PREDICT_RAW_SCORE: raw score;
# - C_API_PREDICT_LEAF_INDEX: leaf index;
# - C_API_PREDICT_CONTRIB: feature contributions (SHAP values)
Comment on lines +151 to +155
Copy link
Contributor Author

@nunosilva800 nunosilva800 Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a single argument instead of 3 boolean args, like Python API does.

https://lightgbm.readthedocs.io/en/stable/pythonapi/lightgbm.Booster.html#lightgbm.Booster.predict

https://github.com/microsoft/LightGBM/blob/e057ae08e6bf6c6c84f276a127423fb145ca5fdb/python-package/lightgbm/basic.py#L1079-L1081

@ankane
Let me know which you prefer. I don't really see the point of 3 booleans, since they are exclusive...

# @param **params [Hash] Other parameters for prediction, e.g. early stopping for prediction
# @return [Float, Array<Float>] Prediction results
def predict(input, start_iteration: nil, num_iteration: nil, predict_type: C_API_PREDICT_NORMAL, **params)
input =
if daru?(input)
input[*cached_feature_name].map_rows(&:to_a)
Expand Down Expand Up @@ -170,14 +183,51 @@ def predict(input, start_iteration: nil, num_iteration: nil, **params)
data.write_array_of_double(flat_input)

out_len = ::FFI::MemoryPointer.new(:int64)
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count)
check_result FFI.LGBM_BoosterPredictForMat(handle_pointer, data, 1, input.count, input.first.count, 1, 0, start_iteration, num_iteration, params_str(params), out_len, out_result)
case predict_type
when C_API_PREDICT_NORMAL, C_API_PREDICT_RAW_SCORE
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count)
when C_API_PREDICT_LEAF_INDEX
num_predict = num_preds(start_iteration:, num_iteration:, nrow: input.count, predict_type:)
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count * num_predict)
singular = false
when C_API_PREDICT_CONTRIB
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count * (num_feature + 1))
singular = false
end

check_result FFI.LGBM_BoosterPredictForMat(
handle_pointer,
data,
1,
input.count,
input.first.count,
1,
predict_type,
start_iteration,
num_iteration,
params_str(params),
out_len,
out_result
)
out = out_result.read_array_of_double(out_len.read_int64)
out = out.each_slice(num_class).to_a if num_class > 1

singular ? out.first : out
end

def num_preds(start_iteration: 0, num_iteration: best_iteration, nrow: nil, predict_type: C_API_PREDICT_NORMAL)
out_len = ::FFI::MemoryPointer.new(:int64)
check_result FFI.LGBM_BoosterCalcNumPredict(
handle_pointer,
nrow,
predict_type,
start_iteration,
num_iteration,
out_len
)
out_len.read_int64
end

def save_model(filename, num_iteration: nil, start_iteration: 0)
num_iteration ||= best_iteration
feature_importance_type = 0 # TODO add
Expand Down
1 change: 1 addition & 0 deletions lib/lightgbm/ffi.rb
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ module FFI
attach_function :LGBM_BoosterLoadModelFromString, %i[string pointer pointer], :int
attach_function :LGBM_BoosterFree, %i[pointer], :int
attach_function :LGBM_BoosterAddValidData, %i[pointer pointer], :int
attach_function :LGBM_BoosterCalcNumPredict, %i[pointer int int int int pointer], :int
attach_function :LGBM_BoosterGetNumClasses, %i[pointer pointer], :int
attach_function :LGBM_BoosterUpdateOneIter, %i[pointer pointer], :int
attach_function :LGBM_BoosterGetCurrentIteration, %i[pointer pointer], :int
Expand Down
7 changes: 7 additions & 0 deletions lib/lightgbm/macros.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module LightGBM
# Macro definition of prediction type in C API of LightGBM
C_API_PREDICT_NORMAL = 0
C_API_PREDICT_RAW_SCORE = 1
C_API_PREDICT_LEAF_INDEX = 2
C_API_PREDICT_CONTRIB = 3
end
79 changes: 63 additions & 16 deletions test/booster_test.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
require_relative "test_helper"

class BoosterTest < Minitest::Test
def test_model_file
def test_predict
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
booster = LightGBM::Booster.new(model_file: "test/support/model.txt")
y_pred = booster.predict(x_test)
Expand All @@ -23,21 +23,6 @@ def test_model_from_string
assert_elements_in_delta [0.9823112229173586, 0.9583143724610858], y_pred.first(2)
end

def test_feature_importance
assert_equal [280, 285, 335, 148], booster.feature_importance
end

def test_feature_name
assert_equal ["x0", "x1", "x2", "x3"], booster.feature_name
end

def test_feature_importance_bad_importance_type
error = assert_raises(LightGBM::Error) do
booster.feature_importance(importance_type: "bad")
end
assert_includes error.message, "Unknown importance type"
end

def test_predict_hash
pred = booster.predict({x0: 3.7, x1: 1.2, x2: 7.2, x3: 9.0})
assert_in_delta 0.9823112229173586, pred
Expand Down Expand Up @@ -88,6 +73,68 @@ def test_predict_rover
end
end

def test_predict_type_leaf_index
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
leaf_indexes = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_LEAF_INDEX)
assert_equal 200, leaf_indexes.count
assert_equal 9.0, leaf_indexes.first
assert_equal 7.0, leaf_indexes.last

x_test = [3.7, 1.2, 7.2, 9.0]
leaf_indexes = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_LEAF_INDEX)
assert_equal 100, leaf_indexes.count
assert_equal 9.0, leaf_indexes.first
assert_equal 10.0, leaf_indexes.last
end

def test_predict_type_contrib
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
results = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_CONTRIB)
assert_equal 10, results.count

# split results on num_features + 1
predictions = results.each_slice(5).to_a
shap_values_1 = predictions.first[0..-2]
ypred_1 = predictions.first[-1]
assert_elements_in_delta [
-0.0733949225678886, -0.24289592050101766, 0.24183795683166504, 0.063430775771174
], shap_values_1
assert_in_delta (0.9933333333834246), ypred_1

shap_values_2 = predictions.last[0..-2]
ypred_2 = predictions.last[-1]
assert_elements_in_delta [
0.1094902954684793, -0.2810485083947154, 0.26691627597706397, -0.13037702397316747
], shap_values_2
assert_in_delta (0.9933333333834246), ypred_2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the last value (supposed to be the prediction result) is the same on each row, and why different from expected ypred values in other tests.


# single row
x_test = [3.7, 1.2, 7.2, 9.0]
results = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_CONTRIB)
assert_equal 5, results.count
shap_values = results[0..-2]
ypred = results[-1]
assert_elements_in_delta [
-0.0733949225678886, -0.24289592050101766, 0.24183795683166504, 0.063430775771174
], shap_values
assert_in_delta (0.9933333333834246), ypred
end

def test_feature_importance
assert_equal [280, 285, 335, 148], booster.feature_importance
end

def test_feature_name
assert_equal ["x0", "x1", "x2", "x3"], booster.feature_name
end

def test_feature_importance_bad_importance_type
error = assert_raises(LightGBM::Error) do
booster.feature_importance(importance_type: "bad")
end
assert_includes error.message, "Unknown importance type"
end

def test_model_to_string
assert booster.model_to_string
end
Expand Down
Loading