Skip to content

Commit

Permalink
include float preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
AnFreTh committed Dec 3, 2024
1 parent a696987 commit c354c00
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
12 changes: 12 additions & 0 deletions mambular/preprocessing/prepro_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,15 @@ def get_feature_names_out(self, input_features=None):
"input_features must be provided to generate feature names."
)
return np.array(input_features)


class ToFloatTransformer(TransformerMixin, BaseEstimator):
"""
A transformer that converts input data to float type.
"""

def fit(self, X, y=None):
return self

def transform(self, X):
return X.astype(float)
19 changes: 12 additions & 7 deletions mambular/preprocessing/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
CustomBinner,
OneHotFromOrdinal,
NoTransformer,
ToFloatTransformer,
)


Expand Down Expand Up @@ -339,6 +340,7 @@ def fit(self, X, y=None):
[
("imputer", SimpleImputer(strategy="most_frequent")),
("onehot", OneHotEncoder()),
("to_float", ToFloatTransformer()),
]
)

Expand Down Expand Up @@ -453,17 +455,20 @@ def _split_transformed_output(self, X, transformed_X):
"""
start = 0
transformed_dict = {}
for (
name,
transformer,
columns,
) in self.column_transformer.transformers_:
for name, transformer, columns in self.column_transformer.transformers_:
if transformer != "drop":
end = start + transformer.transform(X[[columns[0]]]).shape[1]
dtype = int if "cat" in name else float

# Determine dtype based on the transformer steps
transformer_steps = [step[0] for step in transformer.steps]
if "continuous_ordinal" in transformer_steps:
dtype = int # Use int for ordinal/integer encoding
else:
dtype = float # Default to float for other encodings

# Assign transformed data with the correct dtype
transformed_dict[name] = transformed_X[:, start:end].astype(dtype)
start = end

return transformed_dict

def fit_transform(self, X, y=None):
Expand Down

0 comments on commit c354c00

Please sign in to comment.