diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 6ae8a9d4f6..83ade9bbc3 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -15,7 +15,7 @@ from feast.infra.key_encoding_utils import serialize_entity_key from feast.infra.online_stores.online_store import OnlineStore from feast.infra.utils.postgres.connection_utils import _get_connection_pool, _get_conn -from feast.infra.utils.postgres.postgres_config import PostgreSQLConfig, Connection +from feast.infra.utils.postgres.postgres_config import PostgreSQLConfig, ConnectionType from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RepoConfig @@ -32,7 +32,7 @@ class PostgreSQLOnlineStore(OnlineStore): @contextlib.contextmanager def _get_conn(self, config: RepoConfig): assert config.online_store.type == "postgres" - if config.online_store.conn_type == Connection.pool: + if config.online_store.conn_type == ConnectionType.pool: if not self._conn_pool: self._conn_pool = _get_connection_pool(config.online_store) connection = self._conn_pool.getconn() diff --git a/sdk/python/feast/infra/utils/postgres/postgres_config.py b/sdk/python/feast/infra/utils/postgres/postgres_config.py index cff8dbea7f..47b6efa173 100644 --- a/sdk/python/feast/infra/utils/postgres/postgres_config.py +++ b/sdk/python/feast/infra/utils/postgres/postgres_config.py @@ -5,14 +5,14 @@ from feast.repo_config import FeastConfigBaseModel -class Connection(Enum): +class ConnectionType(Enum): singleton = 'singleton' pool = 'pool' class PostgreSQLConfig(FeastConfigBaseModel): min_conn: int = 1 max_conn: int = 10 - conn_type: Connection = Connection.singleton + conn_type: ConnectionType = ConnectionType.singleton host: StrictStr port: int = 5432 database: StrictStr diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index e1ae5f7a42..d53bf96882 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -393,3 +393,17 @@ def feature_store_for_online_retrieval( ] return fs, feature_refs, entity_rows + + +@pytest.fixture +def fake_ingest_data(): + """Fake data to ingest into the feature store""" + data = { + "driver_id": [1], + "conv_rate": [0.5], + "acc_rate": [0.6], + "avg_daily_trips": [4], + "event_timestamp": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")], + "created": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")], + } + return pd.DataFrame(data) diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 51f39a5667..484a1cd9b5 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -30,11 +30,46 @@ driver_feature_view, ) from tests.utils.data_source_test_creator import prep_file_source +from feast.infra.utils.postgres.postgres_config import ConnectionType + + +@pytest.mark.integration +@pytest.mark.universal_online_stores(only=["postgres"]) +def test_connection_pool_online_stores(environment, universal_data_sources, fake_ingest_data): + if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True": + return + fs = environment.feature_store + fs.config.online_store.conn_type = ConnectionType.pool + fs.config.online_store.min_conn = 1 + fs.config.online_store.max_conn = 10 + + entities, datasets, data_sources = universal_data_sources + driver_hourly_stats = create_driver_hourly_stats_feature_view(data_sources.driver) + driver_entity = driver() + + # Register Feature View and Entity + fs.apply([driver_hourly_stats, driver_entity]) + + # directly ingest data into the Online Store + fs.write_to_online_store("driver_stats", fake_ingest_data) + + # assert the right data is in the Online Store + df = fs.get_online_features( + features=[ + "driver_stats:avg_daily_trips", + "driver_stats:acc_rate", + "driver_stats:conv_rate", + ], + entity_rows=[{"driver_id": 1}], + ).to_df() + assertpy.assert_that(df["avg_daily_trips"].iloc[0]).is_equal_to(4) + assertpy.assert_that(df["acc_rate"].iloc[0]).is_close_to(0.6, 1e-6) + assertpy.assert_that(df["conv_rate"].iloc[0]).is_close_to(0.5, 1e-6) @pytest.mark.integration @pytest.mark.universal_online_stores(only=["redis"]) -def test_entity_ttl_online_store(environment, universal_data_sources): +def test_entity_ttl_online_store(environment, universal_data_sources, fake_ingest_data): if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True": return fs = environment.feature_store @@ -47,19 +82,8 @@ def test_entity_ttl_online_store(environment, universal_data_sources): # Register Feature View and Entity fs.apply([driver_hourly_stats, driver_entity]) - # fake data to ingest into Online Store - data = { - "driver_id": [1], - "conv_rate": [0.5], - "acc_rate": [0.6], - "avg_daily_trips": [4], - "event_timestamp": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")], - "created": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")], - } - df_ingest = pd.DataFrame(data) - # directly ingest data into the Online Store - fs.write_to_online_store("driver_stats", df_ingest) + fs.write_to_online_store("driver_stats", fake_ingest_data) # assert the right data is in the Online Store df = fs.get_online_features(