Skip to content

Commit

Permalink
Add RedshiftDataSource (#1669)
Browse files Browse the repository at this point in the history
* Add RedshiftDataSource

Signed-off-by: Tsotne Tabidze <tsotne@tecton.ai>

* Call parent __init__ first

Signed-off-by: Tsotne Tabidze <tsotne@tecton.ai>
  • Loading branch information
Tsotne Tabidze authored Jun 28, 2021
1 parent 3cb303f commit 1761b10
Show file tree
Hide file tree
Showing 15 changed files with 414 additions and 33 deletions.
12 changes: 12 additions & 0 deletions protos/feast/core/DataSource.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ message DataSource {
BATCH_BIGQUERY = 2;
STREAM_KAFKA = 3;
STREAM_KINESIS = 4;
BATCH_REDSHIFT = 5;
}
SourceType type = 1;

Expand Down Expand Up @@ -100,11 +101,22 @@ message DataSource {
StreamFormat record_format = 3;
}

// Defines options for DataSource that sources features from a Redshift Query
message RedshiftOptions {
// Redshift table name
string table = 1;

// SQL query that returns a table containing feature data. Must contain an event_timestamp column, and respective
// entity columns
string query = 2;
}

// DataSource options.
oneof options {
FileOptions file_options = 11;
BigQueryOptions bigquery_options = 12;
KafkaOptions kafka_options = 13;
KinesisOptions kinesis_options = 14;
RedshiftOptions redshift_options = 15;
}
}
2 changes: 2 additions & 0 deletions sdk/python/feast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
FileSource,
KafkaSource,
KinesisSource,
RedshiftSource,
SourceType,
)
from .entity import Entity
Expand Down Expand Up @@ -37,6 +38,7 @@
"FileSource",
"KafkaSource",
"KinesisSource",
"RedshiftSource",
"Feature",
"FeatureStore",
"FeatureTable",
Expand Down
266 changes: 260 additions & 6 deletions sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@
from typing import Callable, Dict, Iterable, Optional, Tuple

from pyarrow.parquet import ParquetFile
from tenacity import retry, retry_unless_exception_type, wait_exponential

from feast import type_map
from feast.data_format import FileFormat, StreamFormat
from feast.errors import DataSourceNotFoundException
from feast.errors import (
DataSourceNotFoundException,
RedshiftCredentialsError,
RedshiftQueryError,
)
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.repo_config import RepoConfig
from feast.value_type import ValueType


Expand Down Expand Up @@ -477,6 +483,15 @@ def from_proto(data_source):
date_partition_column=data_source.date_partition_column,
query=data_source.bigquery_options.query,
)
elif data_source.redshift_options.table or data_source.redshift_options.query:
data_source_obj = RedshiftSource(
field_mapping=data_source.field_mapping,
table=data_source.redshift_options.table,
event_timestamp_column=data_source.event_timestamp_column,
created_timestamp_column=data_source.created_timestamp_column,
date_partition_column=data_source.date_partition_column,
query=data_source.redshift_options.query,
)
elif (
data_source.kafka_options.bootstrap_servers
and data_source.kafka_options.topic
Expand Down Expand Up @@ -520,12 +535,27 @@ def to_proto(self) -> DataSourceProto:
"""
raise NotImplementedError

def validate(self):
def validate(self, config: RepoConfig):
"""
Validates the underlying data source.
"""
raise NotImplementedError

@staticmethod
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
"""
Get the callable method that returns Feast type given the raw column type
"""
raise NotImplementedError

def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
"""
Get the list of column names and raw column types
"""
raise NotImplementedError


class FileSource(DataSource):
def __init__(
Expand Down Expand Up @@ -622,15 +652,17 @@ def to_proto(self) -> DataSourceProto:

return data_source_proto

def validate(self):
def validate(self, config: RepoConfig):
# TODO: validate a FileSource
pass

@staticmethod
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
return type_map.pa_to_feast_value_type

def get_table_column_names_and_types(self) -> Iterable[Tuple[str, str]]:
def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
schema = ParquetFile(self.path).schema_arrow
return zip(schema.names, map(str, schema.types))

Expand Down Expand Up @@ -703,7 +735,7 @@ def to_proto(self) -> DataSourceProto:

return data_source_proto

def validate(self):
def validate(self, config: RepoConfig):
if not self.query:
from google.api_core.exceptions import NotFound
from google.cloud import bigquery
Expand All @@ -725,7 +757,9 @@ def get_table_query_string(self) -> str:
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
return type_map.bq_to_feast_value_type

def get_table_column_names_and_types(self) -> Iterable[Tuple[str, str]]:
def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
from google.cloud import bigquery

client = bigquery.Client()
Expand Down Expand Up @@ -875,3 +909,223 @@ def to_proto(self) -> DataSourceProto:
data_source_proto.date_partition_column = self.date_partition_column

return data_source_proto


class RedshiftOptions:
"""
DataSource Redshift options used to source features from Redshift query
"""

def __init__(self, table: Optional[str], query: Optional[str]):
self._table = table
self._query = query

@property
def query(self):
"""
Returns the Redshift SQL query referenced by this source
"""
return self._query

@query.setter
def query(self, query):
"""
Sets the Redshift SQL query referenced by this source
"""
self._query = query

@property
def table(self):
"""
Returns the table name of this Redshift table
"""
return self._table

@table.setter
def table(self, table_name):
"""
Sets the table ref of this Redshift table
"""
self._table = table_name

@classmethod
def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):
"""
Creates a RedshiftOptions from a protobuf representation of a Redshift option
Args:
redshift_options_proto: A protobuf representation of a DataSource
Returns:
Returns a RedshiftOptions object based on the redshift_options protobuf
"""

redshift_options = cls(
table=redshift_options_proto.table, query=redshift_options_proto.query,
)

return redshift_options

def to_proto(self) -> DataSourceProto.RedshiftOptions:
"""
Converts an RedshiftOptionsProto object to its protobuf representation.
Returns:
RedshiftOptionsProto protobuf
"""

redshift_options_proto = DataSourceProto.RedshiftOptions(
table=self.table, query=self.query,
)

return redshift_options_proto


class RedshiftSource(DataSource):
def __init__(
self,
event_timestamp_column: Optional[str] = "",
table: Optional[str] = None,
created_timestamp_column: Optional[str] = "",
field_mapping: Optional[Dict[str, str]] = None,
date_partition_column: Optional[str] = "",
query: Optional[str] = None,
):
super().__init__(
event_timestamp_column,
created_timestamp_column,
field_mapping,
date_partition_column,
)

self._redshift_options = RedshiftOptions(table=table, query=query)

def __eq__(self, other):
if not isinstance(other, RedshiftSource):
raise TypeError(
"Comparisons should only involve RedshiftSource class objects."
)

return (
self.redshift_options.table == other.redshift_options.table
and self.redshift_options.query == other.redshift_options.query
and self.event_timestamp_column == other.event_timestamp_column
and self.created_timestamp_column == other.created_timestamp_column
and self.field_mapping == other.field_mapping
)

@property
def table(self):
return self._redshift_options.table

@property
def query(self):
return self._redshift_options.query

@property
def redshift_options(self):
"""
Returns the Redshift options of this data source
"""
return self._redshift_options

@redshift_options.setter
def redshift_options(self, _redshift_options):
"""
Sets the Redshift options of this data source
"""
self._redshift_options = _redshift_options

def to_proto(self) -> DataSourceProto:
data_source_proto = DataSourceProto(
type=DataSourceProto.BATCH_REDSHIFT,
field_mapping=self.field_mapping,
redshift_options=self.redshift_options.to_proto(),
)

data_source_proto.event_timestamp_column = self.event_timestamp_column
data_source_proto.created_timestamp_column = self.created_timestamp_column
data_source_proto.date_partition_column = self.date_partition_column

return data_source_proto

def validate(self, config: RepoConfig):
# As long as the query gets successfully executed, or the table exists,
# the data source is validated. We don't need the results though.
# TODO: uncomment this
# self.get_table_column_names_and_types(config)
print("Validate", self.get_table_column_names_and_types(config))

def get_table_query_string(self) -> str:
"""Returns a string that can directly be used to reference this table in SQL"""
if self.table:
return f"`{self.table}`"
else:
return f"({self.query})"

@staticmethod
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
return type_map.redshift_to_feast_value_type

def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError

from feast.infra.offline_stores.redshift import RedshiftOfflineStoreConfig

assert isinstance(config.offline_store, RedshiftOfflineStoreConfig)

client = boto3.client(
"redshift-data", config=Config(region_name=config.offline_store.region)
)

try:
if self.table is not None:
table = client.describe_table(
ClusterIdentifier=config.offline_store.cluster_id,
Database=config.offline_store.database,
DbUser=config.offline_store.user,
Table=self.table,
)
# The API returns valid JSON with empty column list when the table doesn't exist
if len(table["ColumnList"]) == 0:
raise DataSourceNotFoundException(self.table)

columns = table["ColumnList"]
else:
statement = client.execute_statement(
ClusterIdentifier=config.offline_store.cluster_id,
Database=config.offline_store.database,
DbUser=config.offline_store.user,
Sql=f"SELECT * FROM ({self.query}) LIMIT 1",
)

# Need to retry client.describe_statement(...) until the task is finished. We don't want to bombard
# Redshift with queries, and neither do we want to wait for a long time on the initial call.
# The solution is exponential backoff. The backoff starts with 0.1 seconds and doubles exponentially
# until reaching 30 seconds, at which point the backoff is fixed.
@retry(
wait=wait_exponential(multiplier=0.1, max=30),
retry=retry_unless_exception_type(RedshiftQueryError),
)
def wait_for_statement():
desc = client.describe_statement(Id=statement["Id"])
if desc["Status"] in ("SUBMITTED", "STARTED", "PICKED"):
raise Exception # Retry
if desc["Status"] != "FINISHED":
raise RedshiftQueryError(desc) # Don't retry. Raise exception.

wait_for_statement()

result = client.get_statement_result(Id=statement["Id"])

columns = result["ColumnMetadata"]
except ClientError as e:
if e.response["Error"]["Code"] == "ValidationException":
raise RedshiftCredentialsError() from e
raise

return [(column["name"], column["typeName"].upper()) for column in columns]
10 changes: 10 additions & 0 deletions sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,13 @@ def __init__(self, repo_obj_type: str, specific_issue: str):
f"Inference to fill in missing information for {repo_obj_type} failed. {specific_issue}. "
"Try filling the information explicitly."
)


class RedshiftCredentialsError(Exception):
def __init__(self):
super().__init__("Redshift API failed due to incorrect credentials")


class RedshiftQueryError(Exception):
def __init__(self, details):
super().__init__(f"Redshift SQL Query failed to finish. Details: {details}")
Loading

0 comments on commit 1761b10

Please sign in to comment.