Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Environment class and DataSourceCreator API, and use fixtures for datasets and data sources #1790

Merged
merged 53 commits into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
789d7a1
Fix API cruft from DataSourceCreator
achals Aug 18, 2021
8b7eab5
Remove the need for get_prefixed_table_name
achals Aug 18, 2021
0cc21ab
major refactor
achals Aug 19, 2021
6ddd5b0
move start time
achals Aug 19, 2021
a700745
Remove one dimension of variation to be added in later
achals Aug 19, 2021
c25b6cb
Fix default
achals Aug 19, 2021
a3e4473
Fixups
achals Aug 19, 2021
3f632fd
Fixups
achals Aug 19, 2021
9aeea86
Fix up tests
achals Aug 19, 2021
84cbefd
Add retries to execute_redshift_statement_async
achals Aug 19, 2021
fe180eb
Add retries to execute_redshift_statement_async
achals Aug 19, 2021
0cd08d2
refactoooor
achals Aug 19, 2021
4701183
remove retries
achals Aug 19, 2021
4cde284
Remove provider variation since they don't really play a big role
achals Aug 20, 2021
974bb0b
Session scoped cache for test datasets and skipping older tests whose…
achals Aug 23, 2021
dbbd6fb
make format
achals Aug 23, 2021
473b630
make format
achals Aug 23, 2021
95982e2
remove import
achals Aug 23, 2021
fc1b4ee
merge from master
achals Aug 25, 2021
df9596b
fix merge
achals Aug 25, 2021
fc46edb
Use an enum for the stopping procedure instead of the bools
achals Aug 25, 2021
0daaff0
Fix refs
achals Aug 26, 2021
ef6a6b3
fix step
achals Aug 26, 2021
d63691d
WIP fixes
achals Aug 26, 2021
3d75977
Fix for feature inferencing
achals Aug 26, 2021
1abbaa8
C901 '_python_value_to_proto_value' is too complex :(
achals Aug 26, 2021
21996b7
Split out construct_test_repo and construct_universal_test_repo
achals Aug 27, 2021
4541561
remove import
achals Aug 27, 2021
2f20b5c
add unsafe_hash
achals Aug 27, 2021
6cc32b5
Update testrepoconfig
achals Aug 27, 2021
68f2997
Update testrepoconfig
achals Aug 27, 2021
a73868d
Remove kwargs from construct_universal_test_environment
achals Aug 27, 2021
ff9d49d
Remove unneeded method
achals Aug 27, 2021
b95c3ee
Docs
achals Aug 27, 2021
3e5aada
Kill skipped tests
achals Aug 27, 2021
b6abcaa
reorder
achals Aug 27, 2021
9580654
add todo
achals Aug 27, 2021
171a1ef
Split universal vs non data_source_cache
achals Aug 27, 2021
de9d8aa
make format
achals Aug 27, 2021
6454775
WIP fixtures
achals Aug 31, 2021
3fbb2e4
WIP Trying fixtures more effectively
achals Sep 1, 2021
2961d58
fix refs
achals Sep 1, 2021
1739815
Fix refs
achals Sep 1, 2021
fe14ba3
Fix refs
achals Sep 1, 2021
43d8e27
Fix refs
achals Sep 1, 2021
58c75b0
fix historical tests
achals Sep 1, 2021
8b5ba5a
renames
achals Sep 1, 2021
c445f53
CR updates
achals Sep 1, 2021
e0ccd9b
use the actual ref to data source creators
achals Sep 1, 2021
4d4efad
merge from master
achals Sep 1, 2021
a510ca6
format
achals Sep 1, 2021
6c10502
unused imports'
achals Sep 1, 2021
ff1fb81
Add ids for pytest params
achals Sep 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions sdk/python/feast/infra/utils/aws_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from tenacity import retry, retry_if_exception_type, wait_exponential
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)

from feast.errors import RedshiftCredentialsError, RedshiftQueryError
from feast.type_map import pa_to_redshift_value_type

try:
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, ConnectionClosedError
except ImportError as e:
from feast.errors import FeastExtrasDependencyImportError

Expand Down Expand Up @@ -50,6 +55,11 @@ def get_bucket_and_key(s3_path: str) -> Tuple[str, str]:
return bucket, key


@retry(
achals marked this conversation as resolved.
Show resolved Hide resolved
wait=wait_exponential(multiplier=1, max=30),
retry=retry_if_exception_type(ConnectionClosedError),
stop=stop_after_attempt(3),
)
def execute_redshift_statement_async(
redshift_data_client, cluster_id: str, database: str, user: str, query: str
) -> dict:
Expand Down Expand Up @@ -82,7 +92,7 @@ class RedshiftStatementNotFinishedError(Exception):


@retry(
wait=wait_exponential(multiplier=0.1, max=30),
wait=wait_exponential(multiplier=1, max=30),
retry=retry_if_exception_type(RedshiftStatementNotFinishedError),
)
def wait_for_redshift_statement(redshift_data_client, statement: dict) -> None:
Expand Down
198 changes: 110 additions & 88 deletions sdk/python/tests/integration/feature_repos/test_repo_configuration.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import tempfile
achals marked this conversation as resolved.
Show resolved Hide resolved
import uuid
from contextlib import contextmanager
from dataclasses import dataclass, replace
from dataclasses import dataclass, field, replace
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import pandas as pd
import pytest

from feast import FeatureStore, FeatureView, RepoConfig, driver_test_data, importer
Expand Down Expand Up @@ -71,9 +72,63 @@ def ds_creator_path(cls: str):
]


OFFLINE_STORES: List[str] = []
ONLINE_STORES: List[str] = []
PROVIDERS: List[str] = []
def construct_universal_entities() -> Dict[str, List[Any]]:
return {"customer": list(range(1001, 1110)), "driver": list(range(5001, 5110))}


def construct_universal_datasets(
entities: Dict[str, List[Any]], start_time: datetime, end_time: datetime
) -> Dict[str, pd.DataFrame]:
customer_df = driver_test_data.create_customer_daily_profile_df(
entities["customer"], start_time, end_time
)
driver_df = driver_test_data.create_driver_hourly_stats_df(
entities["driver"], start_time, end_time
)
orders_df = driver_test_data.create_orders_df(
customers=entities["customer"],
drivers=entities["driver"],
start_date=end_time - timedelta(days=365),
end_date=end_time + timedelta(days=365),
order_count=1000,
)

return {"customer": customer_df, "driver": driver_df, "orders": orders_df}


def construct_universal_data_sources(
datasets: Dict[str, pd.DataFrame], data_source_creator: DataSourceCreator
) -> Dict[str, DataSource]:
customer_ds = data_source_creator.create_data_sources(
datasets["customer"],
suffix="customer_profile",
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
)
driver_ds = data_source_creator.create_data_sources(
datasets["driver"],
suffix="driver_hourly",
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
)
orders_ds = data_source_creator.create_data_sources(
datasets["orders"],
suffix="orders",
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
)
return {"customer": customer_ds, "driver": driver_ds, "orders": orders_ds}


def construct_universal_feature_views(
data_sources: Dict[str, DataSource]
) -> Dict[str, FeatureView]:
return {
"customer": create_customer_daily_profile_feature_view(
data_sources["customer"]
),
"driver": create_driver_hourly_stats_feature_view(data_sources["driver"]),
}


@dataclass
Expand All @@ -84,78 +139,47 @@ class Environment:
data_source: DataSource
data_source_creator: DataSourceCreator

end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
start_date = end_date - timedelta(days=7)
before_start_date = end_date - timedelta(days=365)
after_end_date = end_date + timedelta(days=365)

customer_entities = list(range(1001, 1110))
customer_df = driver_test_data.create_customer_daily_profile_df(
customer_entities, start_date, end_date
entities_creator: Callable[[], Dict[str, List[Any]]] = field(
default=construct_universal_entities
)
_customer_feature_view: Optional[FeatureView] = None

driver_entities = list(range(5001, 5110))
driver_df = driver_test_data.create_driver_hourly_stats_df(
driver_entities, start_date, end_date
datasets_creator: Callable[
[Dict[str, List[Any]], datetime, datetime], Dict[str, pd.DataFrame]
] = field(default=construct_universal_datasets)
datasources_creator: Callable[
[Dict[str, pd.DataFrame], DataSourceCreator], Dict[str, DataSource]
] = field(default=construct_universal_data_sources)
feature_views_creator: Callable[
[Dict[str, DataSource]], Dict[str, FeatureView]
] = field(default=construct_universal_feature_views)

entities: Dict[str, List[Any]] = field(default_factory=dict)
datasets: Dict[str, pd.DataFrame] = field(default_factory=dict)
datasources: Dict[str, DataSource] = field(default_factory=dict)
feature_views: Dict[str, FeatureView] = field(default_factory=list)

end_date: datetime = field(
default=datetime.now().replace(microsecond=0, second=0, minute=0)
)
_driver_stats_feature_view: Optional[FeatureView] = None

orders_df = driver_test_data.create_orders_df(
customers=customer_entities,
drivers=driver_entities,
start_date=before_start_date,
end_date=after_end_date,
order_count=1000,
)
_orders_table: Optional[str] = None

def customer_feature_view(self) -> FeatureView:
if self._customer_feature_view is None:
customer_table_id = self.data_source_creator.get_prefixed_table_name(
self.name, "customer_profile"
)
ds = self.data_source_creator.create_data_sources(
customer_table_id,
self.customer_df,
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
)
self._customer_feature_view = create_customer_daily_profile_feature_view(ds)
return self._customer_feature_view

def driver_stats_feature_view(self) -> FeatureView:
if self._driver_stats_feature_view is None:
driver_table_id = self.data_source_creator.get_prefixed_table_name(
self.name, "driver_hourly"
)
ds = self.data_source_creator.create_data_sources(
driver_table_id,
self.driver_df,
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
)
self._driver_stats_feature_view = create_driver_hourly_stats_feature_view(
ds
)
return self._driver_stats_feature_view

def orders_table(self) -> Optional[str]:
if self._orders_table is None:
orders_table_id = self.data_source_creator.get_prefixed_table_name(
self.name, "orders"
)
ds = self.data_source_creator.create_data_sources(
orders_table_id,
self.orders_df,
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
)
if hasattr(ds, "table_ref"):
self._orders_table = ds.table_ref
elif hasattr(ds, "table"):
self._orders_table = ds.table
return self._orders_table
def __post_init__(self):
self.start_date: datetime = self.end_date - timedelta(days=7)

self.entities = self.entities_creator()
self.datasets = self.datasets_creator(
self.entities, self.start_date, self.end_date
)
self.datasources = self.datasources_creator(
self.datasets, self.data_source_creator
)
self.feature_views = self.feature_views_creator(self.datasources)


def table_name_from_data_source(ds: DataSource) -> Optional[str]:
if hasattr(ds, "table_ref"):
return ds.table_ref
elif hasattr(ds, "table"):
return ds.table
return None


def vary_full_feature_names(configs: List[TestRepoConfig]) -> List[TestRepoConfig]:
Expand Down Expand Up @@ -195,7 +219,7 @@ def vary_providers_for_offline_stores(


@contextmanager
def construct_test_environment(
def construct_universal_test_environment(
test_repo_config: TestRepoConfig,
create_and_apply: bool = False,
materialize: bool = False,
Expand All @@ -210,7 +234,6 @@ def construct_test_environment(
:param test_repo_config: configuration
:return: A feature store built using the supplied configuration.
"""
df = create_dataset()

project = f"test_correctness_{str(uuid.uuid4()).replace('-', '')[:8]}"

Expand All @@ -221,9 +244,13 @@ def construct_test_environment(
offline_creator: DataSourceCreator = importer.get_class_from_type(
module_name, config_class_name, "DataSourceCreator"
)(project)

# This needs to be abstracted away for test_e2e_universal which uses a different dataset.
df = create_dataset()
ds = offline_creator.create_data_sources(
project, df, field_mapping={"ts_1": "ts", "id": "driver_id"}
df, destination=project, field_mapping={"ts_1": "ts", "id": "driver_id"}
achals marked this conversation as resolved.
Show resolved Hide resolved
)

offline_store = offline_creator.create_offline_store_config()
online_store = test_repo_config.online_store

Expand All @@ -250,12 +277,7 @@ def construct_test_environment(
try:
if create_and_apply:
entities.extend([driver(), customer()])
fvs.extend(
[
environment.driver_stats_feature_view(),
environment.customer_feature_view(),
]
)
fvs.extend(environment.feature_views.values())
fs.apply(fvs + entities)

if materialize:
Expand Down Expand Up @@ -283,7 +305,7 @@ def parametrize_e2e_test(e2e_test):
@pytest.mark.integration
@pytest.mark.parametrize("config", FULL_REPO_CONFIGS, ids=lambda v: str(v))
def inner_test(config):
with construct_test_environment(config) as environment:
with construct_universal_test_environment(config) as environment:
e2e_test(environment)

return inner_test
Expand All @@ -305,12 +327,13 @@ def parametrize_offline_retrieval_test(offline_retrieval_test):

configs = vary_providers_for_offline_stores(FULL_REPO_CONFIGS)
configs = vary_full_feature_names(configs)
configs = vary_infer_event_timestamp_col(configs)

@pytest.mark.integration
@pytest.mark.parametrize("config", configs, ids=lambda v: str(v))
def inner_test(config):
with construct_test_environment(config, create_and_apply=True) as environment:
with construct_universal_test_environment(
config, create_and_apply=True
) as environment:
offline_retrieval_test(environment)

return inner_test
Expand All @@ -330,12 +353,11 @@ def parametrize_online_test(online_test):

configs = vary_providers_for_offline_stores(FULL_REPO_CONFIGS)
configs = vary_full_feature_names(configs)
configs = vary_infer_event_timestamp_col(configs)

@pytest.mark.integration
@pytest.mark.parametrize("config", configs, ids=lambda v: str(v))
def inner_test(config):
with construct_test_environment(
with construct_universal_test_environment(
config, create_and_apply=True, materialize=True
) as environment:
online_test(environment)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Dict
from typing import Dict, Optional

import pandas as pd

Expand All @@ -11,8 +11,9 @@ class DataSourceCreator(ABC):
@abstractmethod
def create_data_sources(
self,
destination: str,
df: pd.DataFrame,
destination: Optional[str] = None,
suffix: Optional[str] = None,
event_timestamp_column="ts",
created_timestamp_column="created_ts",
field_mapping: Dict[str, str] = None,
Expand All @@ -28,5 +29,5 @@ def teardown(self):
...

@abstractmethod
def get_prefixed_table_name(self, name: str, suffix: str) -> str:
def get_prefixed_table_name(self, table_name: str) -> str:
achals marked this conversation as resolved.
Show resolved Hide resolved
...
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Optional

import pandas as pd
from google.cloud import bigquery
Expand Down Expand Up @@ -42,14 +42,19 @@ def create_offline_store_config(self):

def create_data_sources(
self,
destination: str,
df: pd.DataFrame,
destination: Optional[str] = None,
suffix: Optional[str] = None,
event_timestamp_column="ts",
created_timestamp_column="created_ts",
field_mapping: Dict[str, str] = None,
**kwargs,
) -> DataSource:

assert destination or suffix
achals marked this conversation as resolved.
Show resolved Hide resolved
if not destination:
destination = self.get_prefixed_table_name(suffix)

job_config = bigquery.LoadJobConfig()
if self.gcp_project not in destination:
destination = f"{self.gcp_project}.{self.project_name}.{destination}"
Expand All @@ -69,5 +74,5 @@ def create_data_sources(
field_mapping=field_mapping or {"ts_1": "ts"},
)

def get_prefixed_table_name(self, name: str, suffix: str) -> str:
return f"{self.client.project}.{name}.{suffix}"
def get_prefixed_table_name(self, suffix: str) -> str:
return f"{self.client.project}.{self.project_name}.{suffix}"
Loading