From 211c2059d377d910a4d5b92975d29d1d7adb2293 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Tue, 12 Nov 2024 11:55:05 -0800 Subject: [PATCH] Prefer feature_name over feature_names to match Python --- lib/lightgbm/dataset.rb | 20 +++++++++++--------- test/dataset_test.rb | 18 +++++++++--------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/lib/lightgbm/dataset.rb b/lib/lightgbm/dataset.rb index 72036ba..9426e19 100644 --- a/lib/lightgbm/dataset.rb +++ b/lib/lightgbm/dataset.rb @@ -2,7 +2,7 @@ module LightGBM class Dataset attr_reader :data, :params - def initialize(data, label: nil, weight: nil, group: nil, params: nil, reference: nil, used_indices: nil, categorical_feature: "auto", feature_names: nil) + def initialize(data, label: nil, weight: nil, group: nil, params: nil, reference: nil, used_indices: nil, categorical_feature: "auto", feature_name: nil, feature_names: nil) @data = data @label = label @weight = weight @@ -11,7 +11,7 @@ def initialize(data, label: nil, weight: nil, group: nil, params: nil, reference @reference = reference @used_indices = used_indices @categorical_feature = categorical_feature - @feature_names = feature_names + @feature_name = feature_name || feature_names construct end @@ -24,7 +24,7 @@ def weight field("weight") end - def feature_names + def feature_name # must preallocate space num_feature_names = ::FFI::MemoryPointer.new(:int) out_buffer_len = ::FFI::MemoryPointer.new(:size_t) @@ -48,6 +48,7 @@ def feature_names # from most recent call (instead of num_features) str_ptrs[0, num_feature_names.read_int].map(&:read_string) end + alias_method :feature_names, :feature_name def label=(label) @label = label @@ -64,12 +65,13 @@ def group=(group) set_field("group", group, type: :int32) end - def feature_names=(feature_names) + def feature_name=(feature_names) @feature_names = feature_names c_feature_names = ::FFI::MemoryPointer.new(:pointer, feature_names.size) c_feature_names.write_array_of_pointer(feature_names.map { |v| ::FFI::MemoryPointer.from_string(v) }) check_result FFI.LGBM_DatasetSetFeatureNames(handle_pointer, c_feature_names, feature_names.size) end + alias_method :feature_names=, :feature_name= # TODO only update reference if not in chain def reference=(reference) @@ -142,16 +144,16 @@ def construct ncol = data.column_count flat_data = data.to_a.flatten elsif daru?(data) - if @feature_names == "auto" - @feature_names = data.vectors.to_a + if @feature_name == "auto" + @feature_name = data.vectors.to_a end nrow, ncol = data.shape flat_data = data.map_rows(&:to_a).flatten elsif numo?(data) nrow, ncol = data.shape elsif rover?(data) - if @feature_names == "auto" - @feature_names = data.keys + if @feature_name == "auto" + @feature_name = data.keys end data = data.to_numo nrow, ncol = data.shape @@ -176,7 +178,7 @@ def construct self.label = @label if @label self.weight = @weight if @weight self.group = @group if @group - self.feature_names = @feature_names if @feature_names + self.feature_name = @feature_name if @feature_name end def dump_text(filename) diff --git a/test/dataset_test.rb b/test/dataset_test.rb index f417b23..28d4be6 100644 --- a/test/dataset_test.rb +++ b/test/dataset_test.rb @@ -22,10 +22,10 @@ def test_weight assert weight, dataset.weight end - def test_feature_names + def test_feature_name data = [[1, 2], [3, 4]] - dataset = LightGBM::Dataset.new(data, feature_names: ["a", "b"]) - assert_equal ["a", "b"], dataset.feature_names + dataset = LightGBM::Dataset.new(data, feature_name: ["a", "b"]) + assert_equal ["a", "b"], dataset.feature_name end def test_num_data @@ -61,9 +61,9 @@ def test_daru label = data["y"] data = data.delete_vector("y") dataset = LightGBM::Dataset.new(data, label: label) - assert_equal ["Column_0", "Column_1", "Column_2", "Column_3"], dataset.feature_names + assert_equal ["Column_0", "Column_1", "Column_2", "Column_3"], dataset.feature_name - dataset = LightGBM::Dataset.new(data, label: label, feature_names: "auto") + dataset = LightGBM::Dataset.new(data, label: label, feature_name: "auto") assert_equal ["x0", "x1", "x2", "x3"], dataset.feature_names end @@ -74,7 +74,7 @@ def test_numo data = Numo::DFloat.new(3, 5).seq label = Numo::DFloat.new(3).seq dataset = LightGBM::Dataset.new(data, label: label) - assert_equal ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"], dataset.feature_names + assert_equal ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"], dataset.feature_name end def test_rover @@ -84,10 +84,10 @@ def test_rover data = Rover.read_csv(data_path) label = data.delete("y") dataset = LightGBM::Dataset.new(data, label: label) - assert_equal ["Column_0", "Column_1", "Column_2", "Column_3"], dataset.feature_names + assert_equal ["Column_0", "Column_1", "Column_2", "Column_3"], dataset.feature_name - dataset = LightGBM::Dataset.new(data, label: label, feature_names: "auto") - assert_equal ["x0", "x1", "x2", "x3"], dataset.feature_names + dataset = LightGBM::Dataset.new(data, label: label, feature_name: "auto") + assert_equal ["x0", "x1", "x2", "x3"], dataset.feature_name end def test_copy