diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index de52b9e3f3..78431e2d61 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1383,7 +1383,7 @@ def push( fv.name, df, allow_registry_cache=allow_registry_cache ) if to == PushMode.OFFLINE or to == PushMode.ONLINE_AND_OFFLINE: - self._write_to_offline_store( + self.write_to_offline_store( fv.name, df, allow_registry_cache=allow_registry_cache ) @@ -1415,14 +1415,18 @@ def write_to_online_store( provider.ingest_df(feature_view, entities, df) @log_exceptions_and_usage - def _write_to_offline_store( + def write_to_offline_store( self, feature_view_name: str, df: pd.DataFrame, allow_registry_cache: bool = True, + reorder_columns: bool = True, ): """ - ingests data directly into the Online store + Persists the dataframe directly into the batch data source for the given feature view. + + Fails if the dataframe columns do not match the columns of the batch data source. Optionally + reorders the columns of the dataframe to match. """ # TODO: restrict this to work with online StreamFeatureViews and validate the FeatureView type try: @@ -1433,7 +1437,21 @@ def _write_to_offline_store( feature_view = self.get_feature_view( feature_view_name, allow_registry_cache=allow_registry_cache ) - df.reset_index(drop=True) + + # Get columns of the batch source and the input dataframe. + column_names_and_types = feature_view.batch_source.get_table_column_names_and_types( + self.config + ) + source_columns = [column for column, _ in column_names_and_types] + input_columns = df.columns.values.tolist() + + if set(input_columns) != set(source_columns): + raise ValueError( + f"The input dataframe has columns {set(input_columns)} but the batch source has columns {set(source_columns)}." + ) + + if reorder_columns: + df = df.reindex(columns=source_columns) table = pa.Table.from_pandas(df) provider = self._get_provider() diff --git a/sdk/python/feast/infra/contrib/spark_kafka_processor.py b/sdk/python/feast/infra/contrib/spark_kafka_processor.py index 4dfb615773..32d91b2010 100644 --- a/sdk/python/feast/infra/contrib/spark_kafka_processor.py +++ b/sdk/python/feast/infra/contrib/spark_kafka_processor.py @@ -1,12 +1,14 @@ from types import MethodType -from typing import List +from typing import List, Optional +import pandas as pd from pyspark.sql import DataFrame, SparkSession from pyspark.sql.avro.functions import from_avro from pyspark.sql.functions import col, from_json from feast.data_format import AvroFormat, JsonFormat -from feast.data_source import KafkaSource +from feast.data_source import KafkaSource, PushMode +from feast.feature_store import FeatureStore from feast.infra.contrib.stream_processor import ( ProcessorConfig, StreamProcessor, @@ -24,16 +26,16 @@ class SparkProcessorConfig(ProcessorConfig): class SparkKafkaProcessor(StreamProcessor): spark: SparkSession format: str - write_function: MethodType + preprocess_fn: Optional[MethodType] join_keys: List[str] def __init__( self, + *, + fs: FeatureStore, sfv: StreamFeatureView, config: ProcessorConfig, - write_function: MethodType, - processing_time: str = "30 seconds", - query_timeout: int = 15, + preprocess_fn: Optional[MethodType] = None, ): if not isinstance(sfv.stream_source, KafkaSource): raise ValueError("data source is not kafka source") @@ -55,15 +57,16 @@ def __init__( if not isinstance(config, SparkProcessorConfig): raise ValueError("config is not spark processor config") self.spark = config.spark_session - self.write_function = write_function - self.processing_time = processing_time - self.query_timeout = query_timeout - super().__init__(sfv=sfv, data_source=sfv.stream_source) + self.preprocess_fn = preprocess_fn + self.processing_time = config.processing_time + self.query_timeout = config.query_timeout + self.join_keys = [fs.get_entity(entity).join_key for entity in sfv.entities] + super().__init__(fs=fs, sfv=sfv, data_source=sfv.stream_source) - def ingest_stream_feature_view(self) -> None: + def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> None: ingested_stream_df = self._ingest_stream_data() transformed_df = self._construct_transformation_plan(ingested_stream_df) - online_store_query = self._write_to_online_store(transformed_df) + online_store_query = self._write_stream_data(transformed_df, to) return online_store_query def _ingest_stream_data(self) -> StreamTable: @@ -119,13 +122,35 @@ def _ingest_stream_data(self) -> StreamTable: def _construct_transformation_plan(self, df: StreamTable) -> StreamTable: return self.sfv.udf.__call__(df) if self.sfv.udf else df - def _write_to_online_store(self, df: StreamTable): + def _write_stream_data(self, df: StreamTable, to: PushMode): # Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema. def batch_write(row: DataFrame, batch_id: int): - pd_row = row.toPandas() - self.write_function( - pd_row, input_timestamp="event_timestamp", output_timestamp="" + rows: pd.DataFrame = row.toPandas() + + # Extract the latest feature values for each unique entity row (i.e. the join keys). + # Also add a 'created' column. + rows = ( + rows.sort_values( + by=self.join_keys + [self.sfv.timestamp_field], ascending=True + ) + .groupby(self.join_keys) + .nth(0) ) + rows["created"] = pd.to_datetime("now", utc=True) + + # Reset indices to ensure the dataframe has all the required columns. + rows = rows.reset_index() + + # Optionally execute preprocessor before writing to the online store. + if self.preprocess_fn: + rows = self.preprocess_fn(rows) + + # Finally persist the data to the online store and/or offline store. + if rows.size > 0: + if to == PushMode.ONLINE or to == PushMode.ONLINE_AND_OFFLINE: + self.fs.write_to_online_store(self.sfv.name, rows) + if to == PushMode.OFFLINE or to == PushMode.ONLINE_AND_OFFLINE: + self.fs.write_to_offline_store(self.sfv.name, rows) query = ( df.writeStream.outputMode("update") diff --git a/sdk/python/feast/infra/contrib/stream_processor.py b/sdk/python/feast/infra/contrib/stream_processor.py index 2ccf1e59f8..24817c82ea 100644 --- a/sdk/python/feast/infra/contrib/stream_processor.py +++ b/sdk/python/feast/infra/contrib/stream_processor.py @@ -1,14 +1,17 @@ from abc import ABC -from typing import Callable +from types import MethodType +from typing import TYPE_CHECKING, Optional -import pandas as pd from pyspark.sql import DataFrame -from feast.data_source import DataSource +from feast.data_source import DataSource, PushMode from feast.importer import import_class from feast.repo_config import FeastConfigBaseModel from feast.stream_feature_view import StreamFeatureView +if TYPE_CHECKING: + from feast.feature_store import FeatureStore + STREAM_PROCESSOR_CLASS_FOR_TYPE = { ("spark", "kafka"): "feast.infra.contrib.spark_kafka_processor.SparkKafkaProcessor", } @@ -30,21 +33,26 @@ class StreamProcessor(ABC): and persist that data to the online store. Attributes: + fs: The feature store where data should be persisted. sfv: The stream feature view on which the stream processor operates. data_source: The stream data source from which data will be ingested. """ + fs: "FeatureStore" sfv: StreamFeatureView data_source: DataSource - def __init__(self, sfv: StreamFeatureView, data_source: DataSource): + def __init__( + self, fs: "FeatureStore", sfv: StreamFeatureView, data_source: DataSource + ): + self.fs = fs self.sfv = sfv self.data_source = data_source - def ingest_stream_feature_view(self) -> None: + def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> None: """ Ingests data from the stream source attached to the stream feature view; transforms the data - and then persists it to the online store. + and then persists it to the online store and/or offline store, depending on the 'to' parameter. """ pass @@ -62,26 +70,32 @@ def _construct_transformation_plan(self, table: StreamTable) -> StreamTable: """ pass - def _write_to_online_store(self, table: StreamTable) -> None: + def _write_stream_data(self, table: StreamTable, to: PushMode) -> None: """ - Returns query for persisting data to the online store. + Launches a job to persist stream data to the online store and/or offline store, depending + on the 'to' parameter, and returns a handle for the job. """ pass def get_stream_processor_object( config: ProcessorConfig, + fs: "FeatureStore", sfv: StreamFeatureView, - write_function: Callable[[pd.DataFrame, str, str], None], + preprocess_fn: Optional[MethodType] = None, ): """ - Returns a stream processor object based on the config mode and stream source type. The write function is a - function that wraps the feature store "write_to_online_store" capability. + Returns a stream processor object based on the config. + + The returned object will be capable of launching an ingestion job that reads data from the + given stream feature view's stream source, transforms it if the stream feature view has a + transformation, and then writes it to the online store. It will also preprocess the data + if a preprocessor method is defined. """ if config.mode == "spark" and config.source == "kafka": stream_processor = STREAM_PROCESSOR_CLASS_FOR_TYPE[("spark", "kafka")] module_name, class_name = stream_processor.rsplit(".", 1) cls = import_class(module_name, class_name, "StreamProcessor") - return cls(sfv=sfv, config=config, write_function=write_function,) + return cls(fs=fs, sfv=sfv, config=config, preprocess_fn=preprocess_fn) else: raise ValueError("other processors besides spark-kafka not supported") diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index 259a3af7d9..cb5b3a045a 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -329,8 +329,8 @@ def offline_write_batch( ) if column_names != table.column_names: raise ValueError( - f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. " - f"The columns are expected to be (in this order): {column_names}." + f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. " + f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}." ) if table.schema != pa_schema: diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index 75968146de..10012c2d80 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -430,8 +430,8 @@ def offline_write_batch( ) if column_names != table.column_names: raise ValueError( - f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. " - f"The columns are expected to be (in this order): {column_names}." + f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. " + f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}." ) file_options = feature_view.batch_source.file_options diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index 8667989268..5f071a814f 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -323,8 +323,8 @@ def offline_write_batch( ) if column_names != table.column_names: raise ValueError( - f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. " - f"The columns are expected to be (in this order): {column_names}." + f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. " + f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}." ) if table.schema != pa_schema: diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index ec06d8dce1..a5befc33e2 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -332,8 +332,8 @@ def offline_write_batch( ) if column_names != table.column_names: raise ValueError( - f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. " - f"The columns are expected to be (in this order): {column_names}." + f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. " + f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}." ) if table.schema != pa_schema: diff --git a/sdk/python/tests/integration/e2e/test_python_feature_server.py b/sdk/python/tests/integration/e2e/test_python_feature_server.py index ea4c35a1ca..7195594d02 100644 --- a/sdk/python/tests/integration/e2e/test_python_feature_server.py +++ b/sdk/python/tests/integration/e2e/test_python_feature_server.py @@ -63,13 +63,16 @@ def test_get_online_features(python_fs_client): @pytest.mark.integration @pytest.mark.universal_online_stores def test_push(python_fs_client): - initial_temp = get_temperatures(python_fs_client, location_ids=[1])[0] + # TODO(felixwang9817): Note that we choose an entity value of 102 here since it is not included + # in the existing range of entity values (1-49). This allows us to push data for this test + # without affecting other tests. This decision is tech debt, and should be resolved by finding a + # better way to isolate data sources across tests. json_data = json.dumps( { "push_source_name": "location_stats_push_source", "df": { - "location_id": [1], - "temperature": [initial_temp * 100], + "location_id": [102], + "temperature": [4], "event_timestamp": [str(datetime.utcnow())], "created": [str(datetime.utcnow())], }, @@ -79,7 +82,7 @@ def test_push(python_fs_client): # Check new pushed temperature is fetched assert response.status_code == 200 - assert get_temperatures(python_fs_client, location_ids=[1]) == [initial_temp * 100] + assert get_temperatures(python_fs_client, location_ids=[102]) == [4] def get_temperatures(client, location_ids: List[int]): diff --git a/sdk/python/tests/integration/offline_store/test_offline_write.py b/sdk/python/tests/integration/offline_store/test_offline_write.py index 30ead98389..3335da0df7 100644 --- a/sdk/python/tests/integration/offline_store/test_offline_write.py +++ b/sdk/python/tests/integration/offline_store/test_offline_write.py @@ -9,69 +9,25 @@ from feast.types import Float32, Int32 from tests.integration.feature_repos.universal.entities import driver - -@pytest.mark.integration -@pytest.mark.universal_offline_stores(only=["file", "redshift"]) -@pytest.mark.universal_online_stores(only=["sqlite"]) -def test_writing_columns_in_incorrect_order_fails(environment, universal_data_sources): - # TODO(kevjumba) handle incorrect order later, for now schema must be in the order that the filesource is in - store = environment.feature_store - _, _, data_sources = universal_data_sources - driver_stats = FeatureView( - name="driver_stats", - entities=["driver"], - schema=[ - Field(name="avg_daily_trips", dtype=Int32), - Field(name="conv_rate", dtype=Float32), - ], - source=data_sources.driver, - ) - - now = datetime.utcnow() - ts = pd.Timestamp(now).round("ms") - - entity_df = pd.DataFrame.from_dict( - {"driver_id": [1001, 1002], "event_timestamp": [ts - timedelta(hours=3), ts]} - ) - - store.apply([driver(), driver_stats]) - df = store.get_historical_features( - entity_df=entity_df, - features=["driver_stats:conv_rate", "driver_stats:avg_daily_trips"], - full_feature_names=False, - ).to_df() - - assert df["conv_rate"].isnull().all() - assert df["avg_daily_trips"].isnull().all() - - expected_df = pd.DataFrame.from_dict( - { - "driver_id": [1001, 1002], - "event_timestamp": [ts - timedelta(hours=3), ts], - "conv_rate": [random.random(), random.random()], - "avg_daily_trips": [random.randint(0, 10), random.randint(0, 10)], - "created": [ts, ts], - }, - ) - with pytest.raises(ValueError): - store._write_to_offline_store( - driver_stats.name, expected_df, allow_registry_cache=False - ) +# TODO(felixwang9817): Add a unit test that checks that write_to_offline_store can reorder columns. +# This should only happen after https://github.com/feast-dev/feast/issues/2797 is fixed. @pytest.mark.integration -@pytest.mark.universal_offline_stores(only=["file", "redshift"]) +@pytest.mark.universal_offline_stores @pytest.mark.universal_online_stores(only=["sqlite"]) def test_writing_incorrect_schema_fails(environment, universal_data_sources): - # TODO(kevjumba) handle incorrect order later, for now schema must be in the order that the filesource is in + """Tests that writing a dataframe with an incorrect schema fails.""" store = environment.feature_store _, _, data_sources = universal_data_sources + driver_entity = driver() driver_stats = FeatureView( name="driver_stats", - entities=["driver"], + entities=[driver_entity], schema=[ Field(name="avg_daily_trips", dtype=Int32), Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), ], source=data_sources.driver, ) @@ -83,14 +39,19 @@ def test_writing_incorrect_schema_fails(environment, universal_data_sources): {"driver_id": [1001, 1002], "event_timestamp": [ts - timedelta(hours=3), ts]} ) - store.apply([driver(), driver_stats]) + store.apply([driver_entity, driver_stats]) df = store.get_historical_features( entity_df=entity_df, - features=["driver_stats:conv_rate", "driver_stats:avg_daily_trips"], + features=[ + "driver_stats:conv_rate", + "driver_stats:acc_rate", + "driver_stats:avg_daily_trips", + ], full_feature_names=False, ).to_df() assert df["conv_rate"].isnull().all() + assert df["acc_rate"].isnull().all() assert df["avg_daily_trips"].isnull().all() expected_df = pd.DataFrame.from_dict( @@ -103,7 +64,7 @@ def test_writing_incorrect_schema_fails(environment, universal_data_sources): }, ) with pytest.raises(ValueError): - store._write_to_offline_store( + store.write_to_offline_store( driver_stats.name, expected_df, allow_registry_cache=False ) @@ -114,9 +75,10 @@ def test_writing_incorrect_schema_fails(environment, universal_data_sources): def test_writing_consecutively_to_offline_store(environment, universal_data_sources): store = environment.feature_store _, _, data_sources = universal_data_sources + driver_entity = driver() driver_stats = FeatureView( name="driver_stats", - entities=["driver"], + entities=[driver_entity], schema=[ Field(name="avg_daily_trips", dtype=Int32), Field(name="conv_rate", dtype=Float32), @@ -138,14 +100,19 @@ def test_writing_consecutively_to_offline_store(environment, universal_data_sour } ) - store.apply([driver(), driver_stats]) + store.apply([driver_entity, driver_stats]) df = store.get_historical_features( entity_df=entity_df, - features=["driver_stats:conv_rate", "driver_stats:avg_daily_trips"], + features=[ + "driver_stats:conv_rate", + "driver_stats:acc_rate", + "driver_stats:avg_daily_trips", + ], full_feature_names=False, ).to_df() assert df["conv_rate"].isnull().all() + assert df["acc_rate"].isnull().all() assert df["avg_daily_trips"].isnull().all() first_df = pd.DataFrame.from_dict( @@ -158,13 +125,17 @@ def test_writing_consecutively_to_offline_store(environment, universal_data_sour "created": [ts, ts], }, ) - store._write_to_offline_store( + store.write_to_offline_store( driver_stats.name, first_df, allow_registry_cache=False ) after_write_df = store.get_historical_features( entity_df=entity_df, - features=["driver_stats:conv_rate", "driver_stats:avg_daily_trips"], + features=[ + "driver_stats:conv_rate", + "driver_stats:acc_rate", + "driver_stats:avg_daily_trips", + ], full_feature_names=False, ).to_df() @@ -173,6 +144,10 @@ def test_writing_consecutively_to_offline_store(environment, universal_data_sour after_write_df["conv_rate"].reset_index(drop=True) == first_df["conv_rate"].reset_index(drop=True) ) + assert np.where( + after_write_df["acc_rate"].reset_index(drop=True) + == first_df["acc_rate"].reset_index(drop=True) + ) assert np.where( after_write_df["avg_daily_trips"].reset_index(drop=True) == first_df["avg_daily_trips"].reset_index(drop=True) @@ -189,7 +164,7 @@ def test_writing_consecutively_to_offline_store(environment, universal_data_sour }, ) - store._write_to_offline_store( + store.write_to_offline_store( driver_stats.name, second_df, allow_registry_cache=False ) diff --git a/sdk/python/tests/integration/offline_store/test_push_offline_retrieval.py b/sdk/python/tests/integration/offline_store/test_push_offline_retrieval.py index 5cea8a36ef..23bb0f98a7 100644 --- a/sdk/python/tests/integration/offline_store/test_push_offline_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_push_offline_retrieval.py @@ -26,7 +26,7 @@ def test_push_features_and_read_from_offline_store(environment, universal_data_s now = pd.Timestamp(datetime.datetime.utcnow()).round("ms") store.apply([driver(), customer(), location(), *feature_views.values()]) - entity_df = pd.DataFrame.from_dict({"location_id": [1], "event_timestamp": [now]}) + entity_df = pd.DataFrame.from_dict({"location_id": [100], "event_timestamp": [now]}) before_df = store.get_historical_features( entity_df=entity_df, @@ -34,9 +34,13 @@ def test_push_features_and_read_from_offline_store(environment, universal_data_s full_feature_names=False, ).to_df() + # TODO(felixwang9817): Note that we choose an entity value of 100 here since it is not included + # in the existing range of entity values (1-49). This allows us to push data for this test + # without affecting other tests. This decision is tech debt, and should be resolved by finding a + # better way to isolate data sources across tests. data = { "event_timestamp": [now], - "location_id": [1], + "location_id": [100], "temperature": [4], "created": [now], } diff --git a/sdk/python/tests/integration/online_store/test_push_online_retrieval.py b/sdk/python/tests/integration/online_store/test_push_online_retrieval.py index aa7e3e7f53..436f87715f 100644 --- a/sdk/python/tests/integration/online_store/test_push_online_retrieval.py +++ b/sdk/python/tests/integration/online_store/test_push_online_retrieval.py @@ -22,8 +22,13 @@ def test_push_features_and_read(environment, universal_data_sources): feature_views = construct_universal_feature_views(data_sources) store.apply([driver(), customer(), location(), *feature_views.values()]) + + # TODO(felixwang9817): Note that we choose an entity value of 101 here since it is not included + # in the existing range of entity values (1-49). This allows us to push data for this test + # without affecting other tests. This decision is tech debt, and should be resolved by finding a + # better way to isolate data sources across tests. data = { - "location_id": [1], + "location_id": [101], "temperature": [4], "event_timestamp": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")], "created": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")], @@ -34,8 +39,8 @@ def test_push_features_and_read(environment, universal_data_sources): online_resp = store.get_online_features( features=["pushable_location_stats:temperature"], - entity_rows=[{"location_id": 1}], + entity_rows=[{"location_id": 101}], ) online_resp_dict = online_resp.to_dict() - assert online_resp_dict["location_id"] == [1] + assert online_resp_dict["location_id"] == [101] assert online_resp_dict["temperature"] == [4]