Skip to content

Commit

Permalink
Add support for categorical feature auto-encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
nunosilva800 committed Oct 18, 2024
1 parent 3f601b0 commit fa23e11
Show file tree
Hide file tree
Showing 5 changed files with 2,084 additions and 1 deletion.
11 changes: 10 additions & 1 deletion lib/lightgbm/booster.rb
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
require_relative "categorical_feature_encoder"

module LightGBM
class Booster
attr_accessor :best_iteration, :train_data_name
Expand All @@ -6,9 +8,11 @@ def initialize(params: nil, train_set: nil, model_file: nil, model_str: nil)
@handle = ::FFI::MemoryPointer.new(:pointer)
if model_str
model_from_string(model_str)
@categorical_feature_encoder = CategoricalFeatureEncoder.new(model_str.each_line)
elsif model_file
out_num_iterations = ::FFI::MemoryPointer.new(:int)
check_result FFI.LGBM_BoosterCreateFromModelfile(model_file, out_num_iterations, @handle)
@categorical_feature_encoder = CategoricalFeatureEncoder.new(File.foreach(model_file))
else
params ||= {}
set_verbosity(params)
Expand Down Expand Up @@ -152,7 +156,12 @@ def predict(input, start_iteration: nil, num_iteration: nil, **params)
num_iteration ||= best_iteration
num_class ||= num_class()

flat_input = input.flatten
flat_input = if @categorical_feature_encoder
input.flat_map { |row| @categorical_feature_encoder.apply(row) }
else
input.flatten
end

handle_missing(flat_input)
data = ::FFI::MemoryPointer.new(:double, input.count * input.first.count)
data.write_array_of_double(flat_input)
Expand Down
82 changes: 82 additions & 0 deletions lib/lightgbm/categorical_feature_encoder.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
require "json"

module LightGBM
# Converts LightGBM categorical featulres to Float, using label encoding.
# The categorical and mappings are extracted from the LightGBM model file.
class CategoricalFeatureEncoder
# Initializes a new CategoricalFeatureEncoder instance.
#
# @param model_enumerable [Enumerable] Enumerable with each line of LightGBM model file.
def initialize(model_enumerable)
@categorical_feature = []
@pandas_categorical = []

load_categorical_features(model_enumerable)
end

# Returns a new array with categorical features converted to Float, using label encoding.
def apply(feature_values)
return feature_values if @categorical_feature.empty?

transformed_features = feature_values.dup

@categorical_feature.each_with_index do |feature_index, pandas_categorical_index|
value = feature_values[feature_index]

pandas_categorical_entry = @pandas_categorical[pandas_categorical_index]
transformed_value = pandas_categorical_entry[value]
transformed_features[feature_index] = transformed_value.nil? ? Float::NAN : transformed_value.to_f
end

transformed_features
end

private

def load_categorical_features(model_enumerable)
categorical_found = false
pandas_found = false

model_enumerable.each_entry do |line|
# Format: "[categorical_feature: 0,1,2,3,4,5]"
if line.start_with?("[categorical_feature:")
parts = line.split("categorical_feature:")
last_part = parts.last
next if last_part.nil?

values = last_part.strip[0...-1]
next if values.nil?

@categorical_feature = values.split(",").map(&:to_i)
categorical_found = true
end

# Format: "pandas_categorical:[[-1.0, 0.0, 1.0], ["", "shop_pay"], [false, true]]"
if line.start_with?("pandas_categorical:")
parts = line.split("pandas_categorical:")
values = parts[1]
next if values.nil?

@pandas_categorical = JSON.parse(values).map do |array|
array.each_with_index.to_h
end
pandas_found = true
end

# Break the loop if both lines are found
break if categorical_found && pandas_found
end

if @categorical_feature.size != @pandas_categorical.size
raise "categorical_feature and pandas_categorical mismatch"
end
end

def transform_categorical_value(value, index)
pandas_categorical_entry = @pandas_categorical[index]
transformed_value = pandas_categorical_entry.find_index(value)

transformed_value.nil? ? Float::NAN : transformed_value.to_f
end
end
end
14 changes: 14 additions & 0 deletions test/booster_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,27 @@ def test_model_file
assert_elements_in_delta [0.9823112229173586, 0.9583143724610858], y_pred.first(2)
end

def test_model_file_with_categorical_features
x_test = [[false, "green", 7.2, 9.0], [true, "blue", 7.9, 0.0]]
booster = LightGBM::Booster.new(model_file: "test/support/model_with_categorical_features.txt")
y_pred = booster.predict(x_test)
assert_elements_in_delta [0.9948804305465, 0.792909968121466], y_pred.first(2)
end

def test_model_str
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
booster = LightGBM::Booster.new(model_str: File.read("test/support/model.txt"))
y_pred = booster.predict(x_test)
assert_elements_in_delta [0.9823112229173586, 0.9583143724610858], y_pred.first(2)
end

def test_model_str_with_categorical_features
x_test = [[false, "green", 7.2, 9.0], [true, "blue", 7.9, 0.0]]
booster = LightGBM::Booster.new(model_str: File.read("test/support/model_with_categorical_features.txt"))
y_pred = booster.predict(x_test)
assert_elements_in_delta [0.9948804305465, 0.792909968121466], y_pred.first(2)
end

def test_feature_importance
assert_equal [280, 285, 335, 148], booster.feature_importance
end
Expand Down
41 changes: 41 additions & 0 deletions test/categorical_feature_encoder_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
require_relative "test_helper"

class CategoricalFeatureEncoder < Minitest::Test
def setup
model = <<~MODEL
[categorical_feature: 1,2,3]
pandas_categorical:[[-1.0, 0.0, 1.0], ["red", "green", "blue"], [false, true]]
MODEL

@encoder = LightGBM::CategoricalFeatureEncoder.new(model.each_line)
end

def test_apply_with_categorical_features
input = [42.0, 0.0, "green", true]
expected = [42.0, 1.0, 1.0, 1.0]

assert_equal(expected, @encoder.apply(input))
end

def test_apply_with_non_categorical_features
input = [42.0, "non_categorical", 39.0, false]
expected = [42.0, Float::NAN, Float::NAN, 0]

assert_equal(expected, @encoder.apply(input))
end

def test_apply_with_missing_values
input = [42.0, nil, "red", nil]
expected = [42.0, Float::NAN, 0.0, Float::NAN]
result = @encoder.apply(input)

assert_equal(expected, result)
end

def test_apply_with_boolean_values
input = [42.0, -1.0, "green", false]
expected = [42.0, 0.0, 1.0, 0.0]

assert_equal(expected, @encoder.apply(input))
end
end
Loading

0 comments on commit fa23e11

Please sign in to comment.