diff --git a/src/fondant/component/executor.py b/src/fondant/component/executor.py index 95c67d170..3e2b8deab 100644 --- a/src/fondant/component/executor.py +++ b/src/fondant/component/executor.py @@ -514,7 +514,11 @@ def optional_fondant_arguments() -> t.List[str]: return ["input_manifest_path", "input_partition_rows"] @staticmethod - def wrap_transform(transform: t.Callable, *, spec: ComponentSpec) -> t.Callable: + def wrap_transform( + transform: t.Callable, + *, + operation_spec: OperationSpec, + ) -> t.Callable: """Factory that creates a function to wrap the component transform function. The wrapper: - Removes extra columns from the returned dataframe which are not defined in the component spec `produces` section @@ -524,7 +528,7 @@ def wrap_transform(transform: t.Callable, *, spec: ComponentSpec) -> t.Callable: Args: transform: Transform method to wrap - spec: Component specification to base behavior on + operation_spec: Operation specification to base behavior on """ def wrapped_transform(dataframe: pd.DataFrame) -> pd.DataFrame: @@ -532,7 +536,7 @@ def wrapped_transform(dataframe: pd.DataFrame) -> pd.DataFrame: dataframe = transform(dataframe) # Drop columns not in specification - columns = [name for name, field in spec.produces.items()] + columns = [name for name, field in operation_spec.inner_produces.items()] return dataframe[columns] @@ -560,11 +564,14 @@ def _execute_component( # Create meta dataframe with expected format meta_dict = {"id": pd.Series(dtype="object")} - for field_name, field in self.spec.produces.items(): + for field_name, field in self.operation_spec.inner_produces.items(): meta_dict[field_name] = pd.Series(dtype=pd.ArrowDtype(field.type.value)) meta_df = pd.DataFrame(meta_dict).set_index("id") - wrapped_transform = self.wrap_transform(component.transform, spec=self.spec) + wrapped_transform = self.wrap_transform( + component.transform, + operation_spec=self.operation_spec, + ) # Call the component transform method for each partition dataframe = dataframe.map_partitions( diff --git a/tests/component/test_component.py b/tests/component/test_component.py index 6464ee89a..42e9b6726 100644 --- a/tests/component/test_component.py +++ b/tests/component/test_component.py @@ -21,7 +21,7 @@ ExecutorFactory, PandasTransformExecutor, ) -from fondant.core.component_spec import ComponentSpec +from fondant.core.component_spec import ComponentSpec, OperationSpec from fondant.core.manifest import Manifest, Metadata from fondant.pipeline import ComponentOp @@ -403,6 +403,7 @@ def test_wrap_transform(): }, }, "produces": { + "additionalProperties": True, "caption_text": { "type": "string", }, @@ -430,7 +431,15 @@ def transform(dataframe: pd.DataFrame) -> pd.DataFrame: ] return dataframe - wrapped_transform = PandasTransformExecutor.wrap_transform(transform, spec=spec) + overwrite_produces = { + "caption_text": pa.string(), + "image_height": pa.int16(), + } + + wrapped_transform = PandasTransformExecutor.wrap_transform( + transform, + operation_spec=OperationSpec(spec, produces=overwrite_produces), + ) output_df = wrapped_transform(input_df) # Check column flattening, trimming, and ordering