Skip to content

Commit

Permalink
Improved error message for invalid arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Dec 16, 2024
1 parent a15c129 commit 923528e
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lib/lightgbm/dataset.rb
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def construct
ncol = data.first.count
flat_data = data.flat_map { |v| v.fetch_values(*keys) }
else
data = data.to_a
check_2d_array(data)
nrow = data.count
ncol = data.first.count
flat_data = data.flatten
Expand Down
1 change: 1 addition & 0 deletions lib/lightgbm/inner_predictor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def predict(data, start_iteration: 0, num_iteration: -1, raw_score: false, pred_
data = data.to_a
singular = !data.first.is_a?(Array)
data = [data] if singular
check_2d_array(data)
end

preds, nrow =
Expand Down
7 changes: 7 additions & 0 deletions lib/lightgbm/utils.rb
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def set_verbosity(params)
end
end

def check_2d_array(data)
ncol = data.first&.size || 0
if !data.all? { |r| r.size == ncol }
raise ArgumentError, "Rows have different sizes"
end
end

# for categorical, NaN and negative value are the same
def handle_missing(data)
data.map! { |v| v.nil? ? Float::NAN : v }
Expand Down
8 changes: 8 additions & 0 deletions test/booster_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ def test_predict_rover
end
end

def test_predict_array_different_sizes
x_test = [[1, 2], [3, 4, 5]]
error = assert_raises(ArgumentError) do
booster.predict(x_test)
end
assert_equal "Rows have different sizes", error.message
end

def test_predict_raw_score
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
expected = [0.9823112229173586, 0.9583143724610858]
Expand Down
8 changes: 8 additions & 0 deletions test/dataset_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ def test_rover
assert_equal ["x0", "x1", "x2", "x3"], dataset.feature_name
end

def test_array_different_sizes
data = [[1, 2], [3, 4, 5]]
error = assert_raises(ArgumentError) do
LightGBM::Dataset.new(data)
end
assert_equal "Rows have different sizes", error.message
end

def test_copy
regression_train.dup
regression_train.clone
Expand Down

0 comments on commit 923528e

Please sign in to comment.