Skip to content

Commit

Permalink
Infer consume operation if not present in dataset interface (#859)
Browse files Browse the repository at this point in the history
Basic implementation, I still have to add tests. I wanted to get some
feedback first.

- Added a method to the lightweight components to generate a
`ComponentSpec` based on the attributes.
- Added a method in the pipeline to infer the consumption based on the
`ComponentSpec`.
In cases where a user hasn't specified a `consume` in the pipeline
operations, we now infer this. If a component spec contains a `consumes`
section and `additionalProperties` are set to true, we load all columns.
If `additionalProperties` is set to false, we limit the columns defined
in the component spec.

Fix #836
  • Loading branch information
mrchtr authored Feb 20, 2024
1 parent 71020ef commit 0de3f5c
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 99 deletions.
28 changes: 21 additions & 7 deletions src/fondant/core/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,27 @@ def consumes(self) -> t.Mapping[str, Field]:
},
)

@property
def consumes_additional_properties(self) -> bool:
"""Returns a boolean indicating whether the component consumes additional properties."""
return self._specification.get("consumes", {}).get(
"additionalProperties",
False,
)

@property
def consumes_is_defined(self) -> bool:
"""Returns a boolean indicating whether the component consumes is defined."""
return bool(self._specification.get("consumes", {}))

@property
def produces_additional_properties(self) -> bool:
"""Returns a boolean indicating whether the component produces additional properties."""
return self._specification.get("produces", {}).get(
"additionalProperties",
False,
)

@property
def produces(self) -> t.Mapping[str, Field]:
"""The fields produced by the component as an immutable mapping."""
Expand Down Expand Up @@ -414,13 +435,6 @@ def _inner_mapping(self, name: str) -> t.Mapping[str, Field]:
if not isinstance(value, pa.DataType):
continue

if not self._component_spec.is_generic(name):
msg = (
f"Component {self._component_spec.name} does not allow specifying additional "
f"fields but received {key}."
)
raise InvalidPipelineDefinition(msg)

if key not in spec_mapping:
mapping[key] = Field(name=key, type=Type(value))
else:
Expand Down
81 changes: 23 additions & 58 deletions src/fondant/pipeline/lightweight_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import pyarrow as pa

from fondant.component import BaseComponent, Component
from fondant.core.schema import Field, Type
from fondant.core.component_spec import ComponentSpec
from fondant.core.schema import Type
from fondant.pipeline.argument_inference import infer_arguments

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,68 +80,19 @@ def produces(cls) -> t.Optional[t.Dict[str, t.Any]]:
pass

@classmethod
def modify_spec_consumes(
cls,
spec_consumes: t.Dict[str, t.Any],
apply_consumes: t.Optional[t.Dict[str, pa.DataType]],
):
"""Modify fields based on the consumes argument in the 'apply' method."""
if apply_consumes:
for k, v in apply_consumes.items():
if isinstance(v, str):
spec_consumes[k] = spec_consumes.pop(v)
else:
msg = (
f"Invalid data type for field `{k}` in the `apply_consumes` "
f"argument. Only string types are allowed."
)
raise ValueError(
msg,
)
return spec_consumes

@classmethod
def get_spec_consumes(
cls,
dataset_fields: t.Mapping[str, Field],
apply_consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None,
):
"""
Function that get the consumes spec for the component based on the dataset fields and
the apply_consumes argument.
Args:
dataset_fields: The fields of the dataset.
apply_consumes: The consumes argument in the apply method.
Returns:
The consumes spec for the component.
"""
def _get_spec_consumes(cls) -> t.Mapping[str, t.Union[str, pa.DataType, bool]]:
"""Get the consumes spec for the component."""
consumes = cls.consumes()

if consumes is None:
# Get consumes spec from the dataset
spec_consumes = {k: v.type.to_dict() for k, v in dataset_fields.items()}

spec_consumes = cls.modify_spec_consumes(spec_consumes, apply_consumes)

logger.warning(
"No consumes defined. Consumes will be inferred from the dataset."
" All field will be consumed which may lead to additional computation,"
" Consider defining consumes in the component.\n Consumes: %s",
spec_consumes,
)

else:
spec_consumes = {
k: (Type(v).to_dict() if k != "additionalProperties" else v)
for k, v in consumes.items()
}
return {"additionalProperties": True}

return spec_consumes
return {
k: (Type(v).to_dict() if k != "additionalProperties" else v)
for k, v in consumes.items()
}

@classmethod
def get_spec_produces(cls):
def _get_spec_produces(cls) -> t.Mapping[str, t.Union[str, pa.DataType, bool]]:
"""Get the produces spec for the component."""
produces = cls.produces()

Expand All @@ -151,6 +104,18 @@ def get_spec_produces(cls):
for k, v in produces.items()
}

@classmethod
def get_component_spec(cls) -> ComponentSpec:
"""Return the component spec for the component."""
return ComponentSpec(
name=cls.__name__,
image=cls.image().base_image,
description=cls.__doc__ or "lightweight component",
consumes=cls._get_spec_consumes(),
produces=cls._get_spec_produces(),
args={name: arg.to_spec() for name, arg in infer_arguments(cls).items()},
)


def lightweight_component(
*args,
Expand Down
66 changes: 42 additions & 24 deletions src/fondant/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from fondant.core.manifest import Manifest
from fondant.core.schema import Field
from fondant.pipeline import Image, LightweightComponent
from fondant.pipeline.argument_inference import infer_arguments

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -177,25 +176,60 @@ def __init__(
self.resources = resources or Resources()

@classmethod
def from_component_yaml(cls, path, **kwargs) -> "ComponentOp":
def from_component_yaml(cls, path, fields=None, **kwargs) -> "ComponentOp":
if cls._is_custom_component(path):
component_dir = Path(path)
else:
component_dir = cls._get_registry_path(str(path))

component_spec = ComponentSpec.from_file(
component_dir / cls.COMPONENT_SPEC_NAME,
)

# If consumes is not defined in the pipeline, we will try to infer it
if kwargs.get("consumes") is None:
kwargs["consumes"] = cls._infer_consumes(component_spec, fields)

image = Image(
base_image=component_spec.image,
)

return cls(
image=image,
component_spec=component_spec,
component_dir=component_dir,
**kwargs,
)

@classmethod
def _infer_consumes(cls, component_spec, dataset_fields):
"""Infer the consumes section of the component spec."""
if component_spec.consumes_is_defined is False:
msg = (
"The consumes section of the component spec is not defined. "
"Can not infer consumes of the OperationSpec. Please define a consumes section "
"in the dataset interface. "
)
logger.info(msg)
return None

# Component has consumes and additionalProperties, we will load all dataset columns
if (
component_spec.consumes_is_defined
and component_spec.consumes_additional_properties
):
if dataset_fields is None:
logger.info(
"The dataset fields are not defined. Cannot infer consumes.",
)
return None

return {k: v.type.value for k, v in dataset_fields.items()}

# Component has consumes and no additionalProperties, we will load only the columns defined
# in the component spec
return {k: v.type.value for k, v in component_spec.consumes.items()}

@classmethod
def from_ref(
cls,
Expand All @@ -215,31 +249,14 @@ def from_ref(
"""
if inspect.isclass(ref) and issubclass(ref, BaseComponent):
if issubclass(ref, LightweightComponent):
name = ref.__name__
image = ref.image()
description = ref.__doc__ or "lightweight component"
spec_produces = ref.get_spec_produces()

spec_consumes = (
ref.get_spec_consumes(fields, kwargs["consumes"])
if fields
else {"additionalProperties": True}
)
component_spec = ref.get_component_spec()

component_spec = ComponentSpec(
name,
image.base_image,
description=description,
consumes=spec_consumes,
produces=spec_produces,
args={
name: arg.to_spec()
for name, arg in infer_arguments(ref).items()
},
)
# If consumes is not defined in the pipeline, we will try to infer it
if kwargs.get("consumes") is None:
kwargs["consumes"] = cls._infer_consumes(component_spec, fields)

operation = cls(
image,
ref.image(),
component_spec,
**kwargs,
)
Expand All @@ -251,6 +268,7 @@ def from_ref(
elif isinstance(ref, (str, Path)):
operation = cls.from_component_yaml(
ref,
fields,
**kwargs,
)
else:
Expand Down
6 changes: 0 additions & 6 deletions tests/core/test_manifest_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@
"images_array": "images_data",
},
},
"6": {
# Non-generic component that has a type in the produces mapping
"produces": {
"embedding_data": pa.list_(pa.float32()),
},
},
}


Expand Down
Loading

0 comments on commit 0de3f5c

Please sign in to comment.