Skip to content

Commit

Permalink
Synced predict code
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Dec 16, 2024
1 parent 275fac6 commit f1d3ade
Showing 1 changed file with 59 additions and 40 deletions.
99 changes: 59 additions & 40 deletions lib/lightgbm/booster.rb
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,16 @@ def num_trees
out.read_int
end

def predict(input, start_iteration: 0, num_iteration: nil, raw_score: false, pred_leaf: false, pred_contrib: false, **params)
num_iteration ||= best_iteration
def predict(data, start_iteration: 0, num_iteration: -1, raw_score: false, pred_leaf: false, pred_contrib: false, **kwargs)
if num_iteration.nil?
if start_iteration <= 0
num_iteration = best_iteration
else
num_iteration = -1
end
end

if input.is_a?(Dataset)
if data.is_a?(Dataset)
raise TypeError, "Cannot use Dataset instance for prediction, please use raw data instead"
end

Expand All @@ -159,47 +165,15 @@ def predict(input, start_iteration: 0, num_iteration: nil, raw_score: false, pre
predict_type = FFI::C_API_PREDICT_CONTRIB
end

input =
if daru?(input)
input[*cached_feature_name].map_rows(&:to_a)
elsif input.is_a?(Hash) # sort feature.values to match the order of model.feature_name
sorted_feature_values(input)
elsif input.is_a?(Array) && input.first.is_a?(Hash) # on multiple elems, if 1st is hash, assume they all are
input.map(&method(:sorted_feature_values))
elsif rover?(input)
# TODO improve performance
input[cached_feature_name].to_numo.to_a
else
input.to_a
end

singular = !input.first.is_a?(Array)
input = [input] if singular

nrow = input.count
n_preds =
num_preds(
preds, nrow, singular =
preds_for_data(
data,
start_iteration,
num_iteration,
nrow,
predict_type
predict_type,
**kwargs
)

flat_input = input.flatten
handle_missing(flat_input)
data = ::FFI::MemoryPointer.new(:double, input.count * input.first.count)
data.write_array_of_double(flat_input)

out_len = ::FFI::MemoryPointer.new(:int64)
out_result = ::FFI::MemoryPointer.new(:double, n_preds)
check_result FFI.LGBM_BoosterPredictForMat(handle_pointer, data, 1, input.count, input.first.count, 1, predict_type, start_iteration, num_iteration, params_str(params), out_len, out_result)

if n_preds != out_len.read_int64
raise Error, "Wrong length for predict results"
end

preds = out_result.read_array_of_double(out_len.read_int64)

if pred_leaf
preds = preds.map(&:to_i)
end
Expand Down Expand Up @@ -287,6 +261,51 @@ def num_class
out.read_int
end

def preds_for_data(input, start_iteration, num_iteration, predict_type, **params)
input =
if daru?(input)
input[*cached_feature_name].map_rows(&:to_a)
elsif input.is_a?(Hash) # sort feature.values to match the order of model.feature_name
sorted_feature_values(input)
elsif input.is_a?(Array) && input.first.is_a?(Hash) # on multiple elems, if 1st is hash, assume they all are
input.map(&method(:sorted_feature_values))
elsif rover?(input)
# TODO improve performance
input[cached_feature_name].to_numo.to_a
else
input.to_a
end

singular = !input.first.is_a?(Array)
input = [input] if singular

nrow = input.count
n_preds =
num_preds(
start_iteration,
num_iteration,
nrow,
predict_type
)

flat_input = input.flatten
handle_missing(flat_input)
data = ::FFI::MemoryPointer.new(:double, input.count * input.first.count)
data.write_array_of_double(flat_input)

out_len = ::FFI::MemoryPointer.new(:int64)
out_result = ::FFI::MemoryPointer.new(:double, n_preds)
check_result FFI.LGBM_BoosterPredictForMat(handle_pointer, data, 1, input.count, input.first.count, 1, predict_type, start_iteration, num_iteration, params_str(params), out_len, out_result)

if n_preds != out_len.read_int64
raise Error, "Wrong length for predict results"
end

preds = out_result.read_array_of_double(out_len.read_int64)

[preds, nrow, singular]
end

def num_preds(start_iteration, num_iteration, nrow, predict_type)
out = ::FFI::MemoryPointer.new(:int64)
check_result FFI.LGBM_BoosterCalcNumPredict(handle_pointer, nrow, predict_type, start_iteration, num_iteration, out)
Expand Down

0 comments on commit f1d3ade

Please sign in to comment.