Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer consume operation if not present in dataset interface #859

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/fondant/component/component.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""This module defines interfaces which components should implement to be executed by fondant."""

import logging
import typing as t
from abc import abstractmethod

import dask.dataframe as dd
import pandas as pd

logger = logging.getLogger(__name__)
mrchtr marked this conversation as resolved.
Show resolved Hide resolved


class BaseComponent:
"""Base interface for each component, specifying only the constructor.
Expand Down
30 changes: 22 additions & 8 deletions src/fondant/core/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,27 @@ def consumes(self) -> t.Mapping[str, Field]:
},
)

@property
def consumes_additional_properties(self) -> bool:
"""Returns a boolean indicating whether the component consumes additional properties."""
return self._specification.get("consumes", {}).get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be nice to implement on the spec (same for the produces)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean in the OperationSpec?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a better look and never mind. I was lost in the many specs 😛 .

"additionalProperties",
False,
)

@property
def consumes_is_defined(self) -> bool:
"""Returns a boolean indicating whether the component consumes is defined."""
return bool(self._specification.get("consumes", {}))

@property
def produces_additional_properties(self) -> bool:
"""Returns a boolean indicating whether the component produces additional properties."""
return self._specification.get("produces", {}).get(
"additionalProperties",
False,
)

@property
def produces(self) -> t.Mapping[str, Field]:
"""The fields produced by the component as an immutable mapping."""
Expand Down Expand Up @@ -426,13 +447,6 @@ def _inner_mapping(self, name: str) -> t.Mapping[str, Field]:
if not isinstance(value, pa.DataType):
continue

if not self._component_spec.is_generic(name):
msg = (
f"Component {self._component_spec.name} does not allow specifying additional "
f"fields but received {key}."
)
raise InvalidPipelineDefinition(msg)

if key not in spec_mapping:
mapping[key] = Field(name=key, type=Type(value))
else:
Expand Down Expand Up @@ -471,7 +485,7 @@ def _outer_mapping(self, name: str) -> t.Mapping[str, Field]:
if not isinstance(value, str):
continue

if key in spec_mapping:
if key in spec_mapping: # TODO: additionalFields true?
mapping[value] = Field(name=value, type=mapping.pop(key).type)
else:
msg = (
Expand Down
81 changes: 23 additions & 58 deletions src/fondant/pipeline/lightweight_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import pyarrow as pa

from fondant.component import BaseComponent, Component
from fondant.core.schema import Field, Type
from fondant.core.component_spec import ComponentSpec
from fondant.core.schema import Type
from fondant.pipeline.argument_inference import infer_arguments

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,68 +80,19 @@ def produces(cls) -> t.Optional[t.Dict[str, t.Any]]:
pass

@classmethod
def modify_spec_consumes(
cls,
spec_consumes: t.Dict[str, t.Any],
apply_consumes: t.Optional[t.Dict[str, pa.DataType]],
):
"""Modify fields based on the consumes argument in the 'apply' method."""
if apply_consumes:
for k, v in apply_consumes.items():
if isinstance(v, str):
spec_consumes[k] = spec_consumes.pop(v)
else:
msg = (
f"Invalid data type for field `{k}` in the `apply_consumes` "
f"argument. Only string types are allowed."
)
raise ValueError(
msg,
)
return spec_consumes

@classmethod
def get_spec_consumes(
cls,
dataset_fields: t.Mapping[str, Field],
apply_consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None,
):
"""
Function that get the consumes spec for the component based on the dataset fields and
the apply_consumes argument.

Args:
dataset_fields: The fields of the dataset.
apply_consumes: The consumes argument in the apply method.

Returns:
The consumes spec for the component.
"""
def _get_spec_consumes(cls) -> t.Mapping[str, t.Union[str, pa.DataType, bool]]:
"""Get the consumes spec for the component."""
consumes = cls.consumes()

if consumes is None:
# Get consumes spec from the dataset
spec_consumes = {k: v.type.to_dict() for k, v in dataset_fields.items()}

spec_consumes = cls.modify_spec_consumes(spec_consumes, apply_consumes)

logger.warning(
"No consumes defined. Consumes will be inferred from the dataset."
" All field will be consumed which may lead to additional computation,"
" Consider defining consumes in the component.\n Consumes: %s",
spec_consumes,
)

else:
spec_consumes = {
k: (Type(v).to_dict() if k != "additionalProperties" else v)
for k, v in consumes.items()
}
return {"additionalProperties": True}

return spec_consumes
return {
k: (Type(v).to_dict() if k != "additionalProperties" else v)
for k, v in consumes.items()
}

@classmethod
def get_spec_produces(cls):
def _get_spec_produces(cls) -> t.Mapping[str, t.Union[str, pa.DataType, bool]]:
"""Get the produces spec for the component."""
produces = cls.produces()

Expand All @@ -151,6 +104,18 @@ def get_spec_produces(cls):
for k, v in produces.items()
}

@classmethod
def get_component_spec(cls) -> ComponentSpec:
"""Return the component spec for the component."""
return ComponentSpec(
name=cls.__name__,
image=cls.image().base_image,
description=cls.__doc__ or "lightweight component",
consumes=cls._get_spec_consumes(),
produces=cls._get_spec_produces(),
args={name: arg.to_spec() for name, arg in infer_arguments(cls).items()},
)


def lightweight_component(
*args,
Expand Down
65 changes: 41 additions & 24 deletions src/fondant/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from fondant.core.manifest import Manifest
from fondant.core.schema import Field
from fondant.pipeline import Image, LightweightComponent
from fondant.pipeline.argument_inference import infer_arguments

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -185,25 +184,59 @@ def __init__(
self.resources = resources or Resources()

@classmethod
def from_component_yaml(cls, path, **kwargs) -> "ComponentOp":
def from_component_yaml(cls, path, fields=None, **kwargs) -> "ComponentOp":
if cls._is_custom_component(path):
component_dir = Path(path)
else:
component_dir = cls._get_registry_path(str(path))

component_spec = ComponentSpec.from_file(
component_dir / cls.COMPONENT_SPEC_NAME,
)

# If consumes is not defined in the pipeline, we will try to infer it
if kwargs.get("consumes", None) is None:
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
kwargs["consumes"] = cls._infer_consumes(component_spec, fields)

image = Image(
base_image=component_spec.image,
)

return cls(
image=image,
component_spec=component_spec,
component_dir=component_dir,
**kwargs,
)

@classmethod
def _infer_consumes(cls, component_spec, dataset_fields):
"""Infer the consumes section of the component spec."""
if component_spec.consumes_is_defined is False:
msg = (
"The consumes section of the component spec is not defined. "
"Cannot infer consumes."
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
)
logger.info(msg)
return None

# Component has consumes and additionalProperties, we will load all dataset columns
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering about the case here where there's both a schema and additionalProperties. But I guess this is fine for now. We might want to have more logic here once we add support for additionalProperties schemas instead of just the boolean.

if (
component_spec.consumes_is_defined
and component_spec.consumes_additional_properties
):
if dataset_fields is None:
logger.info(
"The dataset fields are not defined. Cannot infer consumes.",
)
return None

return {k: v.type.value for k, v in dataset_fields.items()}

# Component has consumes and no additionalProperties, we will load only the columns defined
# in the component spec
return {k: v.type.value for k, v in component_spec.consumes.items()}

@classmethod
def from_ref(
cls,
Expand All @@ -223,31 +256,14 @@ def from_ref(
"""
if inspect.isclass(ref) and issubclass(ref, BaseComponent):
if issubclass(ref, LightweightComponent):
name = ref.__name__
image = ref.image()
description = ref.__doc__ or "lightweight component"
spec_produces = ref.get_spec_produces()

spec_consumes = (
ref.get_spec_consumes(fields, kwargs["consumes"])
if fields
else {"additionalProperties": True}
)
component_spec = ref.get_component_spec()

component_spec = ComponentSpec(
name,
image.base_image,
description=description,
consumes=spec_consumes,
produces=spec_produces,
args={
name: arg.to_spec()
for name, arg in infer_arguments(ref).items()
},
)
# If consumes is not defined in the pipeline, we will try to infer it
if kwargs.get("consumes", None) is None:
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
kwargs["consumes"] = cls._infer_consumes(component_spec, fields)

operation = cls(
image,
ref.image(),
component_spec,
**kwargs,
)
Expand All @@ -259,6 +275,7 @@ def from_ref(
elif isinstance(ref, (str, Path)):
operation = cls.from_component_yaml(
ref,
fields,
**kwargs,
)
else:
Expand Down
6 changes: 0 additions & 6 deletions tests/core/test_manifest_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@
"images_array": "images_data",
},
},
"6": {
# Non-generic component that has a type in the produces mapping
"produces": {
"embedding_data": pa.list_(pa.float32()),
},
},
}


Expand Down
53 changes: 49 additions & 4 deletions tests/pipeline/test_lightweight_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,16 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
"image": Image.resolve_fndnt_base_image(),
"description": "lightweight component",
"consumes": {
"x": {"type": "int32"},
"y": {"type": "int32"},
"z": {"type": "int32"},
"additionalProperties": True,
},
"produces": {"x": {"type": "int32"}},
"args": {"n": {"type": "int"}},
},
"consumes": {},
"consumes": {
"x": {"type": "int32"},
"y": {"type": "int32"},
"z": {"type": "int32"},
},
"produces": {},
}
pipeline._validate_pipeline_definition(run_id="dummy-run-id")
Expand All @@ -163,6 +165,7 @@ def test_consumes_mapping_all_fields(tmp_path_factory, load_pipeline):
extra_requires=[
"fondant[component]@git+https://github.com/ml6team/fondant@main",
],
consumes={"a": pa.int32()},
produces={"a": pa.int32()},
)
class AddN(PandasTransformComponent):
Expand Down Expand Up @@ -486,3 +489,45 @@ def test_fndnt_base_image_resolution():
mock_call.return_value = "0.9"
base_image_name = Image.resolve_fndnt_base_image()
assert base_image_name == "fndnt/fondant:0.9-py3.9"


def test_component_op_python_component__():
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
@lightweight_component(
base_image="python:3.8-slim-buster",
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
extra_requires=["pandas", "dask"],
produces={"x": pa.int32(), "y": pa.int32(), "z": pa.int32()},
)
class Foo(DaskLoadComponent):
def load(self) -> dd.DataFrame:
df = pd.DataFrame(
{
"x": [1, 2, 3],
"y": [4, 5, 6],
},
index=pd.Index(["a", "b", "c"], name="id"),
)
return dd.from_pandas(df, npartitions=1)

@lightweight_component(
base_image="python:3.8-slim-buster",
extra_requires=["pandas", "dask"],
consumes={"x": pa.int32(), "y": pa.int32(), "z": pa.int32()},
produces={"x": pa.int32(), "y": pa.int32(), "z": pa.int32()},
)
class Bar(PandasTransformComponent):
def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
return dataframe

pipeline = Pipeline(
name="dummy-pipeline",
base_path="./data",
)

dataset = pipeline.read(
ref=Foo,
)

dataset = dataset.apply(
ref=Bar,
)
assert True