Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Enforce kw args in datasources #2567

Merged
merged 15 commits into from
Apr 19, 2022
203 changes: 168 additions & 35 deletions sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class DataSource(ABC):

def __init__(
self,
*,
kevjumba marked this conversation as resolved.
Show resolved Hide resolved
event_timestamp_column: Optional[str] = None,
created_timestamp_column: Optional[str] = None,
field_mapping: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -354,11 +355,12 @@ def get_table_column_names_and_types(

def __init__(
self,
name: str,
event_timestamp_column: str,
bootstrap_servers: str,
message_format: StreamFormat,
topic: str,
*args,
name: Optional[str] = None,
event_timestamp_column: Optional[str] = "",
adchia marked this conversation as resolved.
Show resolved Hide resolved
bootstrap_servers: Optional[str] = None,
message_format: Optional[StreamFormat] = None,
topic: Optional[str] = None,
created_timestamp_column: Optional[str] = "",
field_mapping: Optional[Dict[str, str]] = None,
date_partition_column: Optional[str] = "",
Expand All @@ -368,22 +370,62 @@ def __init__(
timestamp_field: Optional[str] = "",
batch_source: Optional[DataSource] = None,
):
positional_attributes = [
"name",
"event_timestamp_column",
"bootstrap_servers",
"message_format",
"topic",
]
_name = name
_event_timestamp_column = event_timestamp_column
_bootstrap_servers = bootstrap_servers or ""
_message_format = message_format
_topic = topic or ""

if args:
warnings.warn(
(
"Kafka parameters should be specified as a keyword argument instead of a positional arg."
"Feast 0.23+ will not support positional arguments to construct Kafka sources"
),
DeprecationWarning,
)
if len(args) > len(positional_attributes):
raise ValueError(
f"Only {', '.join(positional_attributes)} are allowed as positional args when defining "
f"Kafka sources, for backwards compatibility."
)
if len(args) >= 1:
_name = args[0]
if len(args) >= 2:
_event_timestamp_column = args[1]
if len(args) >= 3:
_bootstrap_servers = args[2]
if len(args) >= 4:
_message_format = args[3]
if len(args) >= 5:
_topic = args[4]

if _message_format is None:
raise ValueError("Message format must be specified for Kafka source")
print("Asdfasdf")
super().__init__(
event_timestamp_column=event_timestamp_column,
event_timestamp_column=_event_timestamp_column,
created_timestamp_column=created_timestamp_column,
field_mapping=field_mapping,
date_partition_column=date_partition_column,
description=description,
tags=tags,
owner=owner,
name=name,
name=_name,
timestamp_field=timestamp_field,
)
self.batch_source = batch_source
self.kafka_options = KafkaOptions(
bootstrap_servers=bootstrap_servers,
message_format=message_format,
topic=topic,
bootstrap_servers=_bootstrap_servers,
message_format=_message_format,
topic=_topic,
)

def __eq__(self, other):
Expand Down Expand Up @@ -472,32 +514,56 @@ class RequestSource(DataSource):

def __init__(
self,
name: str,
schema: Union[Dict[str, ValueType], List[Field]],
*args,
name: Optional[str] = None,
schema: Optional[Union[Dict[str, ValueType], List[Field]]] = None,
description: Optional[str] = "",
tags: Optional[Dict[str, str]] = None,
owner: Optional[str] = "",
):
"""Creates a RequestSource object."""
super().__init__(name=name, description=description, tags=tags, owner=owner)
if isinstance(schema, Dict):
positional_attributes = ["name", "schema"]
_name = name
_schema = schema
if args:
warnings.warn(
(
"Request source parameters should be specified as a keyword argument instead of a positional arg."
"Feast 0.23+ will not support positional arguments to construct request sources"
),
DeprecationWarning,
)
if len(args) > len(positional_attributes):
raise ValueError(
f"Only {', '.join(positional_attributes)} are allowed as positional args when defining "
f"feature views, for backwards compatibility."
)
if len(args) >= 1:
_name = args[0]
if len(args) >= 2:
_schema = args[1]

super().__init__(name=_name, description=description, tags=tags, owner=owner)
if not _schema:
raise ValueError("Schema needs to be provided for Request Source")
if isinstance(_schema, Dict):
warnings.warn(
"Schema in RequestSource is changing type. The schema data type Dict[str, ValueType] is being deprecated in Feast 0.23. "
"Please use List[Field] instead for the schema",
DeprecationWarning,
)
schemaList = []
for key, valueType in schema.items():
for key, valueType in _schema.items():
schemaList.append(
Field(name=key, dtype=VALUE_TYPES_TO_FEAST_TYPES[valueType])
)
self.schema = schemaList
elif isinstance(schema, List):
self.schema = schema
elif isinstance(_schema, List):
self.schema = _schema
else:
raise Exception(
"Schema type must be either dictionary or list, not "
+ str(type(schema))
+ str(type(_schema))
)

def validate(self, config: RepoConfig):
Expand Down Expand Up @@ -643,12 +709,13 @@ def get_table_query_string(self) -> str:

def __init__(
self,
name: str,
event_timestamp_column: str,
created_timestamp_column: str,
record_format: StreamFormat,
region: str,
stream_name: str,
*args,
name: Optional[str] = None,
event_timestamp_column: Optional[str] = "",
created_timestamp_column: Optional[str] = "",
record_format: Optional[StreamFormat] = None,
region: Optional[str] = "",
stream_name: Optional[str] = "",
field_mapping: Optional[Dict[str, str]] = None,
date_partition_column: Optional[str] = "",
description: Optional[str] = "",
Expand All @@ -657,10 +724,53 @@ def __init__(
timestamp_field: Optional[str] = "",
batch_source: Optional[DataSource] = None,
):
positional_attributes = [
"name",
"event_timestamp_column",
"created_timestamp_column",
"record_format",
"region",
"stream_name",
]
_name = name
_event_timestamp_column = event_timestamp_column
_created_timestamp_column = created_timestamp_column
_record_format = record_format
_region = region or ""
_stream_name = stream_name or ""
if args:
warnings.warn(
(
"Kinesis parameters should be specified as a keyword argument instead of a positional arg."
"Feast 0.23+ will not support positional arguments to construct kinesis sources"
),
DeprecationWarning,
)
if len(args) > len(positional_attributes):
raise ValueError(
f"Only {', '.join(positional_attributes)} are allowed as positional args when defining "
f"kinesis sources, for backwards compatibility."
)
if len(args) >= 1:
_name = args[0]
if len(args) >= 2:
_event_timestamp_column = args[1]
if len(args) >= 3:
_created_timestamp_column = args[2]
if len(args) >= 4:
_record_format = args[3]
if len(args) >= 5:
_region = args[4]
if len(args) >= 6:
_stream_name = args[5]

if _record_format is None:
raise ValueError("Record format must be specified for kinesis source")

super().__init__(
name=name,
event_timestamp_column=event_timestamp_column,
created_timestamp_column=created_timestamp_column,
name=_name,
event_timestamp_column=_event_timestamp_column,
created_timestamp_column=_created_timestamp_column,
field_mapping=field_mapping,
date_partition_column=date_partition_column,
description=description,
Expand All @@ -670,7 +780,7 @@ def __init__(
)
self.batch_source = batch_source
self.kinesis_options = KinesisOptions(
record_format=record_format, region=region, stream_name=stream_name
record_format=_record_format, region=_region, stream_name=_stream_name
)

def __eq__(self, other):
Expand Down Expand Up @@ -725,9 +835,9 @@ class PushSource(DataSource):

def __init__(
self,
*,
name: str,
batch_source: DataSource,
*args,
name: Optional[str] = None,
batch_source: Optional[DataSource] = None,
description: Optional[str] = "",
tags: Optional[Dict[str, str]] = None,
owner: Optional[str] = "",
Expand All @@ -744,10 +854,33 @@ def __init__(
maintainer.

"""
super().__init__(name=name, description=description, tags=tags, owner=owner)
self.batch_source = batch_source
if not self.batch_source:
raise ValueError(f"batch_source is needed for push source {self.name}")
positional_attributes = ["name", "batch_source"]
_name = name
_batch_source = batch_source
if args:
warnings.warn(
(
"Push source parameters should be specified as a keyword argument instead of a positional arg."
"Feast 0.23+ will not support positional arguments to construct push sources"
),
DeprecationWarning,
)
if len(args) > len(positional_attributes):
raise ValueError(
f"Only {', '.join(positional_attributes)} are allowed as positional args when defining "
f"push sources, for backwards compatibility."
)
if len(args) >= 1:
_name = args[0]
if len(args) >= 2:
_batch_source = args[1]

super().__init__(name=_name, description=description, tags=tags, owner=owner)
if not _batch_source:
raise ValueError(
f"batch_source parameter is needed for push source {self.name}"
)
self.batch_source = _batch_source

def __eq__(self, other):
if not isinstance(other, PushSource):
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
ValueError: A field mapping conflicts with an Entity or a Feature.
"""

positional_attributes = ["name, entities, ttl"]
positional_attributes = ["name", "entities", "ttl"]
kevjumba marked this conversation as resolved.
Show resolved Hide resolved

_name = name
_entities = entities
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/offline_stores/bigquery_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class BigQuerySource(DataSource):
def __init__(
self,
*,
event_timestamp_column: Optional[str] = "",
table: Optional[str] = None,
created_timestamp_column: Optional[str] = "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SparkSourceFormat(Enum):
class SparkSource(DataSource):
def __init__(
self,
*,
name: Optional[str] = None,
table: Optional[str] = None,
query: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def __init__(
table: Optional[str] = None,
created_timestamp_column: Optional[str] = "",
field_mapping: Optional[Dict[str, str]] = None,
date_partition_column: Optional[str] = None,
adchia marked this conversation as resolved.
Show resolved Hide resolved
query: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = "",
Expand Down
25 changes: 22 additions & 3 deletions sdk/python/feast/infra/offline_stores/file_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
class FileSource(DataSource):
def __init__(
self,
path: str,
*args,
path: Optional[str] = None,
event_timestamp_column: Optional[str] = "",
file_format: Optional[FileFormat] = None,
created_timestamp_column: Optional[str] = "",
Expand Down Expand Up @@ -58,13 +59,31 @@ def __init__(
>>> from feast import FileSource
>>> file_source = FileSource(path="my_features.parquet", timestamp_field="event_timestamp")
"""
if path is None:
positional_attributes = ["path"]
_path = path
if args:
if args:
warnings.warn(
(
"File Source parameters should be specified as a keyword argument instead of a positional arg."
"Feast 0.23+ will not support positional arguments to construct File sources"
),
DeprecationWarning,
)
if len(args) > len(positional_attributes):
raise ValueError(
f"Only {', '.join(positional_attributes)} are allowed as positional args when defining "
f"File sources, for backwards compatibility."
)
if len(args) >= 1:
_path = args[0]
if _path is None:
raise ValueError(
'No "path" argument provided. Please set "path" to the location of your file source.'
)
self.file_options = FileOptions(
file_format=file_format,
uri=path,
uri=_path,
s3_endpoint_override=s3_endpoint_override,
)

Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/offline_stores/redshift_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class RedshiftSource(DataSource):
def __init__(
self,
*,
event_timestamp_column: Optional[str] = "",
table: Optional[str] = None,
schema: Optional[str] = None,
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/offline_stores/snowflake_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class SnowflakeSource(DataSource):
def __init__(
self,
*,
database: Optional[str] = None,
warehouse: Optional[str] = None,
schema: Optional[str] = None,
Expand Down
Loading