Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement offline_write_batch for BigQuery and Snowflake #2840

Merged
merged 6 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import date, datetime, timedelta
from pathlib import Path
from typing import (
Any,
Callable,
ContextManager,
Dict,
Expand Down Expand Up @@ -303,6 +304,58 @@ def write_logged_features(
job_config=job_config,
)

@staticmethod
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, BigQueryOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when bigquery type required"
)
if not isinstance(feature_view.batch_source, BigQuerySource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not bigquery source"
)

pa_schema, column_names = offline_utils.get_pyarrow_schema(config, feature_view)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
)

if table.schema != pa_schema:
table = table.cast(pa_schema)

client = _get_bigquery_client(
project=config.offline_store.project_id,
location=config.offline_store.location,
)

job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.PARQUET,
schema=arrow_schema_to_bq_schema(pa_schema),
write_disposition="WRITE_APPEND", # Default but included for clarity
)

with tempfile.TemporaryFile() as parquet_temp_file:
pyarrow.parquet.write_table(table=table, where=parquet_temp_file)

parquet_temp_file.seek(0)

client.load_table_from_file(
file_obj=parquet_temp_file,
destination=feature_view.batch_source.table,
job_config=job_config,
)


class BigQueryRetrievalJob(RetrievalJob):
def __init__(
Expand Down
26 changes: 16 additions & 10 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from feast.infra.offline_stores.offline_utils import (
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
get_pyarrow_schema,
)
from feast.infra.provider import (
_get_requested_feature_views_to_features_dict,
Expand Down Expand Up @@ -408,7 +409,7 @@ def write_logged_features(
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
data: pyarrow.Table,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
Expand All @@ -423,20 +424,25 @@ def offline_write_batch(
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not file source"
)

pa_schema, column_names = get_pyarrow_schema(config, feature_view)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
)

file_options = feature_view.batch_source.file_options
filesystem, path = FileSource.create_filesystem_and_path(
file_options.uri, file_options.s3_endpoint_override
)

prev_table = pyarrow.parquet.read_table(path, memory_map=True)
if prev_table.column_names != data.column_names:
raise ValueError(
f"Input dataframe has incorrect schema or wrong order, expected columns are: {prev_table.column_names}"
)
if data.schema != prev_table.schema:
data = data.cast(prev_table.schema)
new_table = pyarrow.concat_tables([data, prev_table])
writer = pyarrow.parquet.ParquetWriter(path, data.schema, filesystem=filesystem)
if table.schema != prev_table.schema:
table = table.cast(prev_table.schema)
new_table = pyarrow.concat_tables([table, prev_table])
writer = pyarrow.parquet.ParquetWriter(
path, table.schema, filesystem=filesystem
)
writer.write_table(new_table)
writer.close()

Expand Down
6 changes: 3 additions & 3 deletions sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def write_logged_features(
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
data: pyarrow.Table,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
"""
Expand All @@ -286,8 +286,8 @@ def offline_write_batch(

Args:
config: Repo configuration object
table: FeatureView to write the data to.
data: pyarrow table containing feature data and timestamp column for historical feature retrieval
feature_view: FeatureView to write the data to.
table: pyarrow table containing feature data and timestamp column for historical feature retrieval
progress: Optional function to be called once every mini-batch of rows is written to
the online store. Can be used to display progress.
"""
Expand Down
29 changes: 29 additions & 0 deletions sdk/python/feast/infra/offline_stores/offline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
import pyarrow as pa
from jinja2 import BaseLoader, Environment
from pandas import Timestamp

Expand All @@ -17,6 +18,8 @@
from feast.infra.offline_stores.offline_store import OfflineStore
from feast.infra.provider import _get_requested_feature_views_to_features_dict
from feast.registry import BaseRegistry
from feast.repo_config import RepoConfig
from feast.type_map import feast_value_type_to_pa
from feast.utils import to_naive_utc

DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL = "event_timestamp"
Expand Down Expand Up @@ -217,3 +220,29 @@ def get_offline_store_from_config(offline_store_config: Any) -> OfflineStore:
class_name = qualified_name.replace("Config", "")
offline_store_class = import_class(module_name, class_name, "OfflineStore")
return offline_store_class()


def get_pyarrow_schema(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you rename this to get_feature_view_batch_source_pyarrow_schema or something more specific.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

config: RepoConfig, feature_view: FeatureView
) -> Tuple[pa.Schema, List[str]]:
"""Returns the pyarrow schema and column names for the specified feature view's batch source."""
column_names_and_types = feature_view.batch_source.get_table_column_names_and_types(
config
)

pa_schema = []
column_names = []
for column_name, column_type in column_names_and_types:
pa_schema.append(
(
column_name,
feast_value_type_to_pa(
feature_view.batch_source.source_datatype_to_feast_value_type()(
column_type
)
),
)
)
column_names.append(column_name)

return pa.schema(pa_schema), column_names
27 changes: 7 additions & 20 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from feast.registry import BaseRegistry
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.type_map import feast_value_type_to_pa, redshift_to_feast_value_type
from feast.usage import log_exceptions_and_usage


Expand Down Expand Up @@ -318,33 +317,21 @@ def offline_write_batch(
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not redshift source"
)
redshift_options = feature_view.batch_source.redshift_options
redshift_client = aws_utils.get_redshift_data_client(
config.offline_store.region
)

column_name_to_type = feature_view.batch_source.get_table_column_names_and_types(
config
)
pa_schema_list = []
column_names = []
for column_name, redshift_type in column_name_to_type:
pa_schema_list.append(
(
column_name,
feast_value_type_to_pa(redshift_to_feast_value_type(redshift_type)),
)
)
column_names.append(column_name)
pa_schema = pa.schema(pa_schema_list)
pa_schema, column_names = offline_utils.get_pyarrow_schema(config, feature_view)
if column_names != table.column_names:
raise ValueError(
f"Input dataframe has incorrect schema or wrong order, expected columns are: {column_names}"
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
)

if table.schema != pa_schema:
table = table.cast(pa_schema)

redshift_options = feature_view.batch_source.redshift_options
redshift_client = aws_utils.get_redshift_data_client(
config.offline_store.region
)
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

aws_utils.upload_arrow_table_to_redshift(
Expand Down
40 changes: 40 additions & 0 deletions sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from pathlib import Path
from typing import (
Any,
Callable,
ContextManager,
Dict,
Expand Down Expand Up @@ -306,6 +307,45 @@ def write_logged_features(
auto_create_table=True,
)

@staticmethod
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, SnowflakeOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when snowflake type required"
)
if not isinstance(feature_view.batch_source, SnowflakeSource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not snowflake source"
)

pa_schema, column_names = offline_utils.get_pyarrow_schema(config, feature_view)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
)

if table.schema != pa_schema:
table = table.cast(pa_schema)

snowflake_conn = get_snowflake_conn(config.offline_store)

write_pandas(
snowflake_conn,
table.to_pandas(),
table_name=feature_view.batch_source.table,
auto_create_table=True,
)


class SnowflakeRetrievalJob(RetrievalJob):
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@

OFFLINE_STORE_TO_PROVIDER_CONFIG: Dict[str, DataSourceCreator] = {
"file": ("local", FileDataSourceCreator),
"gcp": ("gcp", BigQueryDataSourceCreator),
"bigquery": ("gcp", BigQueryDataSourceCreator),
"redshift": ("aws", RedshiftDataSourceCreator),
"snowflake": ("aws", RedshiftDataSourceCreator),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_writing_incorrect_schema_fails(environment, universal_data_sources):


@pytest.mark.integration
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
@pytest.mark.universal_offline_stores
@pytest.mark.universal_online_stores(only=["sqlite"])
def test_writing_consecutively_to_offline_store(environment, universal_data_sources):
store = environment.feature_store
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@pytest.mark.integration
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
@pytest.mark.universal_offline_stores
@pytest.mark.universal_online_stores(only=["sqlite"])
def test_push_features_and_read_from_offline_store(environment, universal_data_sources):
store = environment.feature_store
Expand Down