Skip to content

Commit

Permalink
Refactor component package (#654)
Browse files Browse the repository at this point in the history
Refactor component package as part of #643

---------

Co-authored-by: Robbe Sneyders <robbe.sneyders@gmail.com>
Co-authored-by: Philippe Moussalli <philippe.moussalli95@gmail.com>
  • Loading branch information
3 people committed Nov 24, 2023
1 parent b4fe222 commit bb3b623
Show file tree
Hide file tree
Showing 22 changed files with 421 additions and 282 deletions.
175 changes: 61 additions & 114 deletions src/fondant/component/data_io.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import logging
import os
import typing as t
from collections import defaultdict

import dask.dataframe as dd
from dask.diagnostics import ProgressBar
from dask.distributed import Client

from fondant.core.component_spec import ComponentSpec, ComponentSubset
from fondant.core.component_spec import ComponentSpec
from fondant.core.manifest import Manifest

logger = logging.getLogger(__name__)

DEFAULT_INDEX_NAME = "id"


class DataIO:
def __init__(self, *, manifest: Manifest, component_spec: ComponentSpec) -> None:
Expand Down Expand Up @@ -82,73 +85,48 @@ def partition_loaded_dataframe(self, dataframe: dd.DataFrame) -> dd.DataFrame:

return dataframe

def _load_subset(self, subset_name: str, fields: t.List[str]) -> dd.DataFrame:
def load_dataframe(self) -> dd.DataFrame:
"""
Function that loads a subset from the manifest as a Dask dataframe.
Args:
subset_name: the name of the subset to load
fields: the fields to load from the subset
Function that loads the subsets defined in the component spec as a single Dask dataframe for
the user.
Returns:
The subset as a dask dataframe
The Dask dataframe with all columns defined in the manifest field mapping
"""
subset = self.manifest.subsets[subset_name]
remote_path = subset.location

logger.info(f"Loading subset {subset_name} with fields {fields}...")
dataframe = None
field_mapping = defaultdict(list)

subset_df = dd.read_parquet(
remote_path,
columns=fields,
calculate_divisions=True,
# Add index field to field mapping to guarantee start reading with the index dataframe
field_mapping[self.manifest.get_field_location(DEFAULT_INDEX_NAME)].append(
DEFAULT_INDEX_NAME,
)

# add subset prefix to columns
subset_df = subset_df.rename(
columns={col: subset_name + "_" + col for col in subset_df.columns},
)
for field_name in self.component_spec.consumes:
location = self.manifest.get_field_location(field_name)
field_mapping[location].append(field_name)

return subset_df

def _load_index(self) -> dd.DataFrame:
"""
Function that loads the index from the manifest as a Dask dataframe.
Returns:
The index as a dask dataframe
"""
# get index subset from the manifest
index = self.manifest.index
# get remote path
remote_path = index.location

# load index from parquet, expecting id and source columns
return dd.read_parquet(remote_path, calculate_divisions=True)

def load_dataframe(self) -> dd.DataFrame:
"""
Function that loads the subsets defined in the component spec as a single Dask dataframe for
the user.
for location, fields in field_mapping.items():
if DEFAULT_INDEX_NAME in fields:
fields.remove(DEFAULT_INDEX_NAME)

Returns:
The Dask dataframe with the field columns in the format (<subset>_<column_name>)
as well as the index columns.
"""
# load index into dataframe
dataframe = self._load_index()
for name, subset in self.component_spec.consumes.items():
fields = list(subset.fields.keys())
subset_df = self._load_subset(name, fields)
# left joins -> filter on index
dataframe = dd.merge(
dataframe,
subset_df,
left_index=True,
right_index=True,
how="left",
partial_df = dd.read_parquet(
location,
columns=fields,
index=DEFAULT_INDEX_NAME,
calculate_divisions=True,
)

if dataframe is None:
# ensure that the index is set correctly and divisions are known.
dataframe = partial_df
else:
dataframe = dataframe.merge(
partial_df,
how="left",
left_index=True,
right_index=True,
)

dataframe = self.partition_loaded_dataframe(dataframe)

logging.info(f"Columns of dataframe: {list(dataframe.columns)}")
Expand All @@ -170,79 +148,48 @@ def write_dataframe(
dataframe: dd.DataFrame,
dask_client: t.Optional[Client] = None,
) -> None:
write_tasks = []
columns_to_produce = [
column_name for column_name, field in self.component_spec.produces.items()
]

dataframe.index = dataframe.index.rename("id")
dataframe.index = dataframe.index.rename(DEFAULT_INDEX_NAME)

# Turn index into an empty dataframe so we can write it
index_df = dataframe.index.to_frame().drop(columns=["id"])
write_index_task = self._write_subset(
index_df,
subset_name="index",
subset_spec=self.component_spec.index,
)
write_tasks.append(write_index_task)
# validation that all columns are in the dataframe
self.validate_dataframe_columns(dataframe, columns_to_produce)

for subset_name, subset_spec in self.component_spec.produces.items():
subset_df = self._extract_subset_dataframe(
dataframe,
subset_name=subset_name,
subset_spec=subset_spec,
)
write_subset_task = self._write_subset(
subset_df,
subset_name=subset_name,
subset_spec=subset_spec,
)
write_tasks.append(write_subset_task)
dataframe = dataframe[columns_to_produce]
write_task = self._write_dataframe(dataframe)

with ProgressBar():
logging.info("Writing data...")
# alternative implementation possible: futures = client.compute(...)
dd.compute(*write_tasks, scheduler=dask_client)
dd.compute(write_task, scheduler=dask_client)

@staticmethod
def _extract_subset_dataframe(
dataframe: dd.DataFrame,
*,
subset_name: str,
subset_spec: ComponentSubset,
) -> dd.DataFrame:
"""Create subset dataframe to save with the original field name as the column name."""
# Create a new dataframe with only the columns needed for the output subset
subset_columns = [f"{subset_name}_{field}" for field in subset_spec.fields]
try:
subset_df = dataframe[subset_columns]
except KeyError as e:
def validate_dataframe_columns(dataframe: dd.DataFrame, columns: t.List[str]):
"""Validates that all columns are available in the dataset."""
missing_fields = []
for col in columns:
if col not in dataframe.columns:
missing_fields.append(col)

if missing_fields:
msg = (
f"Field {e.args[0]} defined in output subset {subset_name} "
f"Fields {missing_fields} defined in output dataset "
f"but not found in dataframe"
)
raise ValueError(
msg,
)

# Remove the subset prefix from the column names
subset_df = subset_df.rename(
columns={col: col[(len(f"{subset_name}_")) :] for col in subset_columns},
def _write_dataframe(self, dataframe: dd.DataFrame) -> dd.core.Scalar:
"""Create dataframe writing task."""
location = (
self.manifest.base_path + "/" + self.component_spec.component_folder_name
)

return subset_df

def _write_subset(
self,
dataframe: dd.DataFrame,
*,
subset_name: str,
subset_spec: ComponentSubset,
) -> dd.core.Scalar:
if subset_name == "index":
location = self.manifest.index.location
else:
location = self.manifest.subsets[subset_name].location

schema = {field.name: field.type.value for field in subset_spec.fields.values()}

schema = {
field.name: field.type.value
for field in self.component_spec.produces.values()
}
return self._create_write_task(dataframe, location=location, schema=schema)

@staticmethod
Expand Down
34 changes: 9 additions & 25 deletions src/fondant/component/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,42 +491,25 @@ def optional_fondant_arguments() -> t.List[str]:
@staticmethod
def wrap_transform(transform: t.Callable, *, spec: ComponentSpec) -> t.Callable:
"""Factory that creates a function to wrap the component transform function. The wrapper:
- Converts the columns to hierarchical format before passing the dataframe to the
transform function
- Removes extra columns from the returned dataframe which are not defined in the component
spec `produces` section
- Sorts the columns from the returned dataframe according to the order in the component
spec `produces` section to match the order in the `meta` argument passed to Dask's
`map_partitions`.
- Flattens the returned dataframe columns.
Args:
transform: Transform method to wrap
spec: Component specification to base behavior on
"""

def wrapped_transform(dataframe: pd.DataFrame) -> pd.DataFrame:
# Switch to hierarchical columns
dataframe.columns = pd.MultiIndex.from_tuples(
tuple(column.split("_")) for column in dataframe.columns
)

# Call transform method
dataframe = transform(dataframe)

# Drop columns not in specification
columns = [
(subset_name, field)
for subset_name, subset in spec.produces.items()
for field in subset.fields
]
dataframe = dataframe[columns]

# Switch to flattened columns
dataframe.columns = [
"_".join(column) for column in dataframe.columns.to_flat_index()
]
return dataframe
columns = [name for name, field in spec.produces.items()]

return dataframe[columns]

return wrapped_transform

Expand All @@ -552,11 +535,8 @@ def _execute_component(

# Create meta dataframe with expected format
meta_dict = {"id": pd.Series(dtype="object")}
for subset_name, subset in self.spec.produces.items():
for field_name, field in subset.fields.items():
meta_dict[f"{subset_name}_{field_name}"] = pd.Series(
dtype=pd.ArrowDtype(field.type.value),
)
for field_name, field in self.spec.produces.items():
meta_dict[field_name] = pd.Series(dtype=pd.ArrowDtype(field.type.value))
meta_df = pd.DataFrame(meta_dict).set_index("id")

wrapped_transform = self.wrap_transform(component.transform, spec=self.spec)
Expand All @@ -573,8 +553,10 @@ def _execute_component(

return dataframe

# TODO: fix in #244
def _infer_index_change(self) -> bool:
"""Infer if this component changes the index based on its component spec."""
"""
if not self.spec.accepts_additional_subsets:
return True
if not self.spec.outputs_additional_subsets:
Expand All @@ -585,6 +567,8 @@ def _infer_index_change(self) -> bool:
return any(
not subset.additional_fields for subset in self.spec.produces.values()
)
"""
return False


class DaskWriteExecutor(Executor[DaskWriteComponent]):
Expand Down
48 changes: 10 additions & 38 deletions src/fondant/core/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pkgutil
import types
import typing as t
from collections import OrderedDict
from dataclasses import asdict, dataclass
from pathlib import Path

Expand Down Expand Up @@ -146,7 +145,7 @@ def metadata(self) -> t.Dict[str, t.Any]:

@property
def index(self) -> Field:
return Field(name="Index", location=self._specification["index"]["location"])
return Field(name="id", location=self._specification["index"]["location"])

def update_metadata(self, key: str, value: t.Any) -> None:
self.metadata[key] = value
Expand All @@ -155,43 +154,16 @@ def update_metadata(self, key: str, value: t.Any) -> None:
def base_path(self) -> str:
return self.metadata["base_path"]

@property
def field_mapping(self) -> t.Mapping[str, t.List[str]]:
"""
Retrieve a mapping of field locations to corresponding field names.
A dictionary where keys are field locations and values are lists
of column names.
The method returns an immutable OrderedDict where the first dict element contains the
location of the dataframe with the index. This allows an efficient left join operation.
Example:
{
"/base_path/component_1": ["Name", "HP"],
"/base_path/component_2": ["Type 1", "Type 2"],
}
"""
field_mapping = {}
for field_name, field in {"id": self.index, **self.fields}.items():
location = (
f"{self.base_path}/{self.pipeline_name}/{self.run_id}{field.location}"
)
if location in field_mapping:
field_mapping[location].append(field_name)
else:
field_mapping[location] = [field_name]

# Sort field mapping that the first dataset contains the index
sorted_keys = sorted(
field_mapping.keys(),
key=lambda key: "id" in field_mapping[key],
reverse=True,
)
sorted_field_mapping = OrderedDict(
(key, field_mapping[key]) for key in sorted_keys
)
def get_field_location(self, field_name: str):
"""Return absolute path to the field location."""
if field_name == "id":
return f"{self.base_path}/{self.pipeline_name}/{self.run_id}{self.index.location}"
if field_name not in self.fields:
msg = f"Field {field_name} is not available in the manifest."
raise ValueError(msg)

return types.MappingProxyType(sorted_field_mapping)
field = self.fields[field_name]
return f"{self.base_path}/{self.pipeline_name}/{self.run_id}{field.location}"

@property
def run_id(self) -> str:
Expand Down
Loading

0 comments on commit bb3b623

Please sign in to comment.