Skip to content

Commit

Permalink
Add schema parameter to RedshiftSource
Browse files Browse the repository at this point in the history
Signed-off-by: Felix Wang <wangfelix98@gmail.com>
  • Loading branch information
felixwang9817 committed Sep 8, 2021
1 parent 0dc13f0 commit 809ce38
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 30 deletions.
3 changes: 3 additions & 0 deletions protos/feast/core/DataSource.proto
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ message DataSource {
// SQL query that returns a table containing feature data. Must contain an event_timestamp column, and respective
// entity columns
string query = 2;

// Redshift schema name
string schema = 3;
}

// Defines configuration for custom third-party data sources.
Expand Down
113 changes: 83 additions & 30 deletions sdk/python/feast/infra/offline_stores/redshift_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,56 @@ def __init__(
self,
event_timestamp_column: Optional[str] = "",
table: Optional[str] = None,
schema: Optional[str] = None,
created_timestamp_column: Optional[str] = "",
field_mapping: Optional[Dict[str, str]] = None,
date_partition_column: Optional[str] = "",
query: Optional[str] = None,
):
"""
Creates a RedshiftSource object.
Args:
event_timestamp_column (optional): Event timestamp column used for point in
time joins of feature values.
table (optional): Redshift table where the features are stored.
schema (optional): Redshift schema in which the table is located.
created_timestamp_column (optional): Timestamp column indicating when the
row was created, used for deduplicating rows.
field_mapping (optional): A dictionary mapping of column names in this data
source to column names in a feature table or view.
date_partition_column (optional): Timestamp column used for partitioning.
query (optional): The query to be executed to obtain the features.
"""
super().__init__(
event_timestamp_column,
created_timestamp_column,
field_mapping,
date_partition_column,
)

self._redshift_options = RedshiftOptions(table=table, query=query)
# The default Redshift schema is named "public".
_schema = "public" if table and not schema else schema

self._redshift_options = RedshiftOptions(
table=table, schema=_schema, query=query
)

@staticmethod
def from_proto(data_source: DataSourceProto):
"""
Creates a RedshiftSource from a protobuf representation of a RedshiftSource.
Args:
data_source: A protobuf representation of a RedshiftSource
Returns:
A RedshiftSource object based on the data_source protobuf.
"""
return RedshiftSource(
field_mapping=dict(data_source.field_mapping),
table=data_source.redshift_options.table,
schema=data_source.redshift_options.schema,
event_timestamp_column=data_source.event_timestamp_column,
created_timestamp_column=data_source.created_timestamp_column,
date_partition_column=data_source.date_partition_column,
Expand All @@ -46,6 +77,7 @@ def __eq__(self, other):

return (
self.redshift_options.table == other.redshift_options.table
and self.redshift_options.schema == other.redshift_options.schema
and self.redshift_options.query == other.redshift_options.query
and self.event_timestamp_column == other.event_timestamp_column
and self.created_timestamp_column == other.created_timestamp_column
Expand All @@ -54,27 +86,36 @@ def __eq__(self, other):

@property
def table(self):
"""Returns the table of this Redshift source."""
return self._redshift_options.table

@property
def schema(self):
"""Returns the schema of this Redshift source."""
return self._redshift_options.schema

@property
def query(self):
"""Returns the Redshift options of this Redshift source."""
return self._redshift_options.query

@property
def redshift_options(self):
"""
Returns the Redshift options of this data source
"""
"""Returns the Redshift options of this Redshift source."""
return self._redshift_options

@redshift_options.setter
def redshift_options(self, _redshift_options):
"""
Sets the Redshift options of this data source
"""
"""Sets the Redshift options of this Redshift source."""
self._redshift_options = _redshift_options

def to_proto(self) -> DataSourceProto:
"""
Converts a RedshiftSource object to its protobuf representation.
Returns:
A DataSourceProto object.
"""
data_source_proto = DataSourceProto(
type=DataSourceProto.BATCH_REDSHIFT,
field_mapping=self.field_mapping,
Expand All @@ -93,9 +134,9 @@ def validate(self, config: RepoConfig):
self.get_table_column_names_and_types(config)

def get_table_query_string(self) -> str:
"""Returns a string that can directly be used to reference this table in SQL"""
"""Returns a string that can directly be used to reference this table in SQL."""
if self.table:
return f'"{self.table}"'
return f'"{self.schema}"."{self.table}"'
else:
return f"({self.query})"

Expand All @@ -106,6 +147,12 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
"""
Returns a mapping of column names to types for this Redshift source.
Args:
config: A RepoConfig describing the feature repo
"""
from botocore.exceptions import ClientError

from feast.infra.offline_stores.redshift import RedshiftOfflineStoreConfig
Expand All @@ -122,6 +169,7 @@ def get_table_column_names_and_types(
Database=config.offline_store.database,
DbUser=config.offline_store.user,
Table=self.table,
Schema=self.schema,
)
except ClientError as e:
if e.response["Error"]["Code"] == "ValidationException":
Expand Down Expand Up @@ -150,55 +198,61 @@ def get_table_column_names_and_types(

class RedshiftOptions:
"""
DataSource Redshift options used to source features from Redshift query
DataSource Redshift options used to source features from Redshift query.
"""

def __init__(self, table: Optional[str], query: Optional[str]):
def __init__(
self, table: Optional[str], schema: Optional[str], query: Optional[str]
):
self._table = table
self._schema = schema
self._query = query

@property
def query(self):
"""
Returns the Redshift SQL query referenced by this source
"""
"""Returns the Redshift SQL query referenced by this source."""
return self._query

@query.setter
def query(self, query):
"""
Sets the Redshift SQL query referenced by this source
"""
"""Sets the Redshift SQL query referenced by this source."""
self._query = query

@property
def table(self):
"""
Returns the table name of this Redshift table
"""
"""Returns the table name of this Redshift table."""
return self._table

@table.setter
def table(self, table_name):
"""
Sets the table ref of this Redshift table
"""
"""Sets the table ref of this Redshift table."""
self._table = table_name

@property
def schema(self):
"""Returns the schema name of this Redshift table."""
return self._schema

@schema.setter
def schema(self, schema):
"""Sets the schema of this Redshift table."""
self._schema = schema

@classmethod
def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):
"""
Creates a RedshiftOptions from a protobuf representation of a Redshift option
Creates a RedshiftOptions from a protobuf representation of a Redshift option.
Args:
redshift_options_proto: A protobuf representation of a DataSource
Returns:
Returns a RedshiftOptions object based on the redshift_options protobuf
A RedshiftOptions object based on the redshift_options protobuf.
"""

redshift_options = cls(
table=redshift_options_proto.table, query=redshift_options_proto.query,
table=redshift_options_proto.table,
schema=redshift_options_proto.schema,
query=redshift_options_proto.query,
)

return redshift_options
Expand All @@ -208,11 +262,10 @@ def to_proto(self) -> DataSourceProto.RedshiftOptions:
Converts an RedshiftOptionsProto object to its protobuf representation.
Returns:
RedshiftOptionsProto protobuf
A RedshiftOptionsProto protobuf.
"""

redshift_options_proto = DataSourceProto.RedshiftOptions(
table=self.table, query=self.query,
table=self.table, schema=self.schema, query=self.query,
)

return redshift_options_proto

0 comments on commit 809ce38

Please sign in to comment.