Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Dask zero division error if parquet dataset has only one partition #3236

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,14 +662,33 @@ def _drop_duplicates(
created_timestamp_column: str,
entity_df_event_timestamp_col: str,
) -> dd.DataFrame:
if created_timestamp_column:
df_to_join = df_to_join.sort_values(
by=created_timestamp_column, na_position="first"
)
column_order = df_to_join.columns
felixwang9817 marked this conversation as resolved.
Show resolved Hide resolved

# try-catch block is added to deal with this issue https://github.com/dask/dask/issues/8939.
# TODO(kevjumba): remove try catch when fix is merged upstream in Dask.
try:
if created_timestamp_column:
df_to_join = df_to_join.sort_values(
by=created_timestamp_column, na_position="first"
)
df_to_join = df_to_join.persist()

df_to_join = df_to_join.sort_values(by=timestamp_field, na_position="first")
df_to_join = df_to_join.persist()

df_to_join = df_to_join.sort_values(by=timestamp_field, na_position="first")
df_to_join = df_to_join.persist()
except ZeroDivisionError:
# Use 1 partition to get around case where everything in timestamp column is the same so the partition algorithm doesn't
# try to divide by zero.
if created_timestamp_column:
df_to_join = df_to_join[column_order].sort_values(
by=created_timestamp_column, na_position="first", npartitions=1
)
df_to_join = df_to_join.persist()

df_to_join = df_to_join[column_order].sort_values(
by=timestamp_field, na_position="first", npartitions=1
)
df_to_join = df_to_join.persist()

df_to_join = df_to_join.drop_duplicates(
all_join_keys + [entity_df_event_timestamp_col],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Any, Dict, List, Optional

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from minio import Minio
from testcontainers.core.generic import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs
Expand Down Expand Up @@ -87,6 +89,39 @@ def teardown(self):
shutil.rmtree(d)


class FileParquetDatasetSourceCreator(FileDataSourceCreator):
def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
timestamp_field="ts",
created_timestamp_column="created_ts",
field_mapping: Dict[str, str] = None,
) -> DataSource:

destination_name = self.get_prefixed_table_name(destination_name)

dataset_path = tempfile.TemporaryDirectory(
prefix=f"{self.project_name}_{destination_name}"
)
table = pa.Table.from_pandas(df)
pq.write_to_dataset(
table,
base_dir=dataset_path.name,
compression="snappy",
format="parquet",
existing_data_behavior="overwrite_or_ignore",
)
self.files.append(dataset_path.name)
return FileSource(
file_format=ParquetFormat(),
path=dataset_path.name,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
field_mapping=field_mapping or {"ts_1": "ts"},
)


class S3FileDataSourceCreator(DataSourceCreator):
f: Any
minio: DockerContainer
Expand Down
6 changes: 6 additions & 0 deletions sdk/python/tests/utils/e2e_test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from tests.integration.feature_repos.universal.data_sources.file import (
FileDataSourceCreator,
FileParquetDatasetSourceCreator,
)
from tests.integration.feature_repos.universal.data_sources.redshift import (
RedshiftDataSourceCreator,
Expand Down Expand Up @@ -211,6 +212,11 @@ def make_feature_store_yaml(
offline_store_creator=FileDataSourceCreator,
online_store=None,
),
IntegrationTestRepoConfig(
provider="local",
offline_store_creator=FileParquetDatasetSourceCreator,
online_store=None,
),
]

# Only test if this is NOT a local test
Expand Down