Skip to content

Commit

Permalink
Solve adding missing category when predicting
Browse files Browse the repository at this point in the history
  • Loading branch information
stanmart committed Aug 17, 2023
1 parent 9bd3914 commit fc79f44
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 5 deletions.
12 changes: 11 additions & 1 deletion src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
_least_squares_solver,
_trust_constr_solver,
)
from ._util import _align_df_categories, _safe_toarray
from ._util import _add_missing_categories, _align_df_categories, _safe_toarray

_float_itemsize_to_dtype = {8: np.float64, 4: np.float32, 2: np.float16}

Expand Down Expand Up @@ -833,13 +833,23 @@ def _get_start_coef(

def _convert_from_pandas(self, df: pd.DataFrame) -> tm.MatrixBase:
"""Convert a pandas data frame to a tabmat matrix."""

if hasattr(self, "feature_dtypes_"):
df = _align_df_categories(df, self.feature_dtypes_)
if self.cat_missing_method == "convert":
df = _add_missing_categories(
df=df,
dtypes=self.feature_dtypes_,
feature_names=self.feature_names_,
cat_missing_name=self.cat_missing_name,
categorical_format=self.categorical_format,
)

X = tm.from_pandas(
df,
drop_first=self.drop_first,
categorical_format=self.categorical_format,
cat_missing_method=self.cat_missing_method,
)

return X
Expand Down
36 changes: 35 additions & 1 deletion src/glum/_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Union
from typing import Sequence, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -53,6 +53,40 @@ def _align_df_categories(df, dtypes) -> pd.DataFrame:
return df


def _add_missing_categories(
df,
dtypes,
feature_names: Sequence[str],
categorical_format: str,
cat_missing_name: str,
) -> pd.DataFrame:
if not isinstance(df, pd.DataFrame):
raise TypeError(f"Expected `pandas.DataFrame'; got {type(df)}.")

changed_dtypes = {}

categorical_dtypes = [
column
for column, dtype in dtypes.items()
if pd.api.types.is_categorical_dtype(dtype) and (column in df)
]

for column in categorical_dtypes:
if (
categorical_format.format(name=column, category=cat_missing_name)
in feature_names
):
_logger.info(f"Adding missing category {cat_missing_name} to {column}.")
changed_dtypes[column] = df[column].cat.add_categories(cat_missing_name)
if df[column].isnull().any():
changed_dtypes[column] = changed_dtypes[column].fillna(cat_missing_name)

if changed_dtypes:
df = df.assign(**changed_dtypes)

return df


def _safe_lin_pred(
X: Union[MatrixBase, StandardizedMatrix],
coef: np.ndarray,
Expand Down
5 changes: 3 additions & 2 deletions tests/glm/test_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2358,13 +2358,13 @@ def test_store_covariance_matrix_cv(
def test_cat_missing(cat_missing_method):
X = pd.DataFrame(
{
"cat_1": pd.Categorical([1, 2, 1, 2, 1]),
"cat_1": pd.Categorical([1, 2, pd.NA, 2, 1]),
"cat_2": pd.Categorical([1, 2, pd.NA, 1, 2]),
}
)
X_unseen = pd.DataFrame(
{
"cat_1": pd.Categorical([1, 1]),
"cat_1": pd.Categorical([1, pd.NA]),
"cat_2": pd.Categorical([1, 2]),
}
)
Expand All @@ -2385,6 +2385,7 @@ def test_cat_missing(cat_missing_method):
feature_names = ["cat_1[1]", "cat_1[2]", "cat_2[1]", "cat_2[2]"]

if cat_missing_method == "convert":
feature_names.insert(2, "cat_1[(MISSING)]")
feature_names.append("cat_2[(MISSING)]")

np.testing.assert_array_equal(model.feature_names_, feature_names)
Expand Down
51 changes: 50 additions & 1 deletion tests/glm/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
import pytest

from glum._util import _align_df_categories
from glum._util import _add_missing_categories, _align_df_categories


@pytest.fixture()
Expand Down Expand Up @@ -96,3 +96,52 @@ def test_align_df_categories_missing_columns(df):
def test_align_df_categories_not_df():
with pytest.raises(TypeError):
_align_df_categories(np.array([[0], [1]]), {"x0": np.float64})


@pytest.fixture()
def df_na():
return pd.DataFrame(
{
"num": np.array([0, 1], dtype="float64"),
"cat": pd.Categorical(["a", "b"]),
"cat_na": pd.Categorical(["a", pd.NA]),
"cat2": pd.Categorical(["a", "b"]),
}
)


def test_add_missing_categories(df_na):
categorical_format = "{name}[{category}]"
cat_missing_name = "(M)"
dtypes = df_na.dtypes
feature_names = [
"num",
"num[(M)]",
"cat[a]",
"cat[b]",
"cat[(M)]",
"cat_na[a]",
"cat_na[(M)]",
"cat2[a]",
"cat2[b]",
]

expected = pd.DataFrame(
{
"num": np.array([0, 1], dtype="float64"),
"cat": pd.Categorical(["a", "b"], categories=["a", "b", "(M)"]),
"cat_na": pd.Categorical(["a", "(M)"], categories=["a", "(M)"]),
"cat2": pd.Categorical(["a", "b"], categories=["a", "b"]),
}
)

pd.testing.assert_frame_equal(
_add_missing_categories(
df=df_na,
dtypes=dtypes,
feature_names=feature_names,
categorical_format=categorical_format,
cat_missing_name=cat_missing_name,
),
expected,
)

0 comments on commit fc79f44

Please sign in to comment.