Skip to content

Commit

Permalink
feat: Add async feature retrieval for Postgres Online Store (#4327)
Browse files Browse the repository at this point in the history
* Add async retrieval for postgres

Signed-off-by: TomSteenbergen <tomsteenbergen1995@gmail.com>

* Format

Signed-off-by: TomSteenbergen <tomsteenbergen1995@gmail.com>

* Update _prepare_keys method

Signed-off-by: TomSteenbergen <tomsteenbergen1995@gmail.com>

* Fix typo

Signed-off-by: TomSteenbergen <tomsteenbergen1995@gmail.com>

---------

Signed-off-by: TomSteenbergen <tomsteenbergen1995@gmail.com>
  • Loading branch information
TomSteenbergen authored Jul 8, 2024
1 parent 0d89d15 commit cea52e9
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 63 deletions.
186 changes: 126 additions & 60 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
Expand All @@ -12,18 +13,24 @@
Optional,
Sequence,
Tuple,
Union,
)

import pytz
from psycopg import sql
from psycopg import AsyncConnection, sql
from psycopg.connection import Connection
from psycopg_pool import ConnectionPool
from psycopg_pool import AsyncConnectionPool, ConnectionPool

from feast import Entity
from feast.feature_view import FeatureView
from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool
from feast.infra.utils.postgres.connection_utils import (
_get_conn,
_get_conn_async,
_get_connection_pool,
_get_connection_pool_async,
)
from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
Expand Down Expand Up @@ -51,6 +58,9 @@ class PostgreSQLOnlineStore(OnlineStore):
_conn: Optional[Connection] = None
_conn_pool: Optional[ConnectionPool] = None

_conn_async: Optional[AsyncConnection] = None
_conn_pool_async: Optional[AsyncConnectionPool] = None

@contextlib.contextmanager
def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
assert config.online_store.type == "postgres"
Expand All @@ -67,6 +77,24 @@ def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
self._conn = _get_conn(config.online_store)
yield self._conn

@contextlib.asynccontextmanager
async def _get_conn_async(
self, config: RepoConfig
) -> AsyncGenerator[AsyncConnection, Any]:
if config.online_store.conn_type == ConnectionType.pool:
if not self._conn_pool_async:
self._conn_pool_async = await _get_connection_pool_async(
config.online_store
)
await self._conn_pool_async.open()
connection = await self._conn_pool_async.getconn()
yield connection
await self._conn_pool_async.putconn(connection)
else:
if not self._conn_async:
self._conn_async = await _get_conn_async(config.online_store)
yield self._conn_async

def online_write_batch(
self,
config: RepoConfig,
Expand Down Expand Up @@ -132,69 +160,107 @@ def online_read(
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version)
query, params = self._construct_query_and_params(
config, table, keys, requested_features
)

project = config.project
with self._get_conn(config) as conn, conn.cursor() as cur:
# Collecting all the keys to a list allows us to make fewer round trips
# to PostgreSQL
keys = []
for entity_key in entity_keys:
keys.append(
serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
)
cur.execute(query, params)
rows = cur.fetchall()

if not requested_features:
cur.execute(
sql.SQL(
"""
SELECT entity_key, feature_name, value, event_ts
FROM {} WHERE entity_key = ANY(%s);
"""
).format(
sql.Identifier(_table_id(project, table)),
),
(keys,),
)
else:
cur.execute(
sql.SQL(
"""
SELECT entity_key, feature_name, value, event_ts
FROM {} WHERE entity_key = ANY(%s) and feature_name = ANY(%s);
"""
).format(
sql.Identifier(_table_id(project, table)),
),
(keys, requested_features),
)
return self._process_rows(keys, rows)

rows = cur.fetchall()
async def online_read_async(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version)
query, params = self._construct_query_and_params(
config, table, keys, requested_features
)

# Since we don't know the order returned from PostgreSQL we'll need
# to construct a dict to be able to quickly look up the correct row
# when we iterate through the keys since they are in the correct order
values_dict = defaultdict(list)
for row in rows if rows is not None else []:
values_dict[
row[0] if isinstance(row[0], bytes) else row[0].tobytes()
].append(row[1:])

for key in keys:
if key in values_dict:
value = values_dict[key]
res = {}
for feature_name, value_bin, event_ts in value:
val = ValueProto()
val.ParseFromString(bytes(value_bin))
res[feature_name] = val
result.append((event_ts, res))
else:
result.append((None, None))
async with self._get_conn_async(config) as conn:
async with conn.cursor() as cur:
await cur.execute(query, params)
rows = await cur.fetchall()

return self._process_rows(keys, rows)

@staticmethod
def _construct_query_and_params(
config: RepoConfig,
table: FeatureView,
keys: List[bytes],
requested_features: Optional[List[str]] = None,
) -> Tuple[sql.Composed, Union[Tuple[List[bytes], List[str]], Tuple[List[bytes]]]]:
"""Construct the SQL query based on the given parameters."""
if requested_features:
query = sql.SQL(
"""
SELECT entity_key, feature_name, value, event_ts
FROM {} WHERE entity_key = ANY(%s) AND feature_name = ANY(%s);
"""
).format(
sql.Identifier(_table_id(config.project, table)),
)
params = (keys, requested_features)
else:
query = sql.SQL(
"""
SELECT entity_key, feature_name, value, event_ts
FROM {} WHERE entity_key = ANY(%s);
"""
).format(
sql.Identifier(_table_id(config.project, table)),
)
params = (keys, [])
return query, params

@staticmethod
def _prepare_keys(
entity_keys: List[EntityKeyProto], entity_key_serialization_version: int
) -> List[bytes]:
"""Prepare all keys in a list to make fewer round trips to the database."""
return [
serialize_entity_key(
entity_key,
entity_key_serialization_version=entity_key_serialization_version,
)
for entity_key in entity_keys
]

@staticmethod
def _process_rows(
keys: List[bytes], rows: List[Tuple]
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
"""Transform the retrieved rows in the desired output.
PostgreSQL may return rows in an unpredictable order. Therefore, `values_dict`
is created to quickly look up the correct row using the keys, since these are
actually in the correct order.
"""
values_dict = defaultdict(list)
for row in rows if rows is not None else []:
values_dict[
row[0] if isinstance(row[0], bytes) else row[0].tobytes()
].append(row[1:])

result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
for key in keys:
if key in values_dict:
value = values_dict[key]
res = {}
for feature_name, value_bin, event_ts in value:
val = ValueProto()
val.ParseFromString(bytes(value_bin))
res[feature_name] = val
result.append((event_ts, res))
else:
result.append((None, None))
return result

def update(
Expand Down
25 changes: 23 additions & 2 deletions sdk/python/feast/infra/utils/postgres/connection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pandas as pd
import psycopg
import pyarrow as pa
from psycopg.connection import Connection
from psycopg_pool import ConnectionPool
from psycopg import AsyncConnection, Connection
from psycopg_pool import AsyncConnectionPool, ConnectionPool

from feast.infra.utils.postgres.postgres_config import PostgreSQLConfig
from feast.type_map import arrow_to_pg_type
Expand All @@ -21,6 +21,16 @@ def _get_conn(config: PostgreSQLConfig) -> Connection:
return conn


async def _get_conn_async(config: PostgreSQLConfig) -> AsyncConnection:
"""Get a psycopg `AsyncConnection`."""
conn = await psycopg.AsyncConnection.connect(
conninfo=_get_conninfo(config),
keepalives_idle=config.keepalives_idle,
**_get_conn_kwargs(config),
)
return conn


def _get_connection_pool(config: PostgreSQLConfig) -> ConnectionPool:
"""Get a psycopg `ConnectionPool`."""
return ConnectionPool(
Expand All @@ -32,6 +42,17 @@ def _get_connection_pool(config: PostgreSQLConfig) -> ConnectionPool:
)


async def _get_connection_pool_async(config: PostgreSQLConfig) -> AsyncConnectionPool:
"""Get a psycopg `AsyncConnectionPool`."""
return AsyncConnectionPool(
conninfo=_get_conninfo(config),
min_size=config.min_conn,
max_size=config.max_conn,
open=False,
kwargs=_get_conn_kwargs(config),
)


def _get_conninfo(config: PostgreSQLConfig) -> str:
"""Get the `conninfo` argument required for connection objects."""
return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def test_online_retrieval_with_event_timestamps(environment, universal_data_sour


@pytest.mark.integration
@pytest.mark.universal_online_stores(only=["redis", "dynamodb"])
@pytest.mark.universal_online_stores(only=["redis", "dynamodb", "postgres"])
def test_async_online_retrieval_with_event_timestamps(
environment, universal_data_sources
):
Expand Down

0 comments on commit cea52e9

Please sign in to comment.