Skip to content

Commit

Permalink
fix: Redshift push ignores schema (feast-dev#3671)
Browse files Browse the repository at this point in the history
* Add fully-qualified-table-name Redshift prop

Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com>

* pre-commit

Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com>

* Docstring

Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com>

* Test fully_qualified_table_name

Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com>

* Simplify logic

Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com>

* pre-commit

Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com>

* pre-commit

Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com>

* Test offline_write_batch

Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com>

* Bump to trigger CI

Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com>

* another bump for ci

Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com>

---------

Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com>
  • Loading branch information
metavee authored Jul 24, 2023
1 parent 9527183 commit 76270f6
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 2 deletions.
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def offline_write_batch(
s3_resource=s3_resource,
s3_path=f"{config.offline_store.s3_staging_location}/push/{uuid.uuid4()}.parquet",
iam_role=config.offline_store.iam_role,
table_name=redshift_options.table,
table_name=redshift_options.fully_qualified_table_name,
schema=pa_schema,
fail_if_exists=False,
)
Expand Down
37 changes: 36 additions & 1 deletion sdk/python/feast/infra/offline_stores/redshift_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,42 @@ def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):

return redshift_options

@property
def fully_qualified_table_name(self) -> str:
"""
The fully qualified table name of this Redshift table.
Returns:
A string in the format of <database>.<schema>.<table>
May be empty or None if the table is not set
"""

if not self.table:
return ""

# self.table may already contain the database and schema
parts = self.table.split(".")
if len(parts) == 3:
database, schema, table = parts
elif len(parts) == 2:
database = self.database
schema, table = parts
elif len(parts) == 1:
database = self.database
schema = self.schema
table = parts[0]
else:
raise ValueError(
f"Invalid table name: {self.table} - can't determine database and schema"
)

if database and schema:
return f"{database}.{schema}.{table}"
elif schema:
return f"{schema}.{table}"
else:
return table

def to_proto(self) -> DataSourceProto.RedshiftOptions:
"""
Converts an RedshiftOptionsProto object to its protobuf representation.
Expand Down Expand Up @@ -323,7 +359,6 @@ def __init__(self, table_ref: str):

@staticmethod
def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage:

return SavedDatasetRedshiftStorage(
table_ref=RedshiftOptions.from_proto(storage_proto.redshift_storage).table
)
Expand Down
67 changes: 67 additions & 0 deletions sdk/python/tests/unit/infra/offline_stores/test_redshift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from unittest.mock import MagicMock, patch

import pandas as pd
import pyarrow as pa

from feast import FeatureView
from feast.infra.offline_stores import offline_utils
from feast.infra.offline_stores.redshift import (
RedshiftOfflineStore,
RedshiftOfflineStoreConfig,
)
from feast.infra.offline_stores.redshift_source import RedshiftSource
from feast.infra.utils import aws_utils
from feast.repo_config import RepoConfig


@patch.object(aws_utils, "upload_arrow_table_to_redshift")
def test_offline_write_batch(
mock_upload_arrow_table_to_redshift: MagicMock,
simple_dataset_1: pd.DataFrame,
):
repo_config = RepoConfig(
registry="registry",
project="project",
provider="local",
offline_store=RedshiftOfflineStoreConfig(
type="redshift",
region="us-west-2",
cluster_id="cluster_id",
database="database",
user="user",
iam_role="abcdef",
s3_staging_location="s3://bucket/path",
),
)

batch_source = RedshiftSource(
name="test_source",
timestamp_field="ts",
table="table_name",
schema="schema_name",
)
feature_view = FeatureView(
name="test_view",
source=batch_source,
)

pa_dataset = pa.Table.from_pandas(simple_dataset_1)

# patch some more things so that the function can run
def mock_get_pyarrow_schema_from_batch_source(*args, **kwargs) -> pa.Schema:
return pa_dataset.schema, pa_dataset.column_names

with patch.object(
offline_utils,
"get_pyarrow_schema_from_batch_source",
new=mock_get_pyarrow_schema_from_batch_source,
):
RedshiftOfflineStore.offline_write_batch(
repo_config, feature_view, pa_dataset, progress=None
)

# check that we have included the fully qualified table name
mock_upload_arrow_table_to_redshift.assert_called_once()

call = mock_upload_arrow_table_to_redshift.call_args_list[0]
assert call.kwargs["table_name"] == "schema_name.table_name"
43 changes: 43 additions & 0 deletions sdk/python/tests/unit/test_data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,46 @@ def test_column_conflict():
timestamp_field="event_timestamp",
created_timestamp_column="event_timestamp",
)


@pytest.mark.parametrize(
"source_kwargs,expected_name",
[
(
{
"database": "test_database",
"schema": "test_schema",
"table": "test_table",
},
"test_database.test_schema.test_table",
),
(
{"database": "test_database", "table": "test_table"},
"test_database.public.test_table",
),
({"table": "test_table"}, "public.test_table"),
({"database": "test_database", "table": "b.c"}, "test_database.b.c"),
({"database": "test_database", "table": "a.b.c"}, "a.b.c"),
(
{
"database": "test_database",
"schema": "test_schema",
"query": "select * from abc",
},
"",
),
],
)
def test_redshift_fully_qualified_table_name(source_kwargs, expected_name):
redshift_source = RedshiftSource(
name="test_source",
timestamp_field="event_timestamp",
created_timestamp_column="created_timestamp",
field_mapping={"foo": "bar"},
description="test description",
tags={"test": "test"},
owner="test@gmail.com",
**source_kwargs,
)

assert redshift_source.redshift_options.fully_qualified_table_name == expected_name

0 comments on commit 76270f6

Please sign in to comment.