Skip to content

Commit

Permalink
feat: Add column reordering to write_to_offline_store (#2876)
Browse files Browse the repository at this point in the history
* Add feature extraction logic to batch writer

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Enable StreamProcessor to write to both online and offline stores

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Fix incorrect columns error message

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Reorder columns in _write_to_offline_store

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Make _write_to_offline_store a public method

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Import FeatureStore correctly

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Remove defaults for `processing_time` and `query_timeout`

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Clean up `test_offline_write.py`

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Do not do any custom logic for double underscore columns

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Lint

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Switch entity values for all tests using push sources to not affect other tests

Signed-off-by: Felix Wang <wangfelix98@gmail.com>
  • Loading branch information
felixwang9817 authored Jun 30, 2022
1 parent 51df8be commit 8abc2ef
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 109 deletions.
26 changes: 22 additions & 4 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,7 +1383,7 @@ def push(
fv.name, df, allow_registry_cache=allow_registry_cache
)
if to == PushMode.OFFLINE or to == PushMode.ONLINE_AND_OFFLINE:
self._write_to_offline_store(
self.write_to_offline_store(
fv.name, df, allow_registry_cache=allow_registry_cache
)

Expand Down Expand Up @@ -1415,14 +1415,18 @@ def write_to_online_store(
provider.ingest_df(feature_view, entities, df)

@log_exceptions_and_usage
def _write_to_offline_store(
def write_to_offline_store(
self,
feature_view_name: str,
df: pd.DataFrame,
allow_registry_cache: bool = True,
reorder_columns: bool = True,
):
"""
ingests data directly into the Online store
Persists the dataframe directly into the batch data source for the given feature view.
Fails if the dataframe columns do not match the columns of the batch data source. Optionally
reorders the columns of the dataframe to match.
"""
# TODO: restrict this to work with online StreamFeatureViews and validate the FeatureView type
try:
Expand All @@ -1433,7 +1437,21 @@ def _write_to_offline_store(
feature_view = self.get_feature_view(
feature_view_name, allow_registry_cache=allow_registry_cache
)
df.reset_index(drop=True)

# Get columns of the batch source and the input dataframe.
column_names_and_types = feature_view.batch_source.get_table_column_names_and_types(
self.config
)
source_columns = [column for column, _ in column_names_and_types]
input_columns = df.columns.values.tolist()

if set(input_columns) != set(source_columns):
raise ValueError(
f"The input dataframe has columns {set(input_columns)} but the batch source has columns {set(source_columns)}."
)

if reorder_columns:
df = df.reindex(columns=source_columns)

table = pa.Table.from_pandas(df)
provider = self._get_provider()
Expand Down
57 changes: 41 additions & 16 deletions sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from types import MethodType
from typing import List
from typing import List, Optional

import pandas as pd
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.avro.functions import from_avro
from pyspark.sql.functions import col, from_json

from feast.data_format import AvroFormat, JsonFormat
from feast.data_source import KafkaSource
from feast.data_source import KafkaSource, PushMode
from feast.feature_store import FeatureStore
from feast.infra.contrib.stream_processor import (
ProcessorConfig,
StreamProcessor,
Expand All @@ -24,16 +26,16 @@ class SparkProcessorConfig(ProcessorConfig):
class SparkKafkaProcessor(StreamProcessor):
spark: SparkSession
format: str
write_function: MethodType
preprocess_fn: Optional[MethodType]
join_keys: List[str]

def __init__(
self,
*,
fs: FeatureStore,
sfv: StreamFeatureView,
config: ProcessorConfig,
write_function: MethodType,
processing_time: str = "30 seconds",
query_timeout: int = 15,
preprocess_fn: Optional[MethodType] = None,
):
if not isinstance(sfv.stream_source, KafkaSource):
raise ValueError("data source is not kafka source")
Expand All @@ -55,15 +57,16 @@ def __init__(
if not isinstance(config, SparkProcessorConfig):
raise ValueError("config is not spark processor config")
self.spark = config.spark_session
self.write_function = write_function
self.processing_time = processing_time
self.query_timeout = query_timeout
super().__init__(sfv=sfv, data_source=sfv.stream_source)
self.preprocess_fn = preprocess_fn
self.processing_time = config.processing_time
self.query_timeout = config.query_timeout
self.join_keys = [fs.get_entity(entity).join_key for entity in sfv.entities]
super().__init__(fs=fs, sfv=sfv, data_source=sfv.stream_source)

def ingest_stream_feature_view(self) -> None:
def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> None:
ingested_stream_df = self._ingest_stream_data()
transformed_df = self._construct_transformation_plan(ingested_stream_df)
online_store_query = self._write_to_online_store(transformed_df)
online_store_query = self._write_stream_data(transformed_df, to)
return online_store_query

def _ingest_stream_data(self) -> StreamTable:
Expand Down Expand Up @@ -119,13 +122,35 @@ def _ingest_stream_data(self) -> StreamTable:
def _construct_transformation_plan(self, df: StreamTable) -> StreamTable:
return self.sfv.udf.__call__(df) if self.sfv.udf else df

def _write_to_online_store(self, df: StreamTable):
def _write_stream_data(self, df: StreamTable, to: PushMode):
# Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema.
def batch_write(row: DataFrame, batch_id: int):
pd_row = row.toPandas()
self.write_function(
pd_row, input_timestamp="event_timestamp", output_timestamp=""
rows: pd.DataFrame = row.toPandas()

# Extract the latest feature values for each unique entity row (i.e. the join keys).
# Also add a 'created' column.
rows = (
rows.sort_values(
by=self.join_keys + [self.sfv.timestamp_field], ascending=True
)
.groupby(self.join_keys)
.nth(0)
)
rows["created"] = pd.to_datetime("now", utc=True)

# Reset indices to ensure the dataframe has all the required columns.
rows = rows.reset_index()

# Optionally execute preprocessor before writing to the online store.
if self.preprocess_fn:
rows = self.preprocess_fn(rows)

# Finally persist the data to the online store and/or offline store.
if rows.size > 0:
if to == PushMode.ONLINE or to == PushMode.ONLINE_AND_OFFLINE:
self.fs.write_to_online_store(self.sfv.name, rows)
if to == PushMode.OFFLINE or to == PushMode.ONLINE_AND_OFFLINE:
self.fs.write_to_offline_store(self.sfv.name, rows)

query = (
df.writeStream.outputMode("update")
Expand Down
38 changes: 26 additions & 12 deletions sdk/python/feast/infra/contrib/stream_processor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from abc import ABC
from typing import Callable
from types import MethodType
from typing import TYPE_CHECKING, Optional

import pandas as pd
from pyspark.sql import DataFrame

from feast.data_source import DataSource
from feast.data_source import DataSource, PushMode
from feast.importer import import_class
from feast.repo_config import FeastConfigBaseModel
from feast.stream_feature_view import StreamFeatureView

if TYPE_CHECKING:
from feast.feature_store import FeatureStore

STREAM_PROCESSOR_CLASS_FOR_TYPE = {
("spark", "kafka"): "feast.infra.contrib.spark_kafka_processor.SparkKafkaProcessor",
}
Expand All @@ -30,21 +33,26 @@ class StreamProcessor(ABC):
and persist that data to the online store.
Attributes:
fs: The feature store where data should be persisted.
sfv: The stream feature view on which the stream processor operates.
data_source: The stream data source from which data will be ingested.
"""

fs: "FeatureStore"
sfv: StreamFeatureView
data_source: DataSource

def __init__(self, sfv: StreamFeatureView, data_source: DataSource):
def __init__(
self, fs: "FeatureStore", sfv: StreamFeatureView, data_source: DataSource
):
self.fs = fs
self.sfv = sfv
self.data_source = data_source

def ingest_stream_feature_view(self) -> None:
def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> None:
"""
Ingests data from the stream source attached to the stream feature view; transforms the data
and then persists it to the online store.
and then persists it to the online store and/or offline store, depending on the 'to' parameter.
"""
pass

Expand All @@ -62,26 +70,32 @@ def _construct_transformation_plan(self, table: StreamTable) -> StreamTable:
"""
pass

def _write_to_online_store(self, table: StreamTable) -> None:
def _write_stream_data(self, table: StreamTable, to: PushMode) -> None:
"""
Returns query for persisting data to the online store.
Launches a job to persist stream data to the online store and/or offline store, depending
on the 'to' parameter, and returns a handle for the job.
"""
pass


def get_stream_processor_object(
config: ProcessorConfig,
fs: "FeatureStore",
sfv: StreamFeatureView,
write_function: Callable[[pd.DataFrame, str, str], None],
preprocess_fn: Optional[MethodType] = None,
):
"""
Returns a stream processor object based on the config mode and stream source type. The write function is a
function that wraps the feature store "write_to_online_store" capability.
Returns a stream processor object based on the config.
The returned object will be capable of launching an ingestion job that reads data from the
given stream feature view's stream source, transforms it if the stream feature view has a
transformation, and then writes it to the online store. It will also preprocess the data
if a preprocessor method is defined.
"""
if config.mode == "spark" and config.source == "kafka":
stream_processor = STREAM_PROCESSOR_CLASS_FOR_TYPE[("spark", "kafka")]
module_name, class_name = stream_processor.rsplit(".", 1)
cls = import_class(module_name, class_name, "StreamProcessor")
return cls(sfv=sfv, config=config, write_function=write_function,)
return cls(fs=fs, sfv=sfv, config=config, preprocess_fn=preprocess_fn)
else:
raise ValueError("other processors besides spark-kafka not supported")
4 changes: 2 additions & 2 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ def offline_write_batch(
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. "
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
)

if table.schema != pa_schema:
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,8 @@ def offline_write_batch(
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. "
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
)

file_options = feature_view.batch_source.file_options
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ def offline_write_batch(
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. "
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
)

if table.schema != pa_schema:
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ def offline_write_batch(
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. "
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
)

if table.schema != pa_schema:
Expand Down
11 changes: 7 additions & 4 deletions sdk/python/tests/integration/e2e/test_python_feature_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,16 @@ def test_get_online_features(python_fs_client):
@pytest.mark.integration
@pytest.mark.universal_online_stores
def test_push(python_fs_client):
initial_temp = get_temperatures(python_fs_client, location_ids=[1])[0]
# TODO(felixwang9817): Note that we choose an entity value of 102 here since it is not included
# in the existing range of entity values (1-49). This allows us to push data for this test
# without affecting other tests. This decision is tech debt, and should be resolved by finding a
# better way to isolate data sources across tests.
json_data = json.dumps(
{
"push_source_name": "location_stats_push_source",
"df": {
"location_id": [1],
"temperature": [initial_temp * 100],
"location_id": [102],
"temperature": [4],
"event_timestamp": [str(datetime.utcnow())],
"created": [str(datetime.utcnow())],
},
Expand All @@ -79,7 +82,7 @@ def test_push(python_fs_client):

# Check new pushed temperature is fetched
assert response.status_code == 200
assert get_temperatures(python_fs_client, location_ids=[1]) == [initial_temp * 100]
assert get_temperatures(python_fs_client, location_ids=[102]) == [4]


def get_temperatures(client, location_ids: List[int]):
Expand Down
Loading

0 comments on commit 8abc2ef

Please sign in to comment.