diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 144b242a1d..a12e66f109 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -1,3 +1,4 @@ +import contextlib import logging from collections import defaultdict from datetime import datetime @@ -7,14 +8,15 @@ import pytz from psycopg2 import sql from psycopg2.extras import execute_values +from psycopg2.pool import SimpleConnectionPool from pydantic.schema import Literal from feast import Entity from feast.feature_view import FeatureView from feast.infra.key_encoding_utils import serialize_entity_key from feast.infra.online_stores.online_store import OnlineStore -from feast.infra.utils.postgres.connection_utils import _get_conn -from feast.infra.utils.postgres.postgres_config import PostgreSQLConfig +from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool +from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RepoConfig @@ -27,12 +29,21 @@ class PostgreSQLOnlineStoreConfig(PostgreSQLConfig): class PostgreSQLOnlineStore(OnlineStore): _conn: Optional[psycopg2._psycopg.connection] = None + _conn_pool: Optional[SimpleConnectionPool] = None + @contextlib.contextmanager def _get_conn(self, config: RepoConfig): - if not self._conn: - assert config.online_store.type == "postgres" - self._conn = _get_conn(config.online_store) - return self._conn + assert config.online_store.type == "postgres" + if config.online_store.conn_type == ConnectionType.pool: + if not self._conn_pool: + self._conn_pool = _get_connection_pool(config.online_store) + connection = self._conn_pool.getconn() + yield connection + self._conn_pool.putconn(connection) + else: + if not self._conn: + self._conn = _get_conn(config.online_store) + yield self._conn @log_exceptions_and_usage(online_store="postgres") def online_write_batch( diff --git a/sdk/python/feast/infra/utils/postgres/connection_utils.py b/sdk/python/feast/infra/utils/postgres/connection_utils.py index 0e9cbf96fe..0d99c8ab99 100644 --- a/sdk/python/feast/infra/utils/postgres/connection_utils.py +++ b/sdk/python/feast/infra/utils/postgres/connection_utils.py @@ -5,6 +5,7 @@ import psycopg2 import psycopg2.extras import pyarrow as pa +from psycopg2.pool import SimpleConnectionPool from feast.infra.utils.postgres.postgres_config import PostgreSQLConfig from feast.type_map import arrow_to_pg_type @@ -22,10 +23,28 @@ def _get_conn(config: PostgreSQLConfig): sslcert=config.sslcert_path, sslrootcert=config.sslrootcert_path, options="-c search_path={}".format(config.db_schema or config.user), + keepalives_idle=config.keepalives_idle, ) return conn +def _get_connection_pool(config: PostgreSQLConfig): + return SimpleConnectionPool( + config.min_conn, + config.max_conn, + dbname=config.database, + host=config.host, + port=int(config.port), + user=config.user, + password=config.password, + sslmode=config.sslmode, + sslkey=config.sslkey_path, + sslcert=config.sslcert_path, + sslrootcert=config.sslrootcert_path, + options="-c search_path={}".format(config.db_schema or config.user), + ) + + def _df_to_create_table_sql(entity_df, table_name) -> str: pa_table = pa.Table.from_pandas(entity_df) columns = [ diff --git a/sdk/python/feast/infra/utils/postgres/postgres_config.py b/sdk/python/feast/infra/utils/postgres/postgres_config.py index f22cc6c204..9fbaed474d 100644 --- a/sdk/python/feast/infra/utils/postgres/postgres_config.py +++ b/sdk/python/feast/infra/utils/postgres/postgres_config.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Optional from pydantic import StrictStr @@ -5,7 +6,15 @@ from feast.repo_config import FeastConfigBaseModel +class ConnectionType(Enum): + singleton = "singleton" + pool = "pool" + + class PostgreSQLConfig(FeastConfigBaseModel): + min_conn: int = 1 + max_conn: int = 10 + conn_type: ConnectionType = ConnectionType.singleton host: StrictStr port: int = 5432 database: StrictStr @@ -16,3 +25,4 @@ class PostgreSQLConfig(FeastConfigBaseModel): sslkey_path: Optional[StrictStr] = None sslcert_path: Optional[StrictStr] = None sslrootcert_path: Optional[StrictStr] = None + keepalives_idle: int = 0 diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index e1ae5f7a42..728bd9b34f 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -393,3 +393,17 @@ def feature_store_for_online_retrieval( ] return fs, feature_refs, entity_rows + + +@pytest.fixture +def fake_ingest_data(): + """Fake data to ingest into the feature store""" + data = { + "driver_id": [1], + "conv_rate": [0.5], + "acc_rate": [0.6], + "avg_daily_trips": [4], + "event_timestamp": [pd.Timestamp(datetime.utcnow()).round("ms")], + "created": [pd.Timestamp(datetime.utcnow()).round("ms")], + } + return pd.DataFrame(data) diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 51f39a5667..8218971315 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -17,6 +17,7 @@ from feast.feature_service import FeatureService from feast.feature_view import FeatureView from feast.field import Field +from feast.infra.utils.postgres.postgres_config import ConnectionType from feast.online_response import TIMESTAMP_POSTFIX from feast.types import Float32, Int32, String from feast.wait import wait_retry_backoff @@ -32,9 +33,45 @@ from tests.utils.data_source_test_creator import prep_file_source +@pytest.mark.integration +@pytest.mark.universal_online_stores(only=["postgres"]) +def test_connection_pool_online_stores( + environment, universal_data_sources, fake_ingest_data +): + if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True": + return + fs = environment.feature_store + fs.config.online_store.conn_type = ConnectionType.pool + fs.config.online_store.min_conn = 1 + fs.config.online_store.max_conn = 10 + + entities, datasets, data_sources = universal_data_sources + driver_hourly_stats = create_driver_hourly_stats_feature_view(data_sources.driver) + driver_entity = driver() + + # Register Feature View and Entity + fs.apply([driver_hourly_stats, driver_entity]) + + # directly ingest data into the Online Store + fs.write_to_online_store("driver_stats", fake_ingest_data) + + # assert the right data is in the Online Store + df = fs.get_online_features( + features=[ + "driver_stats:avg_daily_trips", + "driver_stats:acc_rate", + "driver_stats:conv_rate", + ], + entity_rows=[{"driver_id": 1}], + ).to_df() + assertpy.assert_that(df["avg_daily_trips"].iloc[0]).is_equal_to(4) + assertpy.assert_that(df["acc_rate"].iloc[0]).is_close_to(0.6, 1e-6) + assertpy.assert_that(df["conv_rate"].iloc[0]).is_close_to(0.5, 1e-6) + + @pytest.mark.integration @pytest.mark.universal_online_stores(only=["redis"]) -def test_entity_ttl_online_store(environment, universal_data_sources): +def test_entity_ttl_online_store(environment, universal_data_sources, fake_ingest_data): if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True": return fs = environment.feature_store @@ -47,19 +84,8 @@ def test_entity_ttl_online_store(environment, universal_data_sources): # Register Feature View and Entity fs.apply([driver_hourly_stats, driver_entity]) - # fake data to ingest into Online Store - data = { - "driver_id": [1], - "conv_rate": [0.5], - "acc_rate": [0.6], - "avg_daily_trips": [4], - "event_timestamp": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")], - "created": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")], - } - df_ingest = pd.DataFrame(data) - # directly ingest data into the Online Store - fs.write_to_online_store("driver_stats", df_ingest) + fs.write_to_online_store("driver_stats", fake_ingest_data) # assert the right data is in the Online Store df = fs.get_online_features(