diff --git a/src/fondant/core/component_spec.py b/src/fondant/core/component_spec.py index 2802ad69..86b6f91c 100644 --- a/src/fondant/core/component_spec.py +++ b/src/fondant/core/component_spec.py @@ -214,6 +214,27 @@ def consumes(self) -> t.Mapping[str, Field]: }, ) + @property + def consumes_additional_properties(self) -> bool: + """Returns a boolean indicating whether the component consumes additional properties.""" + return self._specification.get("consumes", {}).get( + "additionalProperties", + False, + ) + + @property + def consumes_is_defined(self) -> bool: + """Returns a boolean indicating whether the component consumes is defined.""" + return bool(self._specification.get("consumes", {})) + + @property + def produces_additional_properties(self) -> bool: + """Returns a boolean indicating whether the component produces additional properties.""" + return self._specification.get("produces", {}).get( + "additionalProperties", + False, + ) + @property def produces(self) -> t.Mapping[str, Field]: """The fields produced by the component as an immutable mapping.""" @@ -426,13 +447,6 @@ def _inner_mapping(self, name: str) -> t.Mapping[str, Field]: if not isinstance(value, pa.DataType): continue - if not self._component_spec.is_generic(name): - msg = ( - f"Component {self._component_spec.name} does not allow specifying additional " - f"fields but received {key}." - ) - raise InvalidPipelineDefinition(msg) - if key not in spec_mapping: mapping[key] = Field(name=key, type=Type(value)) else: diff --git a/src/fondant/pipeline/lightweight_component.py b/src/fondant/pipeline/lightweight_component.py index 07bea9aa..83cc7450 100644 --- a/src/fondant/pipeline/lightweight_component.py +++ b/src/fondant/pipeline/lightweight_component.py @@ -11,7 +11,9 @@ import pyarrow as pa from fondant.component import BaseComponent, Component -from fondant.core.schema import Field, Type +from fondant.core.component_spec import ComponentSpec +from fondant.core.schema import Type +from fondant.pipeline.argument_inference import infer_arguments logger = logging.getLogger(__name__) @@ -78,68 +80,19 @@ def produces(cls) -> t.Optional[t.Dict[str, t.Any]]: pass @classmethod - def modify_spec_consumes( - cls, - spec_consumes: t.Dict[str, t.Any], - apply_consumes: t.Optional[t.Dict[str, pa.DataType]], - ): - """Modify fields based on the consumes argument in the 'apply' method.""" - if apply_consumes: - for k, v in apply_consumes.items(): - if isinstance(v, str): - spec_consumes[k] = spec_consumes.pop(v) - else: - msg = ( - f"Invalid data type for field `{k}` in the `apply_consumes` " - f"argument. Only string types are allowed." - ) - raise ValueError( - msg, - ) - return spec_consumes - - @classmethod - def get_spec_consumes( - cls, - dataset_fields: t.Mapping[str, Field], - apply_consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None, - ): - """ - Function that get the consumes spec for the component based on the dataset fields and - the apply_consumes argument. - - Args: - dataset_fields: The fields of the dataset. - apply_consumes: The consumes argument in the apply method. - - Returns: - The consumes spec for the component. - """ + def _get_spec_consumes(cls) -> t.Mapping[str, t.Union[str, pa.DataType, bool]]: + """Get the consumes spec for the component.""" consumes = cls.consumes() - if consumes is None: - # Get consumes spec from the dataset - spec_consumes = {k: v.type.to_dict() for k, v in dataset_fields.items()} - - spec_consumes = cls.modify_spec_consumes(spec_consumes, apply_consumes) - - logger.warning( - "No consumes defined. Consumes will be inferred from the dataset." - " All field will be consumed which may lead to additional computation," - " Consider defining consumes in the component.\n Consumes: %s", - spec_consumes, - ) - - else: - spec_consumes = { - k: (Type(v).to_dict() if k != "additionalProperties" else v) - for k, v in consumes.items() - } + return {"additionalProperties": True} - return spec_consumes + return { + k: (Type(v).to_dict() if k != "additionalProperties" else v) + for k, v in consumes.items() + } @classmethod - def get_spec_produces(cls): + def _get_spec_produces(cls) -> t.Mapping[str, t.Union[str, pa.DataType, bool]]: """Get the produces spec for the component.""" produces = cls.produces() @@ -151,6 +104,18 @@ def get_spec_produces(cls): for k, v in produces.items() } + @classmethod + def get_component_spec(cls) -> ComponentSpec: + """Return the component spec for the component.""" + return ComponentSpec( + name=cls.__name__, + image=cls.image().base_image, + description=cls.__doc__ or "lightweight component", + consumes=cls._get_spec_consumes(), + produces=cls._get_spec_produces(), + args={name: arg.to_spec() for name, arg in infer_arguments(cls).items()}, + ) + def lightweight_component( *args, diff --git a/src/fondant/pipeline/pipeline.py b/src/fondant/pipeline/pipeline.py index 72f20061..cb39b4ba 100644 --- a/src/fondant/pipeline/pipeline.py +++ b/src/fondant/pipeline/pipeline.py @@ -26,7 +26,6 @@ from fondant.core.manifest import Manifest from fondant.core.schema import Field from fondant.pipeline import Image, LightweightComponent -from fondant.pipeline.argument_inference import infer_arguments logger = logging.getLogger(__name__) @@ -185,18 +184,24 @@ def __init__( self.resources = resources or Resources() @classmethod - def from_component_yaml(cls, path, **kwargs) -> "ComponentOp": + def from_component_yaml(cls, path, fields=None, **kwargs) -> "ComponentOp": if cls._is_custom_component(path): component_dir = Path(path) else: component_dir = cls._get_registry_path(str(path)) + component_spec = ComponentSpec.from_file( component_dir / cls.COMPONENT_SPEC_NAME, ) + # If consumes is not defined in the pipeline, we will try to infer it + if kwargs.get("consumes") is None: + kwargs["consumes"] = cls._infer_consumes(component_spec, fields) + image = Image( base_image=component_spec.image, ) + return cls( image=image, component_spec=component_spec, @@ -204,6 +209,35 @@ def from_component_yaml(cls, path, **kwargs) -> "ComponentOp": **kwargs, ) + @classmethod + def _infer_consumes(cls, component_spec, dataset_fields): + """Infer the consumes section of the component spec.""" + if component_spec.consumes_is_defined is False: + msg = ( + "The consumes section of the component spec is not defined. " + "Can not infer consumes of the OperationSpec. Please define a consumes section " + "in the dataset interface. " + ) + logger.info(msg) + return None + + # Component has consumes and additionalProperties, we will load all dataset columns + if ( + component_spec.consumes_is_defined + and component_spec.consumes_additional_properties + ): + if dataset_fields is None: + logger.info( + "The dataset fields are not defined. Cannot infer consumes.", + ) + return None + + return {k: v.type.value for k, v in dataset_fields.items()} + + # Component has consumes and no additionalProperties, we will load only the columns defined + # in the component spec + return {k: v.type.value for k, v in component_spec.consumes.items()} + @classmethod def from_ref( cls, @@ -223,31 +257,14 @@ def from_ref( """ if inspect.isclass(ref) and issubclass(ref, BaseComponent): if issubclass(ref, LightweightComponent): - name = ref.__name__ - image = ref.image() - description = ref.__doc__ or "lightweight component" - spec_produces = ref.get_spec_produces() - - spec_consumes = ( - ref.get_spec_consumes(fields, kwargs["consumes"]) - if fields - else {"additionalProperties": True} - ) + component_spec = ref.get_component_spec() - component_spec = ComponentSpec( - name, - image.base_image, - description=description, - consumes=spec_consumes, - produces=spec_produces, - args={ - name: arg.to_spec() - for name, arg in infer_arguments(ref).items() - }, - ) + # If consumes is not defined in the pipeline, we will try to infer it + if kwargs.get("consumes") is None: + kwargs["consumes"] = cls._infer_consumes(component_spec, fields) operation = cls( - image, + ref.image(), component_spec, **kwargs, ) @@ -259,6 +276,7 @@ def from_ref( elif isinstance(ref, (str, Path)): operation = cls.from_component_yaml( ref, + fields, **kwargs, ) else: diff --git a/tests/core/test_manifest_evolution.py b/tests/core/test_manifest_evolution.py index b8e03ef0..c06ac97a 100644 --- a/tests/core/test_manifest_evolution.py +++ b/tests/core/test_manifest_evolution.py @@ -52,12 +52,6 @@ "images_array": "images_data", }, }, - "6": { - # Non-generic component that has a type in the produces mapping - "produces": { - "embedding_data": pa.list_(pa.float32()), - }, - }, } diff --git a/tests/pipeline/test_lightweight_component.py b/tests/pipeline/test_lightweight_component.py index bb464453..900636bf 100644 --- a/tests/pipeline/test_lightweight_component.py +++ b/tests/pipeline/test_lightweight_component.py @@ -142,14 +142,16 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: "image": Image.resolve_fndnt_base_image(), "description": "lightweight component", "consumes": { - "x": {"type": "int32"}, - "y": {"type": "int32"}, - "z": {"type": "int32"}, + "additionalProperties": True, }, "produces": {"x": {"type": "int32"}}, "args": {"n": {"type": "int"}}, }, - "consumes": {}, + "consumes": { + "x": {"type": "int32"}, + "y": {"type": "int32"}, + "z": {"type": "int32"}, + }, "produces": {}, } pipeline._validate_pipeline_definition(run_id="dummy-run-id") @@ -163,6 +165,7 @@ def test_consumes_mapping_all_fields(tmp_path_factory, load_pipeline): extra_requires=[ "fondant[component]@git+https://github.com/ml6team/fondant@main", ], + consumes={"a": pa.int32()}, produces={"a": pa.int32()}, ) class AddN(PandasTransformComponent): @@ -486,3 +489,96 @@ def test_fndnt_base_image_resolution(): mock_call.return_value = "0.9" base_image_name = Image.resolve_fndnt_base_image() assert base_image_name == "fndnt/fondant:0.9-py3.9" + + +def test_infer_consumes_if_not_defined(load_pipeline): + """ + Test that the consumes mapping is inferred when not defined in dataset interface. + All columns of the dataset are consumed. + """ + _, dataset, _, _ = load_pipeline + + @lightweight_component( + base_image="python:3.10-slim-buster", + extra_requires=["pandas", "dask"], + consumes={"x": pa.int32(), "y": pa.int32(), "z": pa.int32()}, + produces={"x": pa.int32(), "y": pa.int32(), "z": pa.int32()}, + ) + class Bar(PandasTransformComponent): + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: + return dataframe + + dataset = dataset.apply( + ref=Bar, + ) + + operation_spec_dict = dataset.pipeline._graph["bar"][ + "operation" + ].operation_spec.to_dict() + assert operation_spec_dict == { + "consumes": { + "x": {"type": "int32"}, + "y": {"type": "int32"}, + "z": {"type": "int32"}, + }, + "produces": {}, + "specification": { + "consumes": { + "x": {"type": "int32"}, + "y": {"type": "int32"}, + "z": {"type": "int32"}, + }, + "description": "lightweight component", + "image": "python:3.10-slim-buster", + "name": "Bar", + "produces": { + "x": {"type": "int32"}, + "y": {"type": "int32"}, + "z": {"type": "int32"}, + }, + }, + } + + +def test_infer_consumes_if_additional_properties_true(load_pipeline): + """ + Test when additional properties is true (no consumes defined in the lightweight component), + the consumes is inferred from the dataset interface. + """ + _, dataset, _, _ = load_pipeline + + @lightweight_component( + base_image="python:3.10-slim-buster", + extra_requires=["pandas", "dask"], + produces={"x": pa.int32(), "y": pa.int32(), "z": pa.int32()}, + ) + class Bar(PandasTransformComponent): + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: + return dataframe + + dataset = dataset.apply( + ref=Bar, + ) + + operation_spec_dict = dataset.pipeline._graph["bar"][ + "operation" + ].operation_spec.to_dict() + assert operation_spec_dict == { + "consumes": { + "x": {"type": "int32"}, + "y": {"type": "int32"}, + "z": {"type": "int32"}, + }, + "produces": {}, + "specification": { + "consumes": {"additionalProperties": True}, + "description": "lightweight component", + "image": "python:3.10-slim-buster", + "name": "Bar", + "produces": { + "x": {"type": "int32"}, + "y": {"type": "int32"}, + "z": {"type": "int32"}, + }, + }, + } diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 6d08f77f..61889f63 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -590,3 +590,87 @@ def test_invoked_field_schema_raise_exception(): with pytest.raises(InvalidPipelineDefinition, match=expected_error_msg): pipeline.validate("pipeline-id") + + +@pytest.mark.parametrize( + "valid_pipeline_example", + [ + ( + "example_1", + [ + "first_component", + "second_component", + "third_component", + "fourth_component", + ], + ), + ], +) +def test_infer_consumes_if_not_defined( + default_pipeline_args, + valid_pipeline_example, + tmp_path, + monkeypatch, +): + """Test that a valid pipeline definition can be compiled without errors.""" + example_dir, component_names = valid_pipeline_example + component_args = {"storage_args": "a dummy string arg"} + components_path = Path(valid_pipeline_path / example_dir) + + pipeline = Pipeline(**default_pipeline_args) + + # override the default package_path with temporary path to avoid the creation of artifacts + monkeypatch.setattr(pipeline, "package_path", str(tmp_path / "test_pipeline.tgz")) + + dataset = pipeline.read( + Path(components_path / component_names[0]), + arguments=component_args, + produces={"images_array": pa.binary()}, + ) + + # Empty consumes & additionalProperties=False -> infer component spec defined columns + assert list(dataset.fields.keys()) == ["images_array"] + dataset = dataset.apply( + Path(components_path / component_names[1]), + arguments=component_args, + ) + + assert dataset.pipeline._graph["second_component"][ + "operation" + ].operation_spec.to_dict()["consumes"] == { + "images_data": {"type": "binary"}, + } + + # Empty consumes, additionalProperties=False, two consumes fields in component spec defined + assert list(dataset.fields.keys()) == ["images_array", "embeddings_data"] + dataset = dataset.apply( + Path(components_path / component_names[2]), + arguments=component_args, + ) + + assert dataset.pipeline._graph["third_component"][ + "operation" + ].operation_spec.to_dict()["consumes"] == { + "images_data": {"type": "binary"}, + "embeddings_data": {"items": {"type": "float32"}, "type": "array"}, + } + + # Additional properties is true, no consumes field in dataset apply + # -> infer operation spec, load all columns of dataset (images_data, embeddings_data) + assert list(dataset.fields.keys()) == [ + "images_array", + "embeddings_data", + "images_data", + ] + dataset = dataset.apply( + Path(components_path / component_names[3]), + arguments=component_args, + ) + + assert dataset.pipeline._graph["fourth_component"][ + "operation" + ].operation_spec.to_dict()["consumes"] == { + "images_data": {"type": "binary"}, + "images_array": {"type": "binary"}, + "embeddings_data": {"items": {"type": "float32"}, "type": "array"}, + }