Skip to content

Commit

Permalink
fix: Changes template file path to relative path (feast-dev#4624)
Browse files Browse the repository at this point in the history
* fix: changes following issue 4593

Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>

* fix: changes following issue 4593
- Fixed file path in templates to be relative path

Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>

* fix: Fixes to relative path in FileSource

Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>

---------

Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>
Signed-off-by: Rob Howley <howley.robert@gmail.com>
  • Loading branch information
tmihalac authored and robhowley committed Oct 17, 2024
1 parent 203c023 commit 03e43b5
Show file tree
Hide file tree
Showing 14 changed files with 104 additions and 32 deletions.
4 changes: 2 additions & 2 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def plan(
>>> fs = FeatureStore(repo_path="project/feature_repo")
>>> driver = Entity(name="driver_id", description="driver id")
>>> driver_hourly_stats = FileSource(
... path="project/feature_repo/data/driver_stats.parquet",
... path="data/driver_stats.parquet",
... timestamp_field="event_timestamp",
... created_timestamp_column="created",
... )
Expand Down Expand Up @@ -827,7 +827,7 @@ def apply(
>>> fs = FeatureStore(repo_path="project/feature_repo")
>>> driver = Entity(name="driver_id", description="driver id")
>>> driver_hourly_stats = FileSource(
... path="project/feature_repo/data/driver_stats.parquet",
... path="data/driver_stats.parquet",
... timestamp_field="event_timestamp",
... created_timestamp_column="created",
... )
Expand Down
40 changes: 33 additions & 7 deletions sdk/python/feast/infra/offline_stores/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
self,
evaluation_function: Callable,
full_feature_names: bool,
repo_path: str,
on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None,
metadata: Optional[RetrievalMetadata] = None,
):
Expand All @@ -67,6 +68,7 @@ def __init__(
self._full_feature_names = full_feature_names
self._on_demand_feature_views = on_demand_feature_views or []
self._metadata = metadata
self.repo_path = repo_path

@property
def full_feature_names(self) -> bool:
Expand Down Expand Up @@ -99,8 +101,13 @@ def persist(
if not allow_overwrite and os.path.exists(storage.file_options.uri):
raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri)

if not Path(storage.file_options.uri).is_absolute():
absolute_path = Path(self.repo_path) / storage.file_options.uri
else:
absolute_path = Path(storage.file_options.uri)

filesystem, path = FileSource.create_filesystem_and_path(
storage.file_options.uri,
str(absolute_path),
storage.file_options.s3_endpoint_override,
)

Expand Down Expand Up @@ -243,7 +250,9 @@ def evaluate_historical_retrieval():

all_join_keys = list(set(all_join_keys + join_keys))

df_to_join = _read_datasource(feature_view.batch_source)
df_to_join = _read_datasource(
feature_view.batch_source, config.repo_path
)

df_to_join, timestamp_field = _field_mapping(
df_to_join,
Expand Down Expand Up @@ -297,6 +306,7 @@ def evaluate_historical_retrieval():
min_event_timestamp=entity_df_event_timestamp_range[0],
max_event_timestamp=entity_df_event_timestamp_range[1],
),
repo_path=str(config.repo_path),
)
return job

Expand All @@ -316,7 +326,7 @@ def pull_latest_from_table_or_query(

# Create lazy function that is only called from the RetrievalJob object
def evaluate_offline_job():
source_df = _read_datasource(data_source)
source_df = _read_datasource(data_source, config.repo_path)

source_df = _normalize_timestamp(
source_df, timestamp_field, created_timestamp_column
Expand Down Expand Up @@ -377,6 +387,7 @@ def evaluate_offline_job():
return DaskRetrievalJob(
evaluation_function=evaluate_offline_job,
full_feature_names=False,
repo_path=str(config.repo_path),
)

@staticmethod
Expand Down Expand Up @@ -420,8 +431,13 @@ def write_logged_features(
# Since this code will be mostly used from Go-created thread, it's better to avoid producing new threads
data = pyarrow.parquet.read_table(data, use_threads=False, pre_buffer=False)

if config.repo_path is not None and not Path(destination.path).is_absolute():
absolute_path = config.repo_path / destination.path
else:
absolute_path = Path(destination.path)

filesystem, path = FileSource.create_filesystem_and_path(
destination.path,
str(absolute_path),
destination.s3_endpoint_override,
)

Expand Down Expand Up @@ -456,8 +472,14 @@ def offline_write_batch(
)

file_options = feature_view.batch_source.file_options

if config.repo_path is not None and not Path(file_options.uri).is_absolute():
absolute_path = config.repo_path / file_options.uri
else:
absolute_path = Path(file_options.uri)

filesystem, path = FileSource.create_filesystem_and_path(
file_options.uri, file_options.s3_endpoint_override
str(absolute_path), file_options.s3_endpoint_override
)
prev_table = pyarrow.parquet.read_table(
path, filesystem=filesystem, memory_map=True
Expand Down Expand Up @@ -493,7 +515,7 @@ def _get_entity_df_event_timestamp_range(
)


def _read_datasource(data_source) -> dd.DataFrame:
def _read_datasource(data_source, repo_path) -> dd.DataFrame:
storage_options = (
{
"client_kwargs": {
Expand All @@ -504,8 +526,12 @@ def _read_datasource(data_source) -> dd.DataFrame:
else None
)

if not Path(data_source.path).is_absolute():
path = repo_path / data_source.path
else:
path = data_source.path
return dd.read_parquet(
data_source.path,
path,
storage_options=storage_options,
)

Expand Down
17 changes: 14 additions & 3 deletions sdk/python/feast/infra/offline_stores/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from feast.repo_config import FeastConfigBaseModel, RepoConfig


def _read_data_source(data_source: DataSource) -> Table:
def _read_data_source(data_source: DataSource, repo_path: str) -> Table:
assert isinstance(data_source, FileSource)

if isinstance(data_source.file_format, ParquetFormat):
Expand All @@ -43,21 +43,32 @@ def _read_data_source(data_source: DataSource) -> Table:
def _write_data_source(
table: Table,
data_source: DataSource,
repo_path: str,
mode: str = "append",
allow_overwrite: bool = False,
):
assert isinstance(data_source, FileSource)

file_options = data_source.file_options

if mode == "overwrite" and not allow_overwrite and os.path.exists(file_options.uri):
if not Path(file_options.uri).is_absolute():
absolute_path = Path(repo_path) / file_options.uri
else:
absolute_path = Path(file_options.uri)

if (
mode == "overwrite"
and not allow_overwrite
and os.path.exists(str(absolute_path))
):
raise SavedDatasetLocationAlreadyExists(location=file_options.uri)

if isinstance(data_source.file_format, ParquetFormat):
if mode == "overwrite":
table = table.to_pyarrow()

filesystem, path = FileSource.create_filesystem_and_path(
file_options.uri,
str(absolute_path),
file_options.s3_endpoint_override,
)

Expand Down
11 changes: 10 additions & 1 deletion sdk/python/feast/infra/offline_stores/file_source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Tuple

import pyarrow
Expand Down Expand Up @@ -154,8 +155,16 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
if (
config.repo_path is not None
and not Path(self.file_options.uri).is_absolute()
):
absolute_path = config.repo_path / self.file_options.uri
else:
absolute_path = Path(self.file_options.uri)

filesystem, path = FileSource.create_filesystem_and_path(
self.path, self.file_options.s3_endpoint_override
str(absolute_path), self.file_options.s3_endpoint_override
)

# TODO why None check necessary
Expand Down
37 changes: 25 additions & 12 deletions sdk/python/feast/infra/offline_stores/ibis.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def pull_latest_from_table_or_query_ibis(
created_timestamp_column: Optional[str],
start_date: datetime,
end_date: datetime,
data_source_reader: Callable[[DataSource], Table],
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
data_source_reader: Callable[[DataSource, str], Table],
data_source_writer: Callable[[pyarrow.Table, DataSource, str], None],
staging_location: Optional[str] = None,
staging_location_endpoint_override: Optional[str] = None,
) -> RetrievalJob:
Expand All @@ -57,7 +57,7 @@ def pull_latest_from_table_or_query_ibis(
start_date = start_date.astimezone(tz=timezone.utc)
end_date = end_date.astimezone(tz=timezone.utc)

table = data_source_reader(data_source)
table = data_source_reader(data_source, str(config.repo_path))

table = table.select(*fields)

Expand Down Expand Up @@ -87,6 +87,7 @@ def pull_latest_from_table_or_query_ibis(
data_source_writer=data_source_writer,
staging_location=staging_location,
staging_location_endpoint_override=staging_location_endpoint_override,
repo_path=str(config.repo_path),
)


Expand Down Expand Up @@ -147,8 +148,8 @@ def get_historical_features_ibis(
entity_df: Union[pd.DataFrame, str],
registry: BaseRegistry,
project: str,
data_source_reader: Callable[[DataSource], Table],
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
data_source_reader: Callable[[DataSource, str], Table],
data_source_writer: Callable[[pyarrow.Table, DataSource, str], None],
full_feature_names: bool = False,
staging_location: Optional[str] = None,
staging_location_endpoint_override: Optional[str] = None,
Expand All @@ -174,7 +175,9 @@ def get_historical_features_ibis(
def read_fv(
feature_view: FeatureView, feature_refs: List[str], full_feature_names: bool
) -> Tuple:
fv_table: Table = data_source_reader(feature_view.batch_source)
fv_table: Table = data_source_reader(
feature_view.batch_source, str(config.repo_path)
)

for old_name, new_name in feature_view.batch_source.field_mapping.items():
if old_name in fv_table.columns:
Expand Down Expand Up @@ -247,6 +250,7 @@ def read_fv(
data_source_writer=data_source_writer,
staging_location=staging_location,
staging_location_endpoint_override=staging_location_endpoint_override,
repo_path=str(config.repo_path),
)


Expand All @@ -258,16 +262,16 @@ def pull_all_from_table_or_query_ibis(
timestamp_field: str,
start_date: datetime,
end_date: datetime,
data_source_reader: Callable[[DataSource], Table],
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
data_source_reader: Callable[[DataSource, str], Table],
data_source_writer: Callable[[pyarrow.Table, DataSource, str], None],
staging_location: Optional[str] = None,
staging_location_endpoint_override: Optional[str] = None,
) -> RetrievalJob:
fields = join_key_columns + feature_name_columns + [timestamp_field]
start_date = start_date.astimezone(tz=timezone.utc)
end_date = end_date.astimezone(tz=timezone.utc)

table = data_source_reader(data_source)
table = data_source_reader(data_source, str(config.repo_path))

table = table.select(*fields)

Expand All @@ -290,6 +294,7 @@ def pull_all_from_table_or_query_ibis(
data_source_writer=data_source_writer,
staging_location=staging_location,
staging_location_endpoint_override=staging_location_endpoint_override,
repo_path=str(config.repo_path),
)


Expand Down Expand Up @@ -319,7 +324,7 @@ def offline_write_batch_ibis(
feature_view: FeatureView,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
data_source_writer: Callable[[pyarrow.Table, DataSource, str], None],
):
pa_schema, column_names = get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
Expand All @@ -330,7 +335,9 @@ def offline_write_batch_ibis(
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
)

data_source_writer(ibis.memtable(table), feature_view.batch_source)
data_source_writer(
ibis.memtable(table), feature_view.batch_source, str(config.repo_path)
)


def deduplicate(
Expand Down Expand Up @@ -469,6 +476,7 @@ def __init__(
data_source_writer,
staging_location,
staging_location_endpoint_override,
repo_path,
) -> None:
super().__init__()
self.table = table
Expand All @@ -480,6 +488,7 @@ def __init__(
self.data_source_writer = data_source_writer
self.staging_location = staging_location
self.staging_location_endpoint_override = staging_location_endpoint_override
self.repo_path = repo_path

def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
return self.table.execute()
Expand All @@ -502,7 +511,11 @@ def persist(
timeout: Optional[int] = None,
):
self.data_source_writer(
self.table, storage.to_data_source(), "overwrite", allow_overwrite
self.table,
storage.to_data_source(),
self.repo_path,
"overwrite",
allow_overwrite,
)

@property
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class RepoConfig(FeastBaseModel):
""" Flags (deprecated field): Feature flags for experimental features """

repo_path: Optional[Path] = None
"""When using relative path in FileSource path, this parameter is mandatory"""

entity_key_serialization_version: StrictInt = 1
""" Entity key serialization version: This version is used to control what serialization scheme is
Expand Down
4 changes: 3 additions & 1 deletion sdk/python/feast/templates/cassandra/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ def bootstrap():

# example_repo.py
example_py_file = repo_path / "example_repo.py"
replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path))
replace_str_in_file(
example_py_file, "%PARQUET_PATH%", str(driver_stats_path.relative_to(repo_path))
)

# store config yaml, interact with user and then customize file:
settings = collect_cassandra_store_settings()
Expand Down
4 changes: 3 additions & 1 deletion sdk/python/feast/templates/hazelcast/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ def bootstrap():

# example_repo.py
example_py_file = repo_path / "example_repo.py"
replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path))
replace_str_in_file(
example_py_file, "%PARQUET_PATH%", str(driver_stats_path.relative_to(repo_path))
)

# store config yaml, interact with user and then customize file:
settings = collect_hazelcast_online_store_settings()
Expand Down
4 changes: 3 additions & 1 deletion sdk/python/feast/templates/hbase/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def bootstrap():
driver_df.to_parquet(path=str(driver_stats_path), allow_truncated_timestamps=True)

example_py_file = repo_path / "example_repo.py"
replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path))
replace_str_in_file(
example_py_file, "%PARQUET_PATH%", str(driver_stats_path.relative_to(repo_path))
)


if __name__ == "__main__":
Expand Down
8 changes: 6 additions & 2 deletions sdk/python/feast/templates/local/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ def bootstrap():

example_py_file = repo_path / "example_repo.py"
replace_str_in_file(example_py_file, "%PROJECT_NAME%", str(project_name))
replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path))
replace_str_in_file(example_py_file, "%LOGGING_PATH%", str(data_path))
replace_str_in_file(
example_py_file, "%PARQUET_PATH%", str(driver_stats_path.relative_to(repo_path))
)
replace_str_in_file(
example_py_file, "%LOGGING_PATH%", str(data_path.relative_to(repo_path))
)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 03e43b5

Please sign in to comment.