Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: CI unittest warnings #4006

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sdk/python/feast/driver_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def create_driver_hourly_stats_df(drivers, start_date, end_date) -> pd.DataFrame
"event_timestamp": [
pd.Timestamp(dt, unit="ms", tz="UTC").round("ms")
for dt in pd.date_range(
start=start_date, end=end_date, freq="1H", inclusive="left"
start=start_date, end=end_date, freq="1h", inclusive="left"
)
]
# include a fixed timestamp for get_historical_features in the quickstart
Expand Down Expand Up @@ -209,7 +209,7 @@ def create_location_stats_df(locations, start_date, end_date) -> pd.DataFrame:
"event_timestamp": [
pd.Timestamp(dt, unit="ms", tz="UTC").round("ms")
for dt in pd.date_range(
start=start_date, end=end_date, freq="1H", inclusive="left"
start=start_date, end=end_date, freq="1h", inclusive="left"
)
]
}
Expand Down
19 changes: 14 additions & 5 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import dask
import dask.dataframe as dd
Expand Down Expand Up @@ -38,10 +38,7 @@
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.usage import log_exceptions_and_usage
from feast.utils import (
_get_requested_feature_views_to_features_dict,
_run_dask_field_mapping,
)
from feast.utils import _get_requested_feature_views_to_features_dict

# FileRetrievalJob will cast string objects to string[pyarrow] from dask version 2023.7.1
# This is not the desired behavior for our use case, so we set the convert-string option to False
Expand Down Expand Up @@ -512,6 +509,18 @@ def _read_datasource(data_source) -> dd.DataFrame:
)


def _run_dask_field_mapping(
table: dd.DataFrame,
field_mapping: Dict[str, str],
):
if field_mapping:
# run field mapping in the forward direction
table = table.rename(columns=field_mapping)
table = table.persist()

return table


def _field_mapping(
df_to_join: dd.DataFrame,
feature_view: FeatureView,
Expand Down
10 changes: 9 additions & 1 deletion sdk/python/feast/infra/offline_stores/file_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Callable, Dict, Iterable, List, Optional, Tuple

import pyarrow
from packaging import version
from pyarrow._fs import FileSystem
from pyarrow._s3fs import S3FileSystem
from pyarrow.parquet import ParquetDataset
Expand Down Expand Up @@ -158,7 +160,13 @@ def get_table_column_names_and_types(
# Adding support for different file format path
# based on S3 filesystem
if filesystem is None:
schema = ParquetDataset(path, use_legacy_dataset=False).schema
kwargs = (
{"use_legacy_dataset": False}
if version.parse(pyarrow.__version__) < version.parse("15.0.0")
else {}
)

schema = ParquetDataset(path, **kwargs).schema
if hasattr(schema, "names") and hasattr(schema, "types"):
# Newer versions of pyarrow doesn't have this method,
# but this field is good enough.
Expand Down
3 changes: 0 additions & 3 deletions sdk/python/feast/on_demand_pandas_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ def __eq__(self, other):
"Comparisons should only involve OnDemandPandasTransformation class objects."
)

if not super().__eq__(other):
return False

if (
self.udf_string != other.udf_string
or self.udf.__code__.co_code != other.udf.__code__.co_code
Expand Down
13 changes: 0 additions & 13 deletions sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import pandas as pd
import pyarrow
from dask import dataframe as dd
from dateutil.tz import tzlocal
from pytz import utc

Expand Down Expand Up @@ -174,18 +173,6 @@ def _run_pyarrow_field_mapping(
return table


def _run_dask_field_mapping(
table: dd.DataFrame,
field_mapping: Dict[str, str],
):
if field_mapping:
# run field mapping in the forward direction
table = table.rename(columns=field_mapping)
table = table.persist()

return table


def _coerce_datetime(ts):
"""
Depending on underlying time resolution, arrow to_pydict() sometimes returns pd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_offline_write_batch(
s3_staging_location="s3://bucket/path",
workgroup="",
),
entity_key_serialization_version=2,
)

batch_source = RedshiftSource(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def retrieval_job(request):
provider="snowflake.offline",
online_store=SqliteOnlineStoreConfig(type="sqlite"),
offline_store=offline_store_config,
entity_key_serialization_version=2,
),
full_feature_names=True,
on_demand_feature_views=[],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_apply_feature_view_with_inline_batch_source(
driver_fv = FeatureView(
name="driver_fv",
entities=[entity],
schema=[Field(name="test_key", dtype=Int64)],
source=file_source,
)

Expand Down Expand Up @@ -178,6 +179,7 @@ def test_apply_feature_view_with_inline_stream_source(
driver_fv = FeatureView(
name="driver_fv",
entities=[entity],
schema=[Field(name="test_key", dtype=Int64)],
source=stream_source,
)

Expand Down Expand Up @@ -332,6 +334,7 @@ def test_apply_conflicting_feature_view_names(feature_store_with_local_registry)
driver_stats = FeatureView(
name="driver_hourly_stats",
entities=[driver],
schema=[Field(name="driver_id", dtype=Int64)],
ttl=timedelta(seconds=10),
online=False,
source=FileSource(path="driver_stats.parquet"),
Expand All @@ -341,6 +344,7 @@ def test_apply_conflicting_feature_view_names(feature_store_with_local_registry)
customer_stats = FeatureView(
name="DRIVER_HOURLY_STATS",
entities=[customer],
schema=[Field(name="customer_id", dtype=Int64)],
ttl=timedelta(seconds=10),
online=False,
source=FileSource(path="customer_stats.parquet"),
Expand Down
2 changes: 2 additions & 0 deletions sdk/python/tests/unit/test_on_demand_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import pandas as pd
import pytest

from feast.feature_view import FeatureView
from feast.field import Field
Expand All @@ -38,6 +39,7 @@ def udf2(features_df: pd.DataFrame) -> pd.DataFrame:
return df


@pytest.mark.filterwarnings("ignore:udf and udf_string parameters are deprecated")
def test_hash():
file_source = FileSource(name="my-file-source", path="test.parquet")
feature_view = FeatureView(
Expand Down
11 changes: 9 additions & 2 deletions sdk/python/tests/unit/test_sql_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def mysql_registry():
container.start()

# The log string uses '8.0.*' since the version might be changed as new Docker images are pushed.
log_string_to_wait_for = "/usr/sbin/mysqld: ready for connections. Version: '(\d+(\.\d+){1,2})' socket: '/var/run/mysqld/mysqld.sock' port: 3306" # noqa: W605
log_string_to_wait_for = "/usr/sbin/mysqld: ready for connections. Version: '(\\d+(\\.\\d+){1,2})' socket: '/var/run/mysqld/mysqld.sock' port: 3306" # noqa: W605
waited = wait_for_logs(
container=container,
predicate=log_string_to_wait_for,
Expand Down Expand Up @@ -218,6 +218,7 @@ def test_apply_feature_view_success(sql_registry):
fv1 = FeatureView(
name="my_feature_view_1",
schema=[
Field(name="test", dtype=Int64),
Field(name="fs1_my_feature_1", dtype=Int64),
Field(name="fs1_my_feature_2", dtype=String),
Field(name="fs1_my_feature_3", dtype=Array(String)),
Expand Down Expand Up @@ -313,6 +314,7 @@ def test_apply_on_demand_feature_view_success(sql_registry):
entities=[driver()],
ttl=timedelta(seconds=8640000000),
schema=[
Field(name="driver_id", dtype=Int64),
Field(name="daily_miles_driven", dtype=Float32),
Field(name="lat", dtype=Float32),
Field(name="lon", dtype=Float32),
Expand Down Expand Up @@ -403,7 +405,10 @@ def test_modify_feature_views_success(sql_registry):

fv1 = FeatureView(
name="my_feature_view_1",
schema=[Field(name="fs1_my_feature_1", dtype=Int64)],
schema=[
Field(name="test", dtype=Int64),
Field(name="fs1_my_feature_1", dtype=Int64),
],
entities=[entity],
tags={"team": "matchmaking"},
source=batch_source,
Expand Down Expand Up @@ -527,6 +532,7 @@ def test_apply_data_source(sql_registry):
fv1 = FeatureView(
name="my_feature_view_1",
schema=[
Field(name="test", dtype=Int64),
Field(name="fs1_my_feature_1", dtype=Int64),
Field(name="fs1_my_feature_2", dtype=String),
Field(name="fs1_my_feature_3", dtype=Array(String)),
Expand Down Expand Up @@ -596,6 +602,7 @@ def test_registry_cache(sql_registry):
fv1 = FeatureView(
name="my_feature_view_1",
schema=[
Field(name="test", dtype=Int64),
Field(name="fs1_my_feature_1", dtype=Int64),
Field(name="fs1_my_feature_2", dtype=String),
Field(name="fs1_my_feature_3", dtype=Array(String)),
Expand Down
8 changes: 4 additions & 4 deletions sdk/python/tests/utils/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import pytest
import warnings


def no_warnings(func):
def wrapper_no_warnings(*args, **kwargs):
with pytest.warns(None) as warnings:
with warnings.catch_warnings(record=True) as record:
func(*args, **kwargs)

if len(warnings) > 0:
if len(record) > 0:
raise AssertionError(
"Warnings were raised: " + ", ".join([str(w) for w in warnings])
"Warnings were raised: " + ", ".join([str(w) for w in record])
)

return wrapper_no_warnings
Loading