Skip to content

Commit

Permalink
Renamed handle_pointer to handle
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 16, 2024
1 parent f6053b1 commit 4c35b78
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 31 deletions.
38 changes: 19 additions & 19 deletions lib/xgboost/booster.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def initialize(params: nil, cache: nil, model_file: nil)
end
end

dmats = array_of_pointers(cache.map { |d| d.handle_pointer })
dmats = array_of_pointers(cache.map { |d| d.handle })
out = ::FFI::MemoryPointer.new(:pointer)
check_call FFI.XGBoosterCreate(dmats, cache.length, out)
@handle = ::FFI::AutoPointer.new(out.read_pointer, FFI.method(:XGBoosterFree))
Expand All @@ -20,7 +20,7 @@ def initialize(params: nil, cache: nil, model_file: nil)
end

if model_file
check_call FFI.XGBoosterLoadModel(handle_pointer, model_file)
check_call FFI.XGBoosterLoadModel(handle, model_file)
end

set_param(params)
Expand All @@ -38,7 +38,7 @@ def []=(key_name, raw_value)
def save_config
length = ::FFI::MemoryPointer.new(:uint64)
json_string = ::FFI::MemoryPointer.new(:pointer)
check_call FFI.XGBoosterSaveJsonConfig(handle_pointer, length, json_string)
check_call FFI.XGBoosterSaveJsonConfig(handle, length, json_string)
json_string.read_pointer.read_string(read_uint64(length)).force_encoding(Encoding::UTF_8)
end

Expand All @@ -47,15 +47,15 @@ def attr(key_name)
success = ::FFI::MemoryPointer.new(:int)
out_result = ::FFI::MemoryPointer.new(:pointer)

check_call FFI.XGBoosterGetAttr(handle_pointer, key, out_result, success)
check_call FFI.XGBoosterGetAttr(handle, key, out_result, success)

success.read_int == 1 ? out_result.read_pointer.read_string : nil
end

def attributes
out_len = ::FFI::MemoryPointer.new(:uint64)
out_result = ::FFI::MemoryPointer.new(:pointer)
check_call FFI.XGBoosterGetAttrNames(handle_pointer, out_len, out_result)
check_call FFI.XGBoosterGetAttrNames(handle, out_len, out_result)

len = read_uint64(out_len)
key_names = len.zero? ? [] : out_result.read_pointer.get_array_of_string(0, len)
Expand All @@ -68,7 +68,7 @@ def set_attr(**kwargs)
key = string_pointer(key_name)
value = raw_value.nil? ? nil : string_pointer(raw_value.to_s)

check_call FFI.XGBoosterSetAttr(handle_pointer, key, value)
check_call FFI.XGBoosterSetAttr(handle, key, value)
end
end

Expand All @@ -91,24 +91,24 @@ def feature_names=(features)
def set_param(params, value = nil)
if params.is_a?(Enumerable)
params.each do |k, v|
check_call FFI.XGBoosterSetParam(handle_pointer, k.to_s, v.to_s)
check_call FFI.XGBoosterSetParam(handle, k.to_s, v.to_s)
end
else
check_call FFI.XGBoosterSetParam(handle_pointer, params.to_s, value.to_s)
check_call FFI.XGBoosterSetParam(handle, params.to_s, value.to_s)
end
end

def update(dtrain, iteration)
check_call FFI.XGBoosterUpdateOneIter(handle_pointer, iteration, dtrain.handle_pointer)
check_call FFI.XGBoosterUpdateOneIter(handle, iteration, dtrain.handle)
end

def eval_set(evals, iteration)
dmats = array_of_pointers(evals.map { |v| v[0].handle_pointer })
dmats = array_of_pointers(evals.map { |v| v[0].handle })
evnames = array_of_pointers(evals.map { |v| string_pointer(v[1]) })

out_result = ::FFI::MemoryPointer.new(:pointer)

check_call FFI.XGBoosterEvalOneIter(handle_pointer, iteration, dmats, evnames, evals.size, out_result)
check_call FFI.XGBoosterEvalOneIter(handle, iteration, dmats, evnames, evals.size, out_result)

out_result.read_pointer.read_string
end
Expand All @@ -117,15 +117,15 @@ def predict(data, ntree_limit: nil)
ntree_limit ||= 0
out_len = ::FFI::MemoryPointer.new(:uint64)
out_result = ::FFI::MemoryPointer.new(:pointer)
check_call FFI.XGBoosterPredict(handle_pointer, data.handle_pointer, 0, ntree_limit, 0, out_len, out_result)
check_call FFI.XGBoosterPredict(handle, data.handle, 0, ntree_limit, 0, out_len, out_result)
out = out_result.read_pointer.read_array_of_float(read_uint64(out_len))
num_class = out.size / data.num_row
out = out.each_slice(num_class).to_a if num_class > 1
out
end

def save_model(fname)
check_call FFI.XGBoosterSaveModel(handle_pointer, fname)
check_call FFI.XGBoosterSaveModel(handle, fname)
end

def best_iteration
Expand All @@ -146,7 +146,7 @@ def best_score=(score)

def num_boosted_rounds
rounds = ::FFI::MemoryPointer.new(:int)
check_call FFI.XGBoosterBoostedRounds(handle_pointer, rounds)
check_call FFI.XGBoosterBoostedRounds(handle, rounds)
rounds.read_int
end

Expand Down Expand Up @@ -178,7 +178,7 @@ def dump(fmap: "", with_stats: false, dump_format: "text")
fnames = array_of_pointers(names.map { |fname| string_pointer(fname) })
ftypes = array_of_pointers(feature_types || Array.new(names.size, string_pointer("float")))

check_call FFI.XGBoosterDumpModelExWithFeatures(handle_pointer, names.size, fnames, ftypes, with_stats ? 1 : 0, dump_format, out_len, out_result)
check_call FFI.XGBoosterDumpModelExWithFeatures(handle, names.size, fnames, ftypes, with_stats ? 1 : 0, dump_format, out_len, out_result)

out_result.read_pointer.get_array_of_string(0, read_uint64(out_len))
end
Expand Down Expand Up @@ -249,7 +249,7 @@ def score(fmap: "", importance_type: "weight")

private

def handle_pointer
def handle
@handle
end

Expand Down Expand Up @@ -289,7 +289,7 @@ def get_feature_info(field)
end
check_call(
FFI.XGBoosterGetStrFeatureInfo(
handle_pointer,
handle,
field,
length,
sarr
Expand All @@ -315,7 +315,7 @@ def set_feature_info(features, field)
c_feature_info = array_of_pointers(features.map { |f| string_pointer(f) })
check_call(
FFI.XGBoosterSetStrFeatureInfo(
handle_pointer,
handle,
field,
c_feature_info,
features.length
Expand All @@ -324,7 +324,7 @@ def set_feature_info(features, field)
else
check_call(
FFI.XGBoosterSetStrFeatureInfo(
handle_pointer, field, nil, 0
handle, field, nil, 0
)
)
end
Expand Down
20 changes: 8 additions & 12 deletions lib/xgboost/dmatrix.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module XGBoost
class DMatrix
include Utils

attr_reader :data, :feature_names, :feature_types
attr_reader :data, :feature_names, :feature_types, :handle

def initialize(data, label: nil, weight: nil, missing: Float::NAN)
@data = data
Expand Down Expand Up @@ -86,37 +86,33 @@ def weight=(weight)
def group=(group)
c_data = ::FFI::MemoryPointer.new(:int, group.size)
c_data.write_array_of_int(group)
check_call FFI.XGDMatrixSetUIntInfo(handle_pointer, "group", c_data, group.size)
check_call FFI.XGDMatrixSetUIntInfo(handle, "group", c_data, group.size)
end

def num_row
out = ::FFI::MemoryPointer.new(:uint64)
check_call FFI.XGDMatrixNumRow(handle_pointer, out)
check_call FFI.XGDMatrixNumRow(handle, out)
read_uint64(out)
end

def num_col
out = ::FFI::MemoryPointer.new(:uint64)
check_call FFI.XGDMatrixNumCol(handle_pointer, out)
check_call FFI.XGDMatrixNumCol(handle, out)
read_uint64(out)
end

def slice(rindex)
idxset = ::FFI::MemoryPointer.new(:int, rindex.count)
idxset.write_array_of_int(rindex)
out = ::FFI::MemoryPointer.new(:pointer)
check_call FFI.XGDMatrixSliceDMatrix(handle_pointer, idxset, rindex.size, out)
check_call FFI.XGDMatrixSliceDMatrix(handle, idxset, rindex.size, out)

handle = ::FFI::AutoPointer.new(out.read_pointer, FFI.method(:XGDMatrixFree))
DMatrix.new(handle)
end

def save_binary(fname, silent: true)
check_call FFI.XGDMatrixSaveBinary(handle_pointer, fname, silent ? 1 : 0)
end

def handle_pointer
@handle
check_call FFI.XGDMatrixSaveBinary(handle, fname, silent ? 1 : 0)
end

private
Expand All @@ -125,14 +121,14 @@ def set_float_info(field, data)
data = data.to_a unless data.is_a?(Array)
c_data = ::FFI::MemoryPointer.new(:float, data.size)
c_data.write_array_of_float(data)
check_call FFI.XGDMatrixSetFloatInfo(handle_pointer, field.to_s, c_data, data.size)
check_call FFI.XGDMatrixSetFloatInfo(handle, field.to_s, c_data, data.size)
end

def float_info(field)
num_row ||= num_row()
out_len = ::FFI::MemoryPointer.new(:int)
out_dptr = ::FFI::MemoryPointer.new(:float, num_row)
check_call FFI.XGDMatrixGetFloatInfo(handle_pointer, field, out_len, out_dptr)
check_call FFI.XGDMatrixGetFloatInfo(handle, field, out_len, out_dptr)
out_dptr.read_pointer.read_array_of_float(num_row)
end

Expand Down

0 comments on commit 4c35b78

Please sign in to comment.