diff --git a/lib/lightgbm.rb b/lib/lightgbm.rb index da9617c..70efa63 100644 --- a/lib/lightgbm.rb +++ b/lib/lightgbm.rb @@ -3,6 +3,7 @@ # modules require_relative "lightgbm/utils" +require_relative "lightgbm/macros" require_relative "lightgbm/booster" require_relative "lightgbm/dataset" require_relative "lightgbm/version" diff --git a/lib/lightgbm/booster.rb b/lib/lightgbm/booster.rb index 29a0ef2..8d42bb8 100644 --- a/lib/lightgbm/booster.rb +++ b/lib/lightgbm/booster.rb @@ -141,8 +141,21 @@ def num_trees out.read_int end - # TODO support different prediction types - def predict(input, start_iteration: nil, num_iteration: nil, **params) + + # Make prediction for a new dataset. + # C-API: LGBM_BoosterPredictForMat. + # + # @param input [Array, Array, Hash{String => Numeric, String}, Daru::DataFrame, Rover::DataFrame] Input data + # @param start_iteration [Integer] Start index of the iteration to predict + # @param num_iteration [Integer] Number of iteration for prediction, <= 0 means no limit + # @param predict_type [Integer] What should be predicted + # - C_API_PREDICT_NORMAL: normal prediction, with transform (if needed); + # - C_API_PREDICT_RAW_SCORE: raw score; + # - C_API_PREDICT_LEAF_INDEX: leaf index; + # - C_API_PREDICT_CONTRIB: feature contributions (SHAP values) + # @param **params [Hash] Other parameters for prediction, e.g. early stopping for prediction + # @return [Float, Array] Prediction results + def predict(input, start_iteration: nil, num_iteration: nil, predict_type: C_API_PREDICT_NORMAL, **params) input = if daru?(input) input[*cached_feature_name].map_rows(&:to_a) @@ -170,14 +183,51 @@ def predict(input, start_iteration: nil, num_iteration: nil, **params) data.write_array_of_double(flat_input) out_len = ::FFI::MemoryPointer.new(:int64) - out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count) - check_result FFI.LGBM_BoosterPredictForMat(handle_pointer, data, 1, input.count, input.first.count, 1, 0, start_iteration, num_iteration, params_str(params), out_len, out_result) + case predict_type + when C_API_PREDICT_NORMAL, C_API_PREDICT_RAW_SCORE + out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count) + when C_API_PREDICT_LEAF_INDEX + num_predict = num_preds(start_iteration:, num_iteration:, nrow: input.count, predict_type:) + out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count * num_predict) + singular = false + when C_API_PREDICT_CONTRIB + out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count * (num_feature + 1)) + singular = false + end + + check_result FFI.LGBM_BoosterPredictForMat( + handle_pointer, + data, + 1, + input.count, + input.first.count, + 1, + predict_type, + start_iteration, + num_iteration, + params_str(params), + out_len, + out_result + ) out = out_result.read_array_of_double(out_len.read_int64) out = out.each_slice(num_class).to_a if num_class > 1 singular ? out.first : out end + def num_preds(start_iteration: 0, num_iteration: best_iteration, nrow: nil, predict_type: C_API_PREDICT_NORMAL) + out_len = ::FFI::MemoryPointer.new(:int64) + check_result FFI.LGBM_BoosterCalcNumPredict( + handle_pointer, + nrow, + predict_type, + start_iteration, + num_iteration, + out_len + ) + out_len.read_int64 + end + def save_model(filename, num_iteration: nil, start_iteration: 0) num_iteration ||= best_iteration feature_importance_type = 0 # TODO add diff --git a/lib/lightgbm/ffi.rb b/lib/lightgbm/ffi.rb index b2a1bec..5f4dfe8 100644 --- a/lib/lightgbm/ffi.rb +++ b/lib/lightgbm/ffi.rb @@ -38,6 +38,7 @@ module FFI attach_function :LGBM_BoosterLoadModelFromString, %i[string pointer pointer], :int attach_function :LGBM_BoosterFree, %i[pointer], :int attach_function :LGBM_BoosterAddValidData, %i[pointer pointer], :int + attach_function :LGBM_BoosterCalcNumPredict, %i[pointer int int int int pointer], :int attach_function :LGBM_BoosterGetNumClasses, %i[pointer pointer], :int attach_function :LGBM_BoosterUpdateOneIter, %i[pointer pointer], :int attach_function :LGBM_BoosterGetCurrentIteration, %i[pointer pointer], :int diff --git a/lib/lightgbm/macros.rb b/lib/lightgbm/macros.rb new file mode 100644 index 0000000..1b498cc --- /dev/null +++ b/lib/lightgbm/macros.rb @@ -0,0 +1,7 @@ +module LightGBM + # Macro definition of prediction type in C API of LightGBM + C_API_PREDICT_NORMAL = 0 + C_API_PREDICT_RAW_SCORE = 1 + C_API_PREDICT_LEAF_INDEX = 2 + C_API_PREDICT_CONTRIB = 3 +end diff --git a/test/booster_test.rb b/test/booster_test.rb index d1d93b4..ba4a41e 100644 --- a/test/booster_test.rb +++ b/test/booster_test.rb @@ -1,7 +1,7 @@ require_relative "test_helper" class BoosterTest < Minitest::Test - def test_model_file + def test_predict x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]] booster = LightGBM::Booster.new(model_file: "test/support/model.txt") y_pred = booster.predict(x_test) @@ -23,21 +23,6 @@ def test_model_from_string assert_elements_in_delta [0.9823112229173586, 0.9583143724610858], y_pred.first(2) end - def test_feature_importance - assert_equal [280, 285, 335, 148], booster.feature_importance - end - - def test_feature_name - assert_equal ["x0", "x1", "x2", "x3"], booster.feature_name - end - - def test_feature_importance_bad_importance_type - error = assert_raises(LightGBM::Error) do - booster.feature_importance(importance_type: "bad") - end - assert_includes error.message, "Unknown importance type" - end - def test_predict_hash pred = booster.predict({x0: 3.7, x1: 1.2, x2: 7.2, x3: 9.0}) assert_in_delta 0.9823112229173586, pred @@ -88,6 +73,68 @@ def test_predict_rover end end + def test_predict_type_leaf_index + x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]] + leaf_indexes = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_LEAF_INDEX) + assert_equal 200, leaf_indexes.count + assert_equal 9.0, leaf_indexes.first + assert_equal 7.0, leaf_indexes.last + + x_test = [3.7, 1.2, 7.2, 9.0] + leaf_indexes = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_LEAF_INDEX) + assert_equal 100, leaf_indexes.count + assert_equal 9.0, leaf_indexes.first + assert_equal 10.0, leaf_indexes.last + end + + def test_predict_type_contrib + x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]] + results = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_CONTRIB) + assert_equal 10, results.count + + # split results on num_features + 1 + predictions = results.each_slice(5).to_a + shap_values_1 = predictions.first[0..-2] + ypred_1 = predictions.first[-1] + assert_elements_in_delta [ + -0.0733949225678886, -0.24289592050101766, 0.24183795683166504, 0.063430775771174 + ], shap_values_1 + assert_in_delta (0.9933333333834246), ypred_1 + + shap_values_2 = predictions.last[0..-2] + ypred_2 = predictions.last[-1] + assert_elements_in_delta [ + 0.1094902954684793, -0.2810485083947154, 0.26691627597706397, -0.13037702397316747 + ], shap_values_2 + assert_in_delta (0.9933333333834246), ypred_2 + + # single row + x_test = [3.7, 1.2, 7.2, 9.0] + results = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_CONTRIB) + assert_equal 5, results.count + shap_values = results[0..-2] + ypred = results[-1] + assert_elements_in_delta [ + -0.0733949225678886, -0.24289592050101766, 0.24183795683166504, 0.063430775771174 + ], shap_values + assert_in_delta (0.9933333333834246), ypred + end + + def test_feature_importance + assert_equal [280, 285, 335, 148], booster.feature_importance + end + + def test_feature_name + assert_equal ["x0", "x1", "x2", "x3"], booster.feature_name + end + + def test_feature_importance_bad_importance_type + error = assert_raises(LightGBM::Error) do + booster.feature_importance(importance_type: "bad") + end + assert_includes error.message, "Unknown importance type" + end + def test_model_to_string assert booster.model_to_string end