Skip to content

Commit

Permalink
Prefer feature_name over feature_names to match Python
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Nov 12, 2024
1 parent e32683a commit 211c205
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
20 changes: 11 additions & 9 deletions lib/lightgbm/dataset.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module LightGBM
class Dataset
attr_reader :data, :params

def initialize(data, label: nil, weight: nil, group: nil, params: nil, reference: nil, used_indices: nil, categorical_feature: "auto", feature_names: nil)
def initialize(data, label: nil, weight: nil, group: nil, params: nil, reference: nil, used_indices: nil, categorical_feature: "auto", feature_name: nil, feature_names: nil)
@data = data
@label = label
@weight = weight
Expand All @@ -11,7 +11,7 @@ def initialize(data, label: nil, weight: nil, group: nil, params: nil, reference
@reference = reference
@used_indices = used_indices
@categorical_feature = categorical_feature
@feature_names = feature_names
@feature_name = feature_name || feature_names

construct
end
Expand All @@ -24,7 +24,7 @@ def weight
field("weight")
end

def feature_names
def feature_name
# must preallocate space
num_feature_names = ::FFI::MemoryPointer.new(:int)
out_buffer_len = ::FFI::MemoryPointer.new(:size_t)
Expand All @@ -48,6 +48,7 @@ def feature_names
# from most recent call (instead of num_features)
str_ptrs[0, num_feature_names.read_int].map(&:read_string)
end
alias_method :feature_names, :feature_name

def label=(label)
@label = label
Expand All @@ -64,12 +65,13 @@ def group=(group)
set_field("group", group, type: :int32)
end

def feature_names=(feature_names)
def feature_name=(feature_names)
@feature_names = feature_names
c_feature_names = ::FFI::MemoryPointer.new(:pointer, feature_names.size)
c_feature_names.write_array_of_pointer(feature_names.map { |v| ::FFI::MemoryPointer.from_string(v) })
check_result FFI.LGBM_DatasetSetFeatureNames(handle_pointer, c_feature_names, feature_names.size)
end
alias_method :feature_names=, :feature_name=

# TODO only update reference if not in chain
def reference=(reference)
Expand Down Expand Up @@ -142,16 +144,16 @@ def construct
ncol = data.column_count
flat_data = data.to_a.flatten
elsif daru?(data)
if @feature_names == "auto"
@feature_names = data.vectors.to_a
if @feature_name == "auto"
@feature_name = data.vectors.to_a
end
nrow, ncol = data.shape
flat_data = data.map_rows(&:to_a).flatten
elsif numo?(data)
nrow, ncol = data.shape
elsif rover?(data)
if @feature_names == "auto"
@feature_names = data.keys
if @feature_name == "auto"
@feature_name = data.keys
end
data = data.to_numo
nrow, ncol = data.shape
Expand All @@ -176,7 +178,7 @@ def construct
self.label = @label if @label
self.weight = @weight if @weight
self.group = @group if @group
self.feature_names = @feature_names if @feature_names
self.feature_name = @feature_name if @feature_name
end

def dump_text(filename)
Expand Down
18 changes: 9 additions & 9 deletions test/dataset_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def test_weight
assert weight, dataset.weight
end

def test_feature_names
def test_feature_name
data = [[1, 2], [3, 4]]
dataset = LightGBM::Dataset.new(data, feature_names: ["a", "b"])
assert_equal ["a", "b"], dataset.feature_names
dataset = LightGBM::Dataset.new(data, feature_name: ["a", "b"])
assert_equal ["a", "b"], dataset.feature_name
end

def test_num_data
Expand Down Expand Up @@ -61,9 +61,9 @@ def test_daru
label = data["y"]
data = data.delete_vector("y")
dataset = LightGBM::Dataset.new(data, label: label)
assert_equal ["Column_0", "Column_1", "Column_2", "Column_3"], dataset.feature_names
assert_equal ["Column_0", "Column_1", "Column_2", "Column_3"], dataset.feature_name

dataset = LightGBM::Dataset.new(data, label: label, feature_names: "auto")
dataset = LightGBM::Dataset.new(data, label: label, feature_name: "auto")
assert_equal ["x0", "x1", "x2", "x3"], dataset.feature_names
end

Expand All @@ -74,7 +74,7 @@ def test_numo
data = Numo::DFloat.new(3, 5).seq
label = Numo::DFloat.new(3).seq
dataset = LightGBM::Dataset.new(data, label: label)
assert_equal ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"], dataset.feature_names
assert_equal ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"], dataset.feature_name
end

def test_rover
Expand All @@ -84,10 +84,10 @@ def test_rover
data = Rover.read_csv(data_path)
label = data.delete("y")
dataset = LightGBM::Dataset.new(data, label: label)
assert_equal ["Column_0", "Column_1", "Column_2", "Column_3"], dataset.feature_names
assert_equal ["Column_0", "Column_1", "Column_2", "Column_3"], dataset.feature_name

dataset = LightGBM::Dataset.new(data, label: label, feature_names: "auto")
assert_equal ["x0", "x1", "x2", "x3"], dataset.feature_names
dataset = LightGBM::Dataset.new(data, label: label, feature_name: "auto")
assert_equal ["x0", "x1", "x2", "x3"], dataset.feature_name
end

def test_copy
Expand Down

0 comments on commit 211c205

Please sign in to comment.