Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor component package #654

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
35338b6
Update component spec schema validation
mrchtr Nov 16, 2023
a269e3c
Update component spec tests to validate new component spec
mrchtr Nov 16, 2023
ad0dab6
Add additional fields to json schema
mrchtr Nov 16, 2023
7b91535
Update manifest json schema for validation
mrchtr Nov 16, 2023
5d1bf5e
Update manifest creation
mrchtr Nov 17, 2023
d8ecd01
Reduce PR to core module
mrchtr Nov 21, 2023
12c78ca
Addresses comments
mrchtr Nov 21, 2023
c1cad60
Restructure test directory
mrchtr Nov 21, 2023
fd0699c
Remove additional fields in common.json
mrchtr Nov 21, 2023
0f8117f
Test structure
mrchtr Nov 21, 2023
7e8a1d6
Refactor component package
mrchtr Nov 21, 2023
9f67c61
Update src/fondant/core/component_spec.py
mrchtr Nov 21, 2023
40955bf
Update src/fondant/core/manifest.py
mrchtr Nov 21, 2023
6b246a4
Update src/fondant/core/component_spec.py
mrchtr Nov 21, 2023
8ef38d9
Update src/fondant/core/manifest.py
mrchtr Nov 21, 2023
e8c8135
Update src/fondant/core/schema.py
mrchtr Nov 21, 2023
df9a60e
Addresses comments
mrchtr Nov 21, 2023
2256118
Addresses comments
mrchtr Nov 21, 2023
3042fb5
Addresses comments
mrchtr Nov 21, 2023
8fa8be7
Update src/fondant/core/manifest.py
mrchtr Nov 21, 2023
25eb492
Addresses comments
mrchtr Nov 22, 2023
c0fb47a
Merge branch 'feature/implement-new-dataset-format' into feautre/refa…
mrchtr Nov 22, 2023
0701662
Addresses comments
mrchtr Nov 22, 2023
365ca6d
Update test examples
mrchtr Nov 22, 2023
4dc7dc7
Update src/fondant/core/manifest.py
mrchtr Nov 22, 2023
a60ca3e
addresses comments
mrchtr Nov 22, 2023
d2182a0
Merge feature/implement-new-dataset-format into feature/refactore-com…
mrchtr Nov 22, 2023
43a7b68
Addressing comments regarding data_io
mrchtr Nov 23, 2023
83a5de6
Merge feature/redesign-dataset-format-and-interface into feature/refa…
mrchtr Nov 23, 2023
5ac5e42
Update tests
mrchtr Nov 23, 2023
6616bf2
Remove set_index on during merging
mrchtr Nov 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 61 additions & 114 deletions src/fondant/component/data_io.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import logging
import os
import typing as t
from collections import defaultdict

import dask.dataframe as dd
from dask.diagnostics import ProgressBar
from dask.distributed import Client

from fondant.core.component_spec import ComponentSpec, ComponentSubset
from fondant.core.component_spec import ComponentSpec
from fondant.core.manifest import Manifest

logger = logging.getLogger(__name__)

DEFAULT_INDEX_NAME = "id"


class DataIO:
def __init__(self, *, manifest: Manifest, component_spec: ComponentSpec) -> None:
Expand Down Expand Up @@ -82,73 +85,48 @@ def partition_loaded_dataframe(self, dataframe: dd.DataFrame) -> dd.DataFrame:

return dataframe

def _load_subset(self, subset_name: str, fields: t.List[str]) -> dd.DataFrame:
def load_dataframe(self) -> dd.DataFrame:
"""
Function that loads a subset from the manifest as a Dask dataframe.

Args:
subset_name: the name of the subset to load
fields: the fields to load from the subset
Function that loads the subsets defined in the component spec as a single Dask dataframe for
the user.

Returns:
The subset as a dask dataframe
The Dask dataframe with all columns defined in the manifest field mapping
"""
subset = self.manifest.subsets[subset_name]
remote_path = subset.location

logger.info(f"Loading subset {subset_name} with fields {fields}...")
dataframe = None
field_mapping = defaultdict(list)

subset_df = dd.read_parquet(
remote_path,
columns=fields,
calculate_divisions=True,
# Add index field to field mapping to guarantee start reading with the index dataframe
field_mapping[self.manifest.get_field_location(DEFAULT_INDEX_NAME)].append(
DEFAULT_INDEX_NAME,
)

# add subset prefix to columns
subset_df = subset_df.rename(
columns={col: subset_name + "_" + col for col in subset_df.columns},
)
for field_name in self.component_spec.consumes:
location = self.manifest.get_field_location(field_name)
field_mapping[location].append(field_name)

return subset_df

def _load_index(self) -> dd.DataFrame:
"""
Function that loads the index from the manifest as a Dask dataframe.

Returns:
The index as a dask dataframe
"""
# get index subset from the manifest
index = self.manifest.index
# get remote path
remote_path = index.location

# load index from parquet, expecting id and source columns
return dd.read_parquet(remote_path, calculate_divisions=True)

def load_dataframe(self) -> dd.DataFrame:
"""
Function that loads the subsets defined in the component spec as a single Dask dataframe for
the user.
for location, fields in field_mapping.items():
RobbeSneyders marked this conversation as resolved.
Show resolved Hide resolved
if DEFAULT_INDEX_NAME in fields:
fields.remove(DEFAULT_INDEX_NAME)

Returns:
The Dask dataframe with the field columns in the format (<subset>_<column_name>)
as well as the index columns.
"""
# load index into dataframe
dataframe = self._load_index()
for name, subset in self.component_spec.consumes.items():
fields = list(subset.fields.keys())
subset_df = self._load_subset(name, fields)
# left joins -> filter on index
dataframe = dd.merge(
dataframe,
subset_df,
left_index=True,
right_index=True,
how="left",
partial_df = dd.read_parquet(
location,
columns=fields,
index=DEFAULT_INDEX_NAME,
calculate_divisions=True,
)

if dataframe is None:
# ensure that the index is set correctly and divisions are known.
dataframe = partial_df
else:
dataframe = dataframe.merge(
partial_df,
how="left",
left_index=True,
right_index=True,
)

dataframe = self.partition_loaded_dataframe(dataframe)

logging.info(f"Columns of dataframe: {list(dataframe.columns)}")
Expand All @@ -170,79 +148,48 @@ def write_dataframe(
dataframe: dd.DataFrame,
dask_client: t.Optional[Client] = None,
) -> None:
write_tasks = []
columns_to_produce = [
column_name for column_name, field in self.component_spec.produces.items()
]

dataframe.index = dataframe.index.rename("id")
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
dataframe.index = dataframe.index.rename(DEFAULT_INDEX_NAME)

# Turn index into an empty dataframe so we can write it
index_df = dataframe.index.to_frame().drop(columns=["id"])
write_index_task = self._write_subset(
index_df,
subset_name="index",
subset_spec=self.component_spec.index,
)
write_tasks.append(write_index_task)
# validation that all columns are in the dataframe
self.validate_dataframe_columns(dataframe, columns_to_produce)

for subset_name, subset_spec in self.component_spec.produces.items():
subset_df = self._extract_subset_dataframe(
dataframe,
subset_name=subset_name,
subset_spec=subset_spec,
)
write_subset_task = self._write_subset(
subset_df,
subset_name=subset_name,
subset_spec=subset_spec,
)
write_tasks.append(write_subset_task)
dataframe = dataframe[columns_to_produce]
write_task = self._write_dataframe(dataframe)

with ProgressBar():
logging.info("Writing data...")
# alternative implementation possible: futures = client.compute(...)
dd.compute(*write_tasks, scheduler=dask_client)
dd.compute(write_task, scheduler=dask_client)

@staticmethod
def _extract_subset_dataframe(
dataframe: dd.DataFrame,
*,
subset_name: str,
subset_spec: ComponentSubset,
) -> dd.DataFrame:
"""Create subset dataframe to save with the original field name as the column name."""
# Create a new dataframe with only the columns needed for the output subset
subset_columns = [f"{subset_name}_{field}" for field in subset_spec.fields]
try:
subset_df = dataframe[subset_columns]
except KeyError as e:
def validate_dataframe_columns(dataframe: dd.DataFrame, columns: t.List[str]):
"""Validates that all columns are available in the dataset."""
missing_fields = []
for col in columns:
if col not in dataframe.columns:
missing_fields.append(col)

if missing_fields:
msg = (
f"Field {e.args[0]} defined in output subset {subset_name} "
f"Fields {missing_fields} defined in output dataset "
f"but not found in dataframe"
)
raise ValueError(
msg,
)

# Remove the subset prefix from the column names
subset_df = subset_df.rename(
columns={col: col[(len(f"{subset_name}_")) :] for col in subset_columns},
def _write_dataframe(self, dataframe: dd.DataFrame) -> dd.core.Scalar:
"""Create dataframe writing task."""
location = (
self.manifest.base_path + "/" + self.component_spec.component_folder_name
)

return subset_df

def _write_subset(
self,
dataframe: dd.DataFrame,
*,
subset_name: str,
subset_spec: ComponentSubset,
) -> dd.core.Scalar:
if subset_name == "index":
location = self.manifest.index.location
else:
location = self.manifest.subsets[subset_name].location

schema = {field.name: field.type.value for field in subset_spec.fields.values()}

schema = {
field.name: field.type.value
for field in self.component_spec.produces.values()
}
return self._create_write_task(dataframe, location=location, schema=schema)

@staticmethod
Expand Down
34 changes: 9 additions & 25 deletions src/fondant/component/executor.py
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -491,42 +491,25 @@ def optional_fondant_arguments() -> t.List[str]:
@staticmethod
def wrap_transform(transform: t.Callable, *, spec: ComponentSpec) -> t.Callable:
"""Factory that creates a function to wrap the component transform function. The wrapper:
- Converts the columns to hierarchical format before passing the dataframe to the
transform function
- Removes extra columns from the returned dataframe which are not defined in the component
spec `produces` section
- Sorts the columns from the returned dataframe according to the order in the component
spec `produces` section to match the order in the `meta` argument passed to Dask's
`map_partitions`.
- Flattens the returned dataframe columns.

Args:
transform: Transform method to wrap
spec: Component specification to base behavior on
"""

def wrapped_transform(dataframe: pd.DataFrame) -> pd.DataFrame:
# Switch to hierarchical columns
dataframe.columns = pd.MultiIndex.from_tuples(
tuple(column.split("_")) for column in dataframe.columns
)

# Call transform method
dataframe = transform(dataframe)

# Drop columns not in specification
columns = [
(subset_name, field)
for subset_name, subset in spec.produces.items()
for field in subset.fields
]
dataframe = dataframe[columns]

# Switch to flattened columns
dataframe.columns = [
"_".join(column) for column in dataframe.columns.to_flat_index()
]
return dataframe
columns = [name for name, field in spec.produces.items()]

return dataframe[columns]

return wrapped_transform

Expand All @@ -552,11 +535,8 @@ def _execute_component(

# Create meta dataframe with expected format
meta_dict = {"id": pd.Series(dtype="object")}
for subset_name, subset in self.spec.produces.items():
for field_name, field in subset.fields.items():
meta_dict[f"{subset_name}_{field_name}"] = pd.Series(
dtype=pd.ArrowDtype(field.type.value),
)
for field_name, field in self.spec.produces.items():
meta_dict[field_name] = pd.Series(dtype=pd.ArrowDtype(field.type.value))
meta_df = pd.DataFrame(meta_dict).set_index("id")

wrapped_transform = self.wrap_transform(component.transform, spec=self.spec)
Expand All @@ -573,8 +553,10 @@ def _execute_component(

return dataframe

# TODO: fix in #244
def _infer_index_change(self) -> bool:
"""Infer if this component changes the index based on its component spec."""
"""
if not self.spec.accepts_additional_subsets:
return True
if not self.spec.outputs_additional_subsets:
Expand All @@ -585,6 +567,8 @@ def _infer_index_change(self) -> bool:
return any(
not subset.additional_fields for subset in self.spec.produces.values()
)
"""
return False


class DaskWriteExecutor(Executor[DaskWriteComponent]):
Expand Down
48 changes: 10 additions & 38 deletions src/fondant/core/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pkgutil
import types
import typing as t
from collections import OrderedDict
from dataclasses import asdict, dataclass
from pathlib import Path

Expand Down Expand Up @@ -146,7 +145,7 @@ def metadata(self) -> t.Dict[str, t.Any]:

@property
def index(self) -> Field:
return Field(name="Index", location=self._specification["index"]["location"])
return Field(name="id", location=self._specification["index"]["location"])

def update_metadata(self, key: str, value: t.Any) -> None:
self.metadata[key] = value
Expand All @@ -155,43 +154,16 @@ def update_metadata(self, key: str, value: t.Any) -> None:
def base_path(self) -> str:
return self.metadata["base_path"]

@property
def field_mapping(self) -> t.Mapping[str, t.List[str]]:
"""
Retrieve a mapping of field locations to corresponding field names.
A dictionary where keys are field locations and values are lists
of column names.

The method returns an immutable OrderedDict where the first dict element contains the
location of the dataframe with the index. This allows an efficient left join operation.

Example:
{
"/base_path/component_1": ["Name", "HP"],
"/base_path/component_2": ["Type 1", "Type 2"],
}
"""
field_mapping = {}
for field_name, field in {"id": self.index, **self.fields}.items():
location = (
f"{self.base_path}/{self.pipeline_name}/{self.run_id}{field.location}"
)
if location in field_mapping:
field_mapping[location].append(field_name)
else:
field_mapping[location] = [field_name]

# Sort field mapping that the first dataset contains the index
sorted_keys = sorted(
field_mapping.keys(),
key=lambda key: "id" in field_mapping[key],
reverse=True,
)
sorted_field_mapping = OrderedDict(
(key, field_mapping[key]) for key in sorted_keys
)
def get_field_location(self, field_name: str):
"""Return absolute path to the field location."""
if field_name == "id":
return f"{self.base_path}/{self.pipeline_name}/{self.run_id}{self.index.location}"
if field_name not in self.fields:
msg = f"Field {field_name} is not available in the manifest."
raise ValueError(msg)

return types.MappingProxyType(sorted_field_mapping)
field = self.fields[field_name]
return f"{self.base_path}/{self.pipeline_name}/{self.run_id}{field.location}"

@property
def run_id(self) -> str:
Expand Down
Loading
Loading