From de8ba0acdd46ca29a35e84dd53fe7284a34386b6 Mon Sep 17 00:00:00 2001 From: RektPunk Date: Wed, 11 Sep 2024 10:52:15 +0900 Subject: [PATCH] add test for encoder --- tests/test_encoder.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/test_encoder.py diff --git a/tests/test_encoder.py b/tests/test_encoder.py new file mode 100644 index 0000000..4ec5d86 --- /dev/null +++ b/tests/test_encoder.py @@ -0,0 +1,43 @@ +import numpy as np +import pandas as pd +import pytest + +from mqboost.base import XdataLike, YdataLike +from mqboost.encoder import MQLabelEncoder + + +# Test data for categorical variables +@pytest.fixture +def sample_data(): + return pd.Series(["apple", "banana", "orange", None, "kiwi", np.nan]) + + +# Test data for label encoding +@pytest.fixture +def sample_label_data(): + return np.array([2, 3, 5, 0, 4, 0]) + + +def test_fit_transform(sample_data): + encoder = MQLabelEncoder() + transformed = encoder.fit_transform(sample_data) + + # Check that the transformed result is numeric + assert transformed is not None + assert transformed.dtype == int + assert len(transformed) == len(sample_data) + + +def test_unseen_and_nan_values(sample_data): + encoder = MQLabelEncoder() + encoder.fit(sample_data) + + # Include new unseen value and check behavior + test_data = pd.Series(["apple", "unknown", None, "melon", np.nan]) + transformed = encoder.transform(test_data) + + # Check for correct handling of unseen and NaN values + assert ( + transformed + == encoder.label_encoder.transform(["apple", "Unseen", "NaN", "Unseen", "NaN"]) + ).all()