-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for categorical feature auto-encoding
- Loading branch information
1 parent
3f601b0
commit fa23e11
Showing
5 changed files
with
2,084 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
require "json" | ||
|
||
module LightGBM | ||
# Converts LightGBM categorical featulres to Float, using label encoding. | ||
# The categorical and mappings are extracted from the LightGBM model file. | ||
class CategoricalFeatureEncoder | ||
# Initializes a new CategoricalFeatureEncoder instance. | ||
# | ||
# @param model_enumerable [Enumerable] Enumerable with each line of LightGBM model file. | ||
def initialize(model_enumerable) | ||
@categorical_feature = [] | ||
@pandas_categorical = [] | ||
|
||
load_categorical_features(model_enumerable) | ||
end | ||
|
||
# Returns a new array with categorical features converted to Float, using label encoding. | ||
def apply(feature_values) | ||
return feature_values if @categorical_feature.empty? | ||
|
||
transformed_features = feature_values.dup | ||
|
||
@categorical_feature.each_with_index do |feature_index, pandas_categorical_index| | ||
value = feature_values[feature_index] | ||
|
||
pandas_categorical_entry = @pandas_categorical[pandas_categorical_index] | ||
transformed_value = pandas_categorical_entry[value] | ||
transformed_features[feature_index] = transformed_value.nil? ? Float::NAN : transformed_value.to_f | ||
end | ||
|
||
transformed_features | ||
end | ||
|
||
private | ||
|
||
def load_categorical_features(model_enumerable) | ||
categorical_found = false | ||
pandas_found = false | ||
|
||
model_enumerable.each_entry do |line| | ||
# Format: "[categorical_feature: 0,1,2,3,4,5]" | ||
if line.start_with?("[categorical_feature:") | ||
parts = line.split("categorical_feature:") | ||
last_part = parts.last | ||
next if last_part.nil? | ||
|
||
values = last_part.strip[0...-1] | ||
next if values.nil? | ||
|
||
@categorical_feature = values.split(",").map(&:to_i) | ||
categorical_found = true | ||
end | ||
|
||
# Format: "pandas_categorical:[[-1.0, 0.0, 1.0], ["", "shop_pay"], [false, true]]" | ||
if line.start_with?("pandas_categorical:") | ||
parts = line.split("pandas_categorical:") | ||
values = parts[1] | ||
next if values.nil? | ||
|
||
@pandas_categorical = JSON.parse(values).map do |array| | ||
array.each_with_index.to_h | ||
end | ||
pandas_found = true | ||
end | ||
|
||
# Break the loop if both lines are found | ||
break if categorical_found && pandas_found | ||
end | ||
|
||
if @categorical_feature.size != @pandas_categorical.size | ||
raise "categorical_feature and pandas_categorical mismatch" | ||
end | ||
end | ||
|
||
def transform_categorical_value(value, index) | ||
pandas_categorical_entry = @pandas_categorical[index] | ||
transformed_value = pandas_categorical_entry.find_index(value) | ||
|
||
transformed_value.nil? ? Float::NAN : transformed_value.to_f | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
require_relative "test_helper" | ||
|
||
class CategoricalFeatureEncoder < Minitest::Test | ||
def setup | ||
model = <<~MODEL | ||
[categorical_feature: 1,2,3] | ||
pandas_categorical:[[-1.0, 0.0, 1.0], ["red", "green", "blue"], [false, true]] | ||
MODEL | ||
|
||
@encoder = LightGBM::CategoricalFeatureEncoder.new(model.each_line) | ||
end | ||
|
||
def test_apply_with_categorical_features | ||
input = [42.0, 0.0, "green", true] | ||
expected = [42.0, 1.0, 1.0, 1.0] | ||
|
||
assert_equal(expected, @encoder.apply(input)) | ||
end | ||
|
||
def test_apply_with_non_categorical_features | ||
input = [42.0, "non_categorical", 39.0, false] | ||
expected = [42.0, Float::NAN, Float::NAN, 0] | ||
|
||
assert_equal(expected, @encoder.apply(input)) | ||
end | ||
|
||
def test_apply_with_missing_values | ||
input = [42.0, nil, "red", nil] | ||
expected = [42.0, Float::NAN, 0.0, Float::NAN] | ||
result = @encoder.apply(input) | ||
|
||
assert_equal(expected, result) | ||
end | ||
|
||
def test_apply_with_boolean_values | ||
input = [42.0, -1.0, "green", false] | ||
expected = [42.0, 0.0, 1.0, 0.0] | ||
|
||
assert_equal(expected, @encoder.apply(input)) | ||
end | ||
end |
Oops, something went wrong.