From 235ede5bc93804a259a2b556bce41fbe980db865 Mon Sep 17 00:00:00 2001 From: tokoko Date: Sat, 9 Mar 2024 13:53:08 +0000 Subject: [PATCH] fix ci unittest warnings Signed-off-by: tokoko --- sdk/python/feast/driver_test_data.py | 4 ++-- sdk/python/feast/infra/offline_stores/file.py | 19 ++++++++++++++----- .../feast/infra/offline_stores/file_source.py | 10 +++++++++- .../feast/on_demand_pandas_transformation.py | 3 --- sdk/python/feast/utils.py | 13 ------------- .../infra/offline_stores/test_redshift.py | 1 + .../infra/offline_stores/test_snowflake.py | 1 + .../test_local_feature_store.py | 4 ++++ .../tests/unit/test_on_demand_feature_view.py | 2 ++ sdk/python/tests/unit/test_sql_registry.py | 11 +++++++++-- sdk/python/tests/utils/test_wrappers.py | 8 ++++---- 11 files changed, 46 insertions(+), 30 deletions(-) diff --git a/sdk/python/feast/driver_test_data.py b/sdk/python/feast/driver_test_data.py index 58c3e8db8f..7959046e6e 100644 --- a/sdk/python/feast/driver_test_data.py +++ b/sdk/python/feast/driver_test_data.py @@ -103,7 +103,7 @@ def create_driver_hourly_stats_df(drivers, start_date, end_date) -> pd.DataFrame "event_timestamp": [ pd.Timestamp(dt, unit="ms", tz="UTC").round("ms") for dt in pd.date_range( - start=start_date, end=end_date, freq="1H", inclusive="left" + start=start_date, end=end_date, freq="1h", inclusive="left" ) ] # include a fixed timestamp for get_historical_features in the quickstart @@ -209,7 +209,7 @@ def create_location_stats_df(locations, start_date, end_date) -> pd.DataFrame: "event_timestamp": [ pd.Timestamp(dt, unit="ms", tz="UTC").round("ms") for dt in pd.date_range( - start=start_date, end=end_date, freq="1H", inclusive="left" + start=start_date, end=end_date, freq="1h", inclusive="left" ) ] } diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index 0b873a2091..d922e98c14 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -2,7 +2,7 @@ import uuid from datetime import datetime from pathlib import Path -from typing import Any, Callable, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import dask import dask.dataframe as dd @@ -38,10 +38,7 @@ from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage from feast.usage import log_exceptions_and_usage -from feast.utils import ( - _get_requested_feature_views_to_features_dict, - _run_dask_field_mapping, -) +from feast.utils import _get_requested_feature_views_to_features_dict # FileRetrievalJob will cast string objects to string[pyarrow] from dask version 2023.7.1 # This is not the desired behavior for our use case, so we set the convert-string option to False @@ -512,6 +509,18 @@ def _read_datasource(data_source) -> dd.DataFrame: ) +def _run_dask_field_mapping( + table: dd.DataFrame, + field_mapping: Dict[str, str], +): + if field_mapping: + # run field mapping in the forward direction + table = table.rename(columns=field_mapping) + table = table.persist() + + return table + + def _field_mapping( df_to_join: dd.DataFrame, feature_view: FeatureView, diff --git a/sdk/python/feast/infra/offline_stores/file_source.py b/sdk/python/feast/infra/offline_stores/file_source.py index 887b410079..2672cf78bf 100644 --- a/sdk/python/feast/infra/offline_stores/file_source.py +++ b/sdk/python/feast/infra/offline_stores/file_source.py @@ -1,5 +1,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple +import pyarrow +from packaging import version from pyarrow._fs import FileSystem from pyarrow._s3fs import S3FileSystem from pyarrow.parquet import ParquetDataset @@ -158,7 +160,13 @@ def get_table_column_names_and_types( # Adding support for different file format path # based on S3 filesystem if filesystem is None: - schema = ParquetDataset(path, use_legacy_dataset=False).schema + kwargs = ( + {"use_legacy_dataset": False} + if version.parse(pyarrow.__version__) < version.parse("15.0.0") + else {} + ) + + schema = ParquetDataset(path, **kwargs).schema if hasattr(schema, "names") and hasattr(schema, "types"): # Newer versions of pyarrow doesn't have this method, # but this field is good enough. diff --git a/sdk/python/feast/on_demand_pandas_transformation.py b/sdk/python/feast/on_demand_pandas_transformation.py index 52d45893c5..32cb44b429 100644 --- a/sdk/python/feast/on_demand_pandas_transformation.py +++ b/sdk/python/feast/on_demand_pandas_transformation.py @@ -30,9 +30,6 @@ def __eq__(self, other): "Comparisons should only involve OnDemandPandasTransformation class objects." ) - if not super().__eq__(other): - return False - if ( self.udf_string != other.udf_string or self.udf.__code__.co_code != other.udf.__code__.co_code diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 50b1e73c86..70fbda964d 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -7,7 +7,6 @@ import pandas as pd import pyarrow -from dask import dataframe as dd from dateutil.tz import tzlocal from pytz import utc @@ -174,18 +173,6 @@ def _run_pyarrow_field_mapping( return table -def _run_dask_field_mapping( - table: dd.DataFrame, - field_mapping: Dict[str, str], -): - if field_mapping: - # run field mapping in the forward direction - table = table.rename(columns=field_mapping) - table = table.persist() - - return table - - def _coerce_datetime(ts): """ Depending on underlying time resolution, arrow to_pydict() sometimes returns pd diff --git a/sdk/python/tests/unit/infra/offline_stores/test_redshift.py b/sdk/python/tests/unit/infra/offline_stores/test_redshift.py index 48ee99e89f..a9ed4c2b59 100644 --- a/sdk/python/tests/unit/infra/offline_stores/test_redshift.py +++ b/sdk/python/tests/unit/infra/offline_stores/test_redshift.py @@ -33,6 +33,7 @@ def test_offline_write_batch( s3_staging_location="s3://bucket/path", workgroup="", ), + entity_key_serialization_version=2, ) batch_source = RedshiftSource( diff --git a/sdk/python/tests/unit/infra/offline_stores/test_snowflake.py b/sdk/python/tests/unit/infra/offline_stores/test_snowflake.py index ac55f123bb..6e27cba341 100644 --- a/sdk/python/tests/unit/infra/offline_stores/test_snowflake.py +++ b/sdk/python/tests/unit/infra/offline_stores/test_snowflake.py @@ -38,6 +38,7 @@ def retrieval_job(request): provider="snowflake.offline", online_store=SqliteOnlineStoreConfig(type="sqlite"), offline_store=offline_store_config, + entity_key_serialization_version=2, ), full_feature_names=True, on_demand_feature_views=[], diff --git a/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py b/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py index 2cced75eb2..b3e6762c17 100644 --- a/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py +++ b/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py @@ -130,6 +130,7 @@ def test_apply_feature_view_with_inline_batch_source( driver_fv = FeatureView( name="driver_fv", entities=[entity], + schema=[Field(name="test_key", dtype=Int64)], source=file_source, ) @@ -178,6 +179,7 @@ def test_apply_feature_view_with_inline_stream_source( driver_fv = FeatureView( name="driver_fv", entities=[entity], + schema=[Field(name="test_key", dtype=Int64)], source=stream_source, ) @@ -332,6 +334,7 @@ def test_apply_conflicting_feature_view_names(feature_store_with_local_registry) driver_stats = FeatureView( name="driver_hourly_stats", entities=[driver], + schema=[Field(name="driver_id", dtype=Int64)], ttl=timedelta(seconds=10), online=False, source=FileSource(path="driver_stats.parquet"), @@ -341,6 +344,7 @@ def test_apply_conflicting_feature_view_names(feature_store_with_local_registry) customer_stats = FeatureView( name="DRIVER_HOURLY_STATS", entities=[customer], + schema=[Field(name="customer_id", dtype=Int64)], ttl=timedelta(seconds=10), online=False, source=FileSource(path="customer_stats.parquet"), diff --git a/sdk/python/tests/unit/test_on_demand_feature_view.py b/sdk/python/tests/unit/test_on_demand_feature_view.py index 721026ea46..66d02c65d1 100644 --- a/sdk/python/tests/unit/test_on_demand_feature_view.py +++ b/sdk/python/tests/unit/test_on_demand_feature_view.py @@ -13,6 +13,7 @@ # limitations under the License. import pandas as pd +import pytest from feast.feature_view import FeatureView from feast.field import Field @@ -38,6 +39,7 @@ def udf2(features_df: pd.DataFrame) -> pd.DataFrame: return df +@pytest.mark.filterwarnings("ignore:udf and udf_string parameters are deprecated") def test_hash(): file_source = FileSource(name="my-file-source", path="test.parquet") feature_view = FeatureView( diff --git a/sdk/python/tests/unit/test_sql_registry.py b/sdk/python/tests/unit/test_sql_registry.py index 4ca41423c1..094b8967c1 100644 --- a/sdk/python/tests/unit/test_sql_registry.py +++ b/sdk/python/tests/unit/test_sql_registry.py @@ -93,7 +93,7 @@ def mysql_registry(): container.start() # The log string uses '8.0.*' since the version might be changed as new Docker images are pushed. - log_string_to_wait_for = "/usr/sbin/mysqld: ready for connections. Version: '(\d+(\.\d+){1,2})' socket: '/var/run/mysqld/mysqld.sock' port: 3306" # noqa: W605 + log_string_to_wait_for = "/usr/sbin/mysqld: ready for connections. Version: '(\\d+(\\.\\d+){1,2})' socket: '/var/run/mysqld/mysqld.sock' port: 3306" # noqa: W605 waited = wait_for_logs( container=container, predicate=log_string_to_wait_for, @@ -218,6 +218,7 @@ def test_apply_feature_view_success(sql_registry): fv1 = FeatureView( name="my_feature_view_1", schema=[ + Field(name="test", dtype=Int64), Field(name="fs1_my_feature_1", dtype=Int64), Field(name="fs1_my_feature_2", dtype=String), Field(name="fs1_my_feature_3", dtype=Array(String)), @@ -313,6 +314,7 @@ def test_apply_on_demand_feature_view_success(sql_registry): entities=[driver()], ttl=timedelta(seconds=8640000000), schema=[ + Field(name="driver_id", dtype=Int64), Field(name="daily_miles_driven", dtype=Float32), Field(name="lat", dtype=Float32), Field(name="lon", dtype=Float32), @@ -403,7 +405,10 @@ def test_modify_feature_views_success(sql_registry): fv1 = FeatureView( name="my_feature_view_1", - schema=[Field(name="fs1_my_feature_1", dtype=Int64)], + schema=[ + Field(name="test", dtype=Int64), + Field(name="fs1_my_feature_1", dtype=Int64), + ], entities=[entity], tags={"team": "matchmaking"}, source=batch_source, @@ -527,6 +532,7 @@ def test_apply_data_source(sql_registry): fv1 = FeatureView( name="my_feature_view_1", schema=[ + Field(name="test", dtype=Int64), Field(name="fs1_my_feature_1", dtype=Int64), Field(name="fs1_my_feature_2", dtype=String), Field(name="fs1_my_feature_3", dtype=Array(String)), @@ -596,6 +602,7 @@ def test_registry_cache(sql_registry): fv1 = FeatureView( name="my_feature_view_1", schema=[ + Field(name="test", dtype=Int64), Field(name="fs1_my_feature_1", dtype=Int64), Field(name="fs1_my_feature_2", dtype=String), Field(name="fs1_my_feature_3", dtype=Array(String)), diff --git a/sdk/python/tests/utils/test_wrappers.py b/sdk/python/tests/utils/test_wrappers.py index efee675790..eb5e3ef3f1 100644 --- a/sdk/python/tests/utils/test_wrappers.py +++ b/sdk/python/tests/utils/test_wrappers.py @@ -1,14 +1,14 @@ -import pytest +import warnings def no_warnings(func): def wrapper_no_warnings(*args, **kwargs): - with pytest.warns(None) as warnings: + with warnings.catch_warnings(record=True) as record: func(*args, **kwargs) - if len(warnings) > 0: + if len(record) > 0: raise AssertionError( - "Warnings were raised: " + ", ".join([str(w) for w in warnings]) + "Warnings were raised: " + ", ".join([str(w) for w in record]) ) return wrapper_no_warnings