Skip to content

Commit

Permalink
Added save_config method to Booster - #11
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 16, 2024
1 parent 14e9731 commit c0542ef
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

- Updated XGBoost to 2.1.1
- Added support for callbacks
- Added `save_config` method to `Booster`
- Dropped support for Ruby < 3.1

## 0.8.0 (2023-09-13)
Expand Down
7 changes: 7 additions & 0 deletions lib/xgboost/booster.rb
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def []=(key_name, raw_value)
set_attr(**{key_name => raw_value})
end

def save_config
length = ::FFI::MemoryPointer.new(:int)
json_string = ::FFI::MemoryPointer.new(:pointer)
check_result FFI.XGBoosterSaveJsonConfig(handle_pointer, length, json_string)
json_string.read_pointer.read_string(length.read_int).force_encoding(Encoding::UTF_8)
end

def attr(key_name)
key = string_pointer(key_name.to_s)
success = ::FFI::MemoryPointer.new(:int)
Expand Down
1 change: 1 addition & 0 deletions lib/xgboost/ffi.rb
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ module FFI
attach_function :XGBoosterPredict, %i[pointer pointer int int int pointer pointer], :int
attach_function :XGBoosterLoadModel, %i[pointer string], :int
attach_function :XGBoosterSaveModel, %i[pointer string], :int
attach_function :XGBoosterSaveJsonConfig, %i[pointer pointer pointer], :int
attach_function :XGBoosterDumpModelExWithFeatures, %i[pointer int pointer pointer int string pointer pointer], :int
attach_function :XGBoosterGetAttr, %i[pointer pointer pointer pointer], :int
attach_function :XGBoosterSetAttr, %i[pointer pointer pointer], :int
Expand Down
6 changes: 6 additions & 0 deletions test/booster_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def test_dump_model_json
assert JSON.parse(File.read(tempfile))
end

def test_save_config
config = booster.save_config
assert_kind_of Hash, JSON.parse(config)
assert_equal Encoding::UTF_8, config.encoding
end

def test_score
expected = {"f0" => 118, "f2" => 93, "f1" => 104, "f3" => 43}
assert_equal expected.values.sort, booster.score.values.sort
Expand Down

0 comments on commit c0542ef

Please sign in to comment.