From c14023f5c15f4233f0ffd30b945f78c1106674b1 Mon Sep 17 00:00:00 2001 From: Tsotne Tabidze Date: Mon, 28 Jun 2021 11:05:59 -0700 Subject: [PATCH] Add RedshiftDataSource (#1669) * Add RedshiftDataSource Signed-off-by: Tsotne Tabidze * Call parent __init__ first Signed-off-by: Tsotne Tabidze Signed-off-by: Mwad22 <51929507+Mwad22@users.noreply.github.com> --- protos/feast/core/DataSource.proto | 12 + sdk/python/feast/__init__.py | 2 + sdk/python/feast/data_source.py | 266 +++++++++++++++++- sdk/python/feast/errors.py | 10 + sdk/python/feast/feature_store.py | 6 +- sdk/python/feast/feature_view.py | 13 +- sdk/python/feast/inference.py | 15 +- .../feast/infra/offline_stores/redshift.py | 60 ++++ sdk/python/feast/repo_config.py | 3 + sdk/python/feast/repo_operations.py | 8 +- sdk/python/feast/type_map.py | 24 ++ .../tensorflow_metadata/proto/v0/path_pb2.py | 2 +- .../proto/v0/schema_pb2.py | 2 +- .../proto/v0/statistics_pb2.py | 2 +- sdk/python/tests/test_inference.py | 22 +- 15 files changed, 414 insertions(+), 33 deletions(-) create mode 100644 sdk/python/feast/infra/offline_stores/redshift.py diff --git a/protos/feast/core/DataSource.proto b/protos/feast/core/DataSource.proto index a4c46e7508..1200c1b9be 100644 --- a/protos/feast/core/DataSource.proto +++ b/protos/feast/core/DataSource.proto @@ -33,6 +33,7 @@ message DataSource { BATCH_BIGQUERY = 2; STREAM_KAFKA = 3; STREAM_KINESIS = 4; + BATCH_REDSHIFT = 5; } SourceType type = 1; @@ -100,11 +101,22 @@ message DataSource { StreamFormat record_format = 3; } + // Defines options for DataSource that sources features from a Redshift Query + message RedshiftOptions { + // Redshift table name + string table = 1; + + // SQL query that returns a table containing feature data. Must contain an event_timestamp column, and respective + // entity columns + string query = 2; + } + // DataSource options. oneof options { FileOptions file_options = 11; BigQueryOptions bigquery_options = 12; KafkaOptions kafka_options = 13; KinesisOptions kinesis_options = 14; + RedshiftOptions redshift_options = 15; } } diff --git a/sdk/python/feast/__init__.py b/sdk/python/feast/__init__.py index 7a1a70f23c..6f1cb58451 100644 --- a/sdk/python/feast/__init__.py +++ b/sdk/python/feast/__init__.py @@ -8,6 +8,7 @@ FileSource, KafkaSource, KinesisSource, + RedshiftSource, SourceType, ) from .entity import Entity @@ -37,6 +38,7 @@ "FileSource", "KafkaSource", "KinesisSource", + "RedshiftSource", "Feature", "FeatureStore", "FeatureTable", diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index c25b64c82f..a5a620b55b 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -17,11 +17,17 @@ from typing import Callable, Dict, Iterable, Optional, Tuple from pyarrow.parquet import ParquetFile +from tenacity import retry, retry_unless_exception_type, wait_exponential from feast import type_map from feast.data_format import FileFormat, StreamFormat -from feast.errors import DataSourceNotFoundException +from feast.errors import ( + DataSourceNotFoundException, + RedshiftCredentialsError, + RedshiftQueryError, +) from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto +from feast.repo_config import RepoConfig from feast.value_type import ValueType @@ -477,6 +483,15 @@ def from_proto(data_source): date_partition_column=data_source.date_partition_column, query=data_source.bigquery_options.query, ) + elif data_source.redshift_options.table or data_source.redshift_options.query: + data_source_obj = RedshiftSource( + field_mapping=data_source.field_mapping, + table=data_source.redshift_options.table, + event_timestamp_column=data_source.event_timestamp_column, + created_timestamp_column=data_source.created_timestamp_column, + date_partition_column=data_source.date_partition_column, + query=data_source.redshift_options.query, + ) elif ( data_source.kafka_options.bootstrap_servers and data_source.kafka_options.topic @@ -520,12 +535,27 @@ def to_proto(self) -> DataSourceProto: """ raise NotImplementedError - def validate(self): + def validate(self, config: RepoConfig): """ Validates the underlying data source. """ raise NotImplementedError + @staticmethod + def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: + """ + Get the callable method that returns Feast type given the raw column type + """ + raise NotImplementedError + + def get_table_column_names_and_types( + self, config: RepoConfig + ) -> Iterable[Tuple[str, str]]: + """ + Get the list of column names and raw column types + """ + raise NotImplementedError + class FileSource(DataSource): def __init__( @@ -622,7 +652,7 @@ def to_proto(self) -> DataSourceProto: return data_source_proto - def validate(self): + def validate(self, config: RepoConfig): # TODO: validate a FileSource pass @@ -630,7 +660,9 @@ def validate(self): def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: return type_map.pa_to_feast_value_type - def get_table_column_names_and_types(self) -> Iterable[Tuple[str, str]]: + def get_table_column_names_and_types( + self, config: RepoConfig + ) -> Iterable[Tuple[str, str]]: schema = ParquetFile(self.path).schema_arrow return zip(schema.names, map(str, schema.types)) @@ -703,7 +735,7 @@ def to_proto(self) -> DataSourceProto: return data_source_proto - def validate(self): + def validate(self, config: RepoConfig): if not self.query: from google.api_core.exceptions import NotFound from google.cloud import bigquery @@ -725,7 +757,9 @@ def get_table_query_string(self) -> str: def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: return type_map.bq_to_feast_value_type - def get_table_column_names_and_types(self) -> Iterable[Tuple[str, str]]: + def get_table_column_names_and_types( + self, config: RepoConfig + ) -> Iterable[Tuple[str, str]]: from google.cloud import bigquery client = bigquery.Client() @@ -875,3 +909,223 @@ def to_proto(self) -> DataSourceProto: data_source_proto.date_partition_column = self.date_partition_column return data_source_proto + + +class RedshiftOptions: + """ + DataSource Redshift options used to source features from Redshift query + """ + + def __init__(self, table: Optional[str], query: Optional[str]): + self._table = table + self._query = query + + @property + def query(self): + """ + Returns the Redshift SQL query referenced by this source + """ + return self._query + + @query.setter + def query(self, query): + """ + Sets the Redshift SQL query referenced by this source + """ + self._query = query + + @property + def table(self): + """ + Returns the table name of this Redshift table + """ + return self._table + + @table.setter + def table(self, table_name): + """ + Sets the table ref of this Redshift table + """ + self._table = table_name + + @classmethod + def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions): + """ + Creates a RedshiftOptions from a protobuf representation of a Redshift option + + Args: + redshift_options_proto: A protobuf representation of a DataSource + + Returns: + Returns a RedshiftOptions object based on the redshift_options protobuf + """ + + redshift_options = cls( + table=redshift_options_proto.table, query=redshift_options_proto.query, + ) + + return redshift_options + + def to_proto(self) -> DataSourceProto.RedshiftOptions: + """ + Converts an RedshiftOptionsProto object to its protobuf representation. + + Returns: + RedshiftOptionsProto protobuf + """ + + redshift_options_proto = DataSourceProto.RedshiftOptions( + table=self.table, query=self.query, + ) + + return redshift_options_proto + + +class RedshiftSource(DataSource): + def __init__( + self, + event_timestamp_column: Optional[str] = "", + table: Optional[str] = None, + created_timestamp_column: Optional[str] = "", + field_mapping: Optional[Dict[str, str]] = None, + date_partition_column: Optional[str] = "", + query: Optional[str] = None, + ): + super().__init__( + event_timestamp_column, + created_timestamp_column, + field_mapping, + date_partition_column, + ) + + self._redshift_options = RedshiftOptions(table=table, query=query) + + def __eq__(self, other): + if not isinstance(other, RedshiftSource): + raise TypeError( + "Comparisons should only involve RedshiftSource class objects." + ) + + return ( + self.redshift_options.table == other.redshift_options.table + and self.redshift_options.query == other.redshift_options.query + and self.event_timestamp_column == other.event_timestamp_column + and self.created_timestamp_column == other.created_timestamp_column + and self.field_mapping == other.field_mapping + ) + + @property + def table(self): + return self._redshift_options.table + + @property + def query(self): + return self._redshift_options.query + + @property + def redshift_options(self): + """ + Returns the Redshift options of this data source + """ + return self._redshift_options + + @redshift_options.setter + def redshift_options(self, _redshift_options): + """ + Sets the Redshift options of this data source + """ + self._redshift_options = _redshift_options + + def to_proto(self) -> DataSourceProto: + data_source_proto = DataSourceProto( + type=DataSourceProto.BATCH_REDSHIFT, + field_mapping=self.field_mapping, + redshift_options=self.redshift_options.to_proto(), + ) + + data_source_proto.event_timestamp_column = self.event_timestamp_column + data_source_proto.created_timestamp_column = self.created_timestamp_column + data_source_proto.date_partition_column = self.date_partition_column + + return data_source_proto + + def validate(self, config: RepoConfig): + # As long as the query gets successfully executed, or the table exists, + # the data source is validated. We don't need the results though. + # TODO: uncomment this + # self.get_table_column_names_and_types(config) + print("Validate", self.get_table_column_names_and_types(config)) + + def get_table_query_string(self) -> str: + """Returns a string that can directly be used to reference this table in SQL""" + if self.table: + return f"`{self.table}`" + else: + return f"({self.query})" + + @staticmethod + def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: + return type_map.redshift_to_feast_value_type + + def get_table_column_names_and_types( + self, config: RepoConfig + ) -> Iterable[Tuple[str, str]]: + import boto3 + from botocore.config import Config + from botocore.exceptions import ClientError + + from feast.infra.offline_stores.redshift import RedshiftOfflineStoreConfig + + assert isinstance(config.offline_store, RedshiftOfflineStoreConfig) + + client = boto3.client( + "redshift-data", config=Config(region_name=config.offline_store.region) + ) + + try: + if self.table is not None: + table = client.describe_table( + ClusterIdentifier=config.offline_store.cluster_id, + Database=config.offline_store.database, + DbUser=config.offline_store.user, + Table=self.table, + ) + # The API returns valid JSON with empty column list when the table doesn't exist + if len(table["ColumnList"]) == 0: + raise DataSourceNotFoundException(self.table) + + columns = table["ColumnList"] + else: + statement = client.execute_statement( + ClusterIdentifier=config.offline_store.cluster_id, + Database=config.offline_store.database, + DbUser=config.offline_store.user, + Sql=f"SELECT * FROM ({self.query}) LIMIT 1", + ) + + # Need to retry client.describe_statement(...) until the task is finished. We don't want to bombard + # Redshift with queries, and neither do we want to wait for a long time on the initial call. + # The solution is exponential backoff. The backoff starts with 0.1 seconds and doubles exponentially + # until reaching 30 seconds, at which point the backoff is fixed. + @retry( + wait=wait_exponential(multiplier=0.1, max=30), + retry=retry_unless_exception_type(RedshiftQueryError), + ) + def wait_for_statement(): + desc = client.describe_statement(Id=statement["Id"]) + if desc["Status"] in ("SUBMITTED", "STARTED", "PICKED"): + raise Exception # Retry + if desc["Status"] != "FINISHED": + raise RedshiftQueryError(desc) # Don't retry. Raise exception. + + wait_for_statement() + + result = client.get_statement_result(Id=statement["Id"]) + + columns = result["ColumnMetadata"] + except ClientError as e: + if e.response["Error"]["Code"] == "ValidationException": + raise RedshiftCredentialsError() from e + raise + + return [(column["name"], column["typeName"].upper()) for column in columns] diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index 58e34dede4..78d3985255 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -133,3 +133,13 @@ def __init__(self, repo_obj_type: str, specific_issue: str): f"Inference to fill in missing information for {repo_obj_type} failed. {specific_issue}. " "Try filling the information explicitly." ) + + +class RedshiftCredentialsError(Exception): + def __init__(self): + super().__init__("Redshift API failed due to incorrect credentials") + + +class RedshiftQueryError(Exception): + def __init__(self, details): + super().__init__(f"Redshift SQL Query failed to finish. Details: {details}") diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 22d9df8d17..c2f58cd9fe 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -232,13 +232,13 @@ def apply( # Make inferences update_entities_with_inferred_types_from_feature_views( - entities_to_update, views_to_update + entities_to_update, views_to_update, self.config ) update_data_sources_with_inferred_event_timestamp_col( - [view.input for view in views_to_update] + [view.input for view in views_to_update], self.config ) for view in views_to_update: - view.infer_features_from_input_source() + view.infer_features_from_input_source(self.config) if len(views_to_update) + len(entities_to_update) != len(objects): raise ValueError("Unknown object type provided as part of apply() call") diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index d89a8a7b6b..db22fd7e4a 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -20,7 +20,7 @@ from google.protobuf.timestamp_pb2 import Timestamp from feast import utils -from feast.data_source import BigQuerySource, DataSource, FileSource +from feast.data_source import DataSource from feast.errors import RegistryInferenceFailure from feast.feature import Feature from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto @@ -33,6 +33,7 @@ from feast.protos.feast.core.FeatureView_pb2 import ( MaterializationInterval as MaterializationIntervalProto, ) +from feast.repo_config import RepoConfig from feast.usage import log_exceptions from feast.value_type import ValueType @@ -48,7 +49,7 @@ class FeatureView: tags: Optional[Dict[str, str]] ttl: Optional[timedelta] online: bool - input: Union[BigQuerySource, FileSource] + input: DataSource created_timestamp: Optional[Timestamp] = None last_updated_timestamp: Optional[Timestamp] = None @@ -60,7 +61,7 @@ def __init__( name: str, entities: List[str], ttl: Optional[Union[Duration, timedelta]], - input: Union[BigQuerySource, FileSource], + input: DataSource, features: List[Feature] = [], tags: Optional[Dict[str, str]] = None, online: bool = True, @@ -220,14 +221,16 @@ def most_recent_end_time(self) -> Optional[datetime]: return None return max([interval[1] for interval in self.materialization_intervals]) - def infer_features_from_input_source(self): + def infer_features_from_input_source(self, config: RepoConfig): if not self.features: columns_to_exclude = { self.input.event_timestamp_column, self.input.created_timestamp_column, } | set(self.entities) - for col_name, col_datatype in self.input.get_table_column_names_and_types(): + for col_name, col_datatype in self.input.get_table_column_names_and_types( + config + ): if col_name not in columns_to_exclude and not re.match( "^__|__$", col_name, # double underscores often signal an internal-use column diff --git a/sdk/python/feast/inference.py b/sdk/python/feast/inference.py index af95f9d255..28b764fd80 100644 --- a/sdk/python/feast/inference.py +++ b/sdk/python/feast/inference.py @@ -1,15 +1,16 @@ import re -from typing import List, Union +from typing import List from feast import Entity -from feast.data_source import BigQuerySource, FileSource +from feast.data_source import BigQuerySource, DataSource, FileSource, RedshiftSource from feast.errors import RegistryInferenceFailure from feast.feature_view import FeatureView +from feast.repo_config import RepoConfig from feast.value_type import ValueType def update_entities_with_inferred_types_from_feature_views( - entities: List[Entity], feature_views: List[FeatureView] + entities: List[Entity], feature_views: List[FeatureView], config: RepoConfig ) -> None: """ Infer entity value type by examining schema of feature view input sources @@ -25,7 +26,7 @@ def update_entities_with_inferred_types_from_feature_views( if not (incomplete_entities_keys & set(view.entities)): continue # skip if view doesn't contain any entities that need inference - col_names_and_types = view.input.get_table_column_names_and_types() + col_names_and_types = view.input.get_table_column_names_and_types(config) for entity_name in view.entities: if entity_name in incomplete_entities: # get entity information from information extracted from the view input source @@ -59,7 +60,7 @@ def update_entities_with_inferred_types_from_feature_views( def update_data_sources_with_inferred_event_timestamp_col( - data_sources: List[Union[BigQuerySource, FileSource]], + data_sources: List[DataSource], config: RepoConfig ) -> None: ERROR_MSG_PREFIX = "Unable to infer DataSource event_timestamp_column" @@ -74,6 +75,8 @@ def update_data_sources_with_inferred_event_timestamp_col( ts_column_type_regex_pattern = r"^timestamp" elif isinstance(data_source, BigQuerySource): ts_column_type_regex_pattern = "TIMESTAMP|DATETIME" + elif isinstance(data_source, RedshiftSource): + ts_column_type_regex_pattern = "TIMESTAMP[A-Z]*" else: raise RegistryInferenceFailure( "DataSource", @@ -92,7 +95,7 @@ def update_data_sources_with_inferred_event_timestamp_col( for ( col_name, col_datatype, - ) in data_source.get_table_column_names_and_types(): + ) in data_source.get_table_column_names_and_types(config): if re.match(ts_column_type_regex_pattern, col_datatype): if matched_flag: raise RegistryInferenceFailure( diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py new file mode 100644 index 0000000000..06a437564a --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -0,0 +1,60 @@ +from datetime import datetime +from typing import List, Optional, Union + +import pandas as pd +import pyarrow +from pydantic import StrictStr +from pydantic.typing import Literal + +from feast.data_source import DataSource +from feast.feature_view import FeatureView +from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob +from feast.registry import Registry +from feast.repo_config import FeastConfigBaseModel, RepoConfig + + +class RedshiftOfflineStoreConfig(FeastConfigBaseModel): + """ Offline store config for AWS Redshift """ + + type: Literal["redshift"] = "redshift" + """ Offline store type selector""" + + cluster_id: StrictStr + """ Redshift cluster identifier """ + + region: StrictStr + """ Redshift cluster's AWS region """ + + user: StrictStr + """ Redshift user name """ + + database: StrictStr + """ Redshift database name """ + + s3_path: StrictStr + """ S3 path for importing & exporting data to Redshift """ + + +class RedshiftOfflineStore(OfflineStore): + @staticmethod + def pull_latest_from_table_or_query( + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + event_timestamp_column: str, + created_timestamp_column: Optional[str], + start_date: datetime, + end_date: datetime, + ) -> pyarrow.Table: + pass + + @staticmethod + def get_historical_features( + config: RepoConfig, + feature_views: List[FeatureView], + feature_refs: List[str], + entity_df: Union[pd.DataFrame, str], + registry: Registry, + project: str, + ) -> RetrievalJob: + pass diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index c680d94d07..968af8bc9e 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -21,6 +21,7 @@ OFFLINE_STORE_CLASS_FOR_TYPE = { "file": "feast.infra.offline_stores.file.FileOfflineStore", "bigquery": "feast.infra.offline_stores.bigquery.BigQueryOfflineStore", + "redshift": "feast.infra.offline_stores.redshift.RedshiftOfflineStore", } @@ -154,6 +155,8 @@ def _validate_offline_store_config(cls, values): values["offline_store"]["type"] = "file" elif values["provider"] == "gcp": values["offline_store"]["type"] = "bigquery" + elif values["provider"] == "aws": + values["offline_store"]["type"] = "redshift" offline_store_type = values["offline_store"]["type"] diff --git a/sdk/python/feast/repo_operations.py b/sdk/python/feast/repo_operations.py index f707f19c41..eeae8f3d89 100644 --- a/sdk/python/feast/repo_operations.py +++ b/sdk/python/feast/repo_operations.py @@ -135,15 +135,15 @@ def apply_total(repo_config: RepoConfig, repo_path: Path): # Make sure the data source used by this feature view is supported by Feast for data_source in data_sources: - data_source.validate() + data_source.validate(repo_config) # Make inferences update_entities_with_inferred_types_from_feature_views( - repo.entities, repo.feature_views + repo.entities, repo.feature_views, repo_config ) - update_data_sources_with_inferred_event_timestamp_col(data_sources) + update_data_sources_with_inferred_event_timestamp_col(data_sources, repo_config) for view in repo.feature_views: - view.infer_features_from_input_source() + view.infer_features_from_input_source(repo_config) sys.dont_write_bytecode = False for entity in repo.entities: diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index 576a0b7f35..54d30cbd22 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -528,3 +528,27 @@ def bq_to_feast_value_type(bq_type_as_str): } return type_map[bq_type_as_str] + + +def redshift_to_feast_value_type(redshift_type_as_str): + # Type names from https://docs.aws.amazon.com/redshift/latest/dg/c_Supported_data_types.html + type_map: Dict[ValueType, Union[str, Dict[str, Any]]] = { + "INT": ValueType.INT32, + "INT4": ValueType.INT32, + "INT8": ValueType.INT64, + "FLOAT4": ValueType.FLOAT, + "FLOAT8": ValueType.DOUBLE, + "FLOAT": ValueType.DOUBLE, + "NUMERIC": ValueType.DOUBLE, + "BOOL": ValueType.BOOL, + "CHARACTER": ValueType.STRING, + "NCHAR": ValueType.STRING, + "BPCHAR": ValueType.STRING, + "CHARACTER VARYING": ValueType.STRING, + "NVARCHAR": ValueType.STRING, + "TEXT": ValueType.STRING, + "TIMESTAMP WITHOUT TIME ZONE": ValueType.UNIX_TIMESTAMP, + "TIMESTAMP WITH TIME ZONE": ValueType.UNIX_TIMESTAMP, + } + + return type_map[redshift_type_as_str.upper()] diff --git a/sdk/python/tensorflow_metadata/proto/v0/path_pb2.py b/sdk/python/tensorflow_metadata/proto/v0/path_pb2.py index 4b6dec828c..d732119ead 100644 --- a/sdk/python/tensorflow_metadata/proto/v0/path_pb2.py +++ b/sdk/python/tensorflow_metadata/proto/v0/path_pb2.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: tensorflow_metadata/proto/v0/path.proto -"""Generated protocol buffer code.""" + from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection diff --git a/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.py b/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.py index d3bfc50616..78fda8003d 100644 --- a/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.py +++ b/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: tensorflow_metadata/proto/v0/schema.proto -"""Generated protocol buffer code.""" + from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message diff --git a/sdk/python/tensorflow_metadata/proto/v0/statistics_pb2.py b/sdk/python/tensorflow_metadata/proto/v0/statistics_pb2.py index 21473adc75..d8e12bd120 100644 --- a/sdk/python/tensorflow_metadata/proto/v0/statistics_pb2.py +++ b/sdk/python/tensorflow_metadata/proto/v0/statistics_pb2.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: tensorflow_metadata/proto/v0/statistics.proto -"""Generated protocol buffer code.""" + from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection diff --git a/sdk/python/tests/test_inference.py b/sdk/python/tests/test_inference.py index 9405bce1f2..cff5f33f74 100644 --- a/sdk/python/tests/test_inference.py +++ b/sdk/python/tests/test_inference.py @@ -5,7 +5,7 @@ simple_bq_source_using_table_ref_arg, ) -from feast import Entity, ValueType +from feast import Entity, RepoConfig, ValueType from feast.errors import RegistryInferenceFailure from feast.feature_view import FeatureView from feast.inference import ( @@ -29,15 +29,21 @@ def test_update_entities_with_inferred_types_from_feature_views( actual_1 = Entity(name="id") actual_2 = Entity(name="id") - update_entities_with_inferred_types_from_feature_views([actual_1], [fv1]) - update_entities_with_inferred_types_from_feature_views([actual_2], [fv2]) + update_entities_with_inferred_types_from_feature_views( + [actual_1], [fv1], RepoConfig(provider="local", project="test") + ) + update_entities_with_inferred_types_from_feature_views( + [actual_2], [fv2], RepoConfig(provider="local", project="test") + ) assert actual_1 == Entity(name="id", value_type=ValueType.INT64) assert actual_2 == Entity(name="id", value_type=ValueType.STRING) with pytest.raises(RegistryInferenceFailure): # two viable data types update_entities_with_inferred_types_from_feature_views( - [Entity(name="id")], [fv1, fv2] + [Entity(name="id")], + [fv1, fv2], + RepoConfig(provider="local", project="test"), ) @@ -52,7 +58,9 @@ def test_update_data_sources_with_inferred_event_timestamp_col(simple_dataset_1) simple_bq_source_using_table_ref_arg(simple_dataset_1), simple_bq_source_using_query_arg(simple_dataset_1), ] - update_data_sources_with_inferred_event_timestamp_col(data_sources) + update_data_sources_with_inferred_event_timestamp_col( + data_sources, RepoConfig(provider="local", project="test") + ) actual_event_timestamp_cols = [ source.event_timestamp_column for source in data_sources ] @@ -62,4 +70,6 @@ def test_update_data_sources_with_inferred_event_timestamp_col(simple_dataset_1) with prep_file_source(df=df_with_two_viable_timestamp_cols) as file_source: with pytest.raises(RegistryInferenceFailure): # two viable event_timestamp_columns - update_data_sources_with_inferred_event_timestamp_col([file_source]) + update_data_sources_with_inferred_event_timestamp_col( + [file_source], RepoConfig(provider="local", project="test") + )