Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add AWS Redshift Serverless support #3595

Merged
merged 5 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/reference/offline-stores/redshift.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,25 @@ While the following trust relationship is necessary to make sure that Redshift,
]
}
```


## Redshift Serverless

In order to use [AWS Redshift Serverless](https://aws.amazon.com/redshift/redshift-serverless/), specify a workgroup instead of a cluster_id and user.

{% code title="feature_store.yaml" %}
```yaml
project: my_feature_repo
registry: data/registry.db
provider: aws
offline_store:
type: redshift
region: us-west-2
workgroup: feast-workgroup
database: feast-database
s3_staging_location: s3://feast-bucket/redshift
iam_role: arn:aws:iam::123456789012:role/redshift_s3_access_role
```
{% endcode %}

Please note that the IAM policies above will need the [redshift-serverless](https://aws.permissions.cloud/iam/redshift-serverless) version, rather than the standard [redshift](https://aws.permissions.cloud/iam/redshift).
49 changes: 42 additions & 7 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pyarrow
import pyarrow as pa
from dateutil import parser
from pydantic import StrictStr
from pydantic import StrictStr, root_validator
from pydantic.typing import Literal
from pytz import utc

Expand Down Expand Up @@ -51,15 +51,18 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel):
type: Literal["redshift"] = "redshift"
""" Offline store type selector"""

cluster_id: StrictStr
""" Redshift cluster identifier """
cluster_id: Optional[StrictStr]
""" Redshift cluster identifier, for provisioned clusters """

user: Optional[StrictStr]
""" Redshift user name, only required for provisioned clusters """

workgroup: Optional[StrictStr]
""" Redshift workgroup identifier, for serverless """

region: StrictStr
""" Redshift cluster's AWS region """

user: StrictStr
""" Redshift user name """

database: StrictStr
""" Redshift database name """

Expand All @@ -69,6 +72,26 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel):
iam_role: StrictStr
""" IAM Role for Redshift, granting it access to S3 """

@root_validator
def require_cluster_and_user_or_workgroup(cls, values):
"""
Provisioned Redshift clusters: Require cluster_id and user, ignore workgroup
Serverless Redshift: Require workgroup, ignore cluster_id and user
"""
cluster_id, user, workgroup = (
values.get("cluster_id"),
values.get("user"),
values.get("workgroup"),
)
if not (cluster_id and user) and not workgroup:
raise ValueError(
"please specify either cluster_id & user if using provisioned clusters, or workgroup if using serverless"
)
elif cluster_id and workgroup:
raise ValueError("cannot specify both cluster_id and workgroup")

return values


class RedshiftOfflineStore(OfflineStore):
@staticmethod
Expand Down Expand Up @@ -248,6 +271,7 @@ def query_generator() -> Iterator[str]:
aws_utils.execute_redshift_statement(
redshift_client,
config.offline_store.cluster_id,
config.offline_store.workgroup,
config.offline_store.database,
config.offline_store.user,
f"DROP TABLE IF EXISTS {table_name}",
Expand Down Expand Up @@ -294,6 +318,7 @@ def write_logged_features(
table=data,
redshift_data_client=redshift_client,
cluster_id=config.offline_store.cluster_id,
workgroup=config.offline_store.workgroup,
database=config.offline_store.database,
user=config.offline_store.user,
s3_resource=s3_resource,
Expand Down Expand Up @@ -336,8 +361,10 @@ def offline_write_batch(
table=table,
redshift_data_client=redshift_client,
cluster_id=config.offline_store.cluster_id,
workgroup=config.offline_store.workgroup,
database=redshift_options.database
or config.offline_store.database, # Users can define database in the source if needed but it's not required.
# Users can define database in the source if needed but it's not required.
or config.offline_store.database,
user=config.offline_store.user,
s3_resource=s3_resource,
s3_path=f"{config.offline_store.s3_staging_location}/push/{uuid.uuid4()}.parquet",
Expand Down Expand Up @@ -405,6 +432,7 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
return aws_utils.unload_redshift_query_to_df(
self._redshift_client,
self._config.offline_store.cluster_id,
self._config.offline_store.workgroup,
self._config.offline_store.database,
self._config.offline_store.user,
self._s3_resource,
Expand All @@ -419,6 +447,7 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
return aws_utils.unload_redshift_query_to_pa(
self._redshift_client,
self._config.offline_store.cluster_id,
self._config.offline_store.workgroup,
self._config.offline_store.database,
self._config.offline_store.user,
self._s3_resource,
Expand All @@ -439,6 +468,7 @@ def to_s3(self) -> str:
aws_utils.execute_redshift_query_and_unload_to_s3(
self._redshift_client,
self._config.offline_store.cluster_id,
self._config.offline_store.workgroup,
self._config.offline_store.database,
self._config.offline_store.user,
self._s3_path,
Expand All @@ -455,6 +485,7 @@ def to_redshift(self, table_name: str) -> None:
aws_utils.upload_df_to_redshift(
self._redshift_client,
self._config.offline_store.cluster_id,
self._config.offline_store.workgroup,
self._config.offline_store.database,
self._config.offline_store.user,
self._s3_resource,
Expand All @@ -471,6 +502,7 @@ def to_redshift(self, table_name: str) -> None:
aws_utils.execute_redshift_statement(
self._redshift_client,
self._config.offline_store.cluster_id,
self._config.offline_store.workgroup,
self._config.offline_store.database,
self._config.offline_store.user,
query,
Expand Down Expand Up @@ -509,6 +541,7 @@ def _upload_entity_df(
aws_utils.upload_df_to_redshift(
redshift_client,
config.offline_store.cluster_id,
config.offline_store.workgroup,
config.offline_store.database,
config.offline_store.user,
s3_resource,
Expand All @@ -522,6 +555,7 @@ def _upload_entity_df(
aws_utils.execute_redshift_statement(
redshift_client,
config.offline_store.cluster_id,
config.offline_store.workgroup,
config.offline_store.database,
config.offline_store.user,
f"CREATE TABLE {table_name} AS ({entity_df})",
Expand Down Expand Up @@ -577,6 +611,7 @@ def _get_entity_df_event_timestamp_range(
statement_id = aws_utils.execute_redshift_statement(
redshift_client,
config.offline_store.cluster_id,
config.offline_store.workgroup,
config.offline_store.database,
config.offline_store.user,
f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max "
Expand Down
27 changes: 20 additions & 7 deletions sdk/python/feast/infra/offline_stores/redshift_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,18 +207,30 @@ def get_table_column_names_and_types(
if self.table:
try:
paginator = client.get_paginator("describe_table")
response_iterator = paginator.paginate(
ClusterIdentifier=config.offline_store.cluster_id,
Database=(

paginator_kwargs = {
"Database": (
self.database
if self.database
else config.offline_store.database
),
DbUser=config.offline_store.user,
Table=self.table,
Schema=self.schema,
)
"Table": self.table,
"Schema": self.schema,
}

if config.offline_store.cluster_id:
# Provisioned cluster
paginator_kwargs[
"ClusterIdentifier"
] = config.offline_store.cluster_id
paginator_kwargs["DbUser"] = config.offline_store.user
elif config.offline_store.workgroup:
# Redshift serverless
paginator_kwargs["WorkgroupName"] = config.offline_store.workgroup

response_iterator = paginator.paginate(**paginator_kwargs)
table = response_iterator.build_full_result()

except ClientError as e:
if e.response["Error"]["Code"] == "ValidationException":
raise RedshiftCredentialsError() from e
Expand All @@ -233,6 +245,7 @@ def get_table_column_names_and_types(
statement_id = aws_utils.execute_redshift_statement(
client,
config.offline_store.cluster_id,
config.offline_store.workgroup,
self.database if self.database else config.offline_store.database,
config.offline_store.user,
f"SELECT * FROM ({self.query}) LIMIT 1",
Expand Down
Loading