From 3d359dc2dd9890546fd512eecb23c88c32a3e85b Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Tue, 13 Feb 2024 17:32:54 +0100 Subject: [PATCH 1/6] Move dask client configuration to component class and set local cluster as default --- src/fondant/component/component.py | 33 +++++++++++-- src/fondant/component/executor.py | 48 +------------------ src/fondant/core/component_spec.py | 12 ----- src/fondant/pipeline/pipeline.py | 28 ----------- tests/component/test_component.py | 1 - .../component_specs/kubeflow_component.yaml | 12 ----- 6 files changed, 29 insertions(+), 105 deletions(-) diff --git a/src/fondant/component/component.py b/src/fondant/component/component.py index 3cc43cd5..f3e5b0ac 100644 --- a/src/fondant/component/component.py +++ b/src/fondant/component/component.py @@ -1,10 +1,12 @@ """This module defines interfaces which components should implement to be executed by fondant.""" - +import os import typing as t from abc import abstractmethod +import dask import dask.dataframe as dd import pandas as pd +from dask.distributed import Client, LocalCluster class BaseComponent: @@ -27,7 +29,28 @@ def teardown(self) -> None: """Method called after the component has been executed.""" -class DaskLoadComponent(BaseComponent): +class DaskComponent(BaseComponent): + """Component built on Dask.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dask_client() + + def dask_client(self) -> Client: + dask.config.set({"dataframe.convert-string": False}) + # worker.daemon is set to false because creating a worker process in daemon + # mode is not possible in our docker container setup. + dask.config.set({"distributed.worker.daemon": False}) + + local_cluster = LocalCluster( + processes=True, + n_workers=os.cpu_count(), + threads_per_worker=1, + ) + return Client(local_cluster) + + +class DaskLoadComponent(DaskComponent): """Component that loads data and returns a Dask DataFrame.""" @abstractmethod @@ -35,7 +58,7 @@ def load(self) -> dd.DataFrame: pass -class DaskTransformComponent(BaseComponent): +class DaskTransformComponent(DaskComponent): """Component that transforms an incoming Dask DataFrame.""" @abstractmethod @@ -49,7 +72,7 @@ def transform(self, dataframe: dd.DataFrame) -> dd.DataFrame: """ -class DaskWriteComponent(BaseComponent): +class DaskWriteComponent(DaskComponent): """Component that accepts a Dask DataFrame and writes its contents.""" @abstractmethod @@ -57,7 +80,7 @@ def write(self, dataframe: dd.DataFrame) -> None: pass -class PandasTransformComponent(BaseComponent): +class PandasTransformComponent(DaskComponent): """Component that transforms the incoming dataset partition per partition as a pandas DataFrame. """ diff --git a/src/fondant/component/executor.py b/src/fondant/component/executor.py index 271e1fc1..9fee0990 100644 --- a/src/fondant/component/executor.py +++ b/src/fondant/component/executor.py @@ -7,16 +7,13 @@ import argparse import json import logging -import os import typing as t from abc import abstractmethod from distutils.util import strtobool from pathlib import Path -import dask import dask.dataframe as dd import pandas as pd -from dask.distributed import Client, LocalCluster from fsspec import open as fs_open from fondant.component import ( @@ -30,7 +27,6 @@ from fondant.core.component_spec import Argument, OperationSpec from fondant.core.manifest import Manifest, Metadata -dask.config.set({"dataframe.convert-string": False}) logger = logging.getLogger(__name__) @@ -49,9 +45,6 @@ class Executor(t.Generic[Component]): partition of dataframe. Partitions are divided based on this number (n rows per partition). Set to None for no row limit. - cluster_type: The type of cluster to use for distributed execution - (default is "local"). - client_kwargs: Additional keyword arguments dict which will be used to initialise the dask client, allowing for advanced configuration. previous_index: The name of the index column of the previous component. Used to remove all previous fields if the component changes the index @@ -67,8 +60,6 @@ def __init__( metadata: t.Dict[str, t.Any], user_arguments: t.Dict[str, t.Any], input_partition_rows: int, - cluster_type: t.Optional[str] = None, - client_kwargs: t.Optional[dict] = None, previous_index: t.Optional[str] = None, ) -> None: self.operation_spec = operation_spec @@ -80,33 +71,6 @@ def __init__( self.input_partition_rows = input_partition_rows self.previous_index = previous_index - if cluster_type == "local": - client_kwargs = client_kwargs or { - "processes": True, - "n_workers": os.cpu_count(), - "threads_per_worker": 1, - } - - logger.info(f"Initialize local dask cluster with arguments {client_kwargs}") - - # Additional dask configuration have to be set before initialising the client - # worker.daemon is set to false because creating a worker process in daemon - # mode is not possible in our docker container setup. - dask.config.set({"distributed.worker.daemon": False}) - - local_cluster = LocalCluster(**client_kwargs, silence_logs=logging.ERROR) - self.client = Client(local_cluster) - - elif cluster_type == "distributed": - msg = "The usage of the Dask distributed client is not supported yet." - raise NotImplementedError(msg) - else: - logger.info( - "Dask default local mode will be used for further executions." - "Our current supported options are limited to 'local' and 'default'.", - ) - self.client = None - @classmethod def from_args(cls) -> "Executor": """Create an executor from a passed argument containing the specification as a dict.""" @@ -114,8 +78,6 @@ def from_args(cls) -> "Executor": parser.add_argument("--operation_spec", type=json.loads) parser.add_argument("--cache", type=lambda x: bool(strtobool(x))) parser.add_argument("--input_partition_rows", type=int) - parser.add_argument("--cluster_type", type=str) - parser.add_argument("--client_kwargs", type=json.loads) args, _ = parser.parse_known_args() if "operation_spec" not in args: @@ -128,8 +90,6 @@ def from_args(cls) -> "Executor": operation_spec, cache=args.cache, input_partition_rows=args.input_partition_rows, - cluster_type=args.cluster_type, - client_kwargs=args.client_kwargs, ) @classmethod @@ -139,8 +99,6 @@ def from_spec( *, cache: bool, input_partition_rows: int, - cluster_type: t.Optional[str], - client_kwargs: t.Optional[dict], ) -> "Executor": """Create an executor from a component spec.""" args_dict = vars(cls._add_and_parse_args(operation_spec)) @@ -149,8 +107,6 @@ def from_spec( "operation_spec", "input_partition_rows", "cache", - "cluster_type", - "client_kwargs", "consumes", "produces", ]: @@ -169,8 +125,6 @@ def from_spec( metadata=metadata, user_arguments=args_dict, input_partition_rows=input_partition_rows, - cluster_type=cluster_type, - client_kwargs=client_kwargs, previous_index=operation_spec.previous_index, ) @@ -262,7 +216,7 @@ def _write_data( operation_spec=self.operation_spec, ) - data_writer.write_dataframe(dataframe, self.client) + data_writer.write_dataframe(dataframe) def _get_cache_reference_content(self) -> t.Union[str, None]: """ diff --git a/src/fondant/core/component_spec.py b/src/fondant/core/component_spec.py index 2802ad69..5cf7d05b 100644 --- a/src/fondant/core/component_spec.py +++ b/src/fondant/core/component_spec.py @@ -290,18 +290,6 @@ def default_arguments(self) -> t.Dict[str, Argument]: type=bool, default=True, ), - "cluster_type": Argument( - name="cluster_type", - description="The cluster type to use for the execution", - type=str, - default="default", - ), - "client_kwargs": Argument( - name="client_kwargs", - description="Keyword arguments to pass to the Dask client", - type=dict, - default={}, - ), "metadata": Argument( name="metadata", description="Metadata arguments containing the run id and base path", diff --git a/src/fondant/pipeline/pipeline.py b/src/fondant/pipeline/pipeline.py index 72f20061..4ec40abd 100644 --- a/src/fondant/pipeline/pipeline.py +++ b/src/fondant/pipeline/pipeline.py @@ -126,8 +126,6 @@ class ComponentOp: input_partition_rows: The number of rows to load per partition. Set to override the automatic partitioning resources: The resources to assign to the operation. - cluster_type: The type of cluster to use for distributed execution (default is "local"). - client_kwargs: Keyword arguments used to initialise the dask client. Note: - A Fondant Component operation is created by defining a Fondant Component and its input @@ -146,8 +144,6 @@ def __init__( arguments: t.Optional[t.Dict[str, t.Any]] = None, input_partition_rows: t.Optional[t.Union[str, int]] = None, cache: t.Optional[bool] = True, - cluster_type: t.Optional[str] = "default", - client_kwargs: t.Optional[dict] = None, resources: t.Optional[Resources] = None, component_dir: t.Optional[Path] = None, ) -> None: @@ -155,8 +151,6 @@ def __init__( self.component_spec = component_spec self.input_partition_rows = input_partition_rows self.cache = self._configure_caching_from_image_tag(cache) - self.cluster_type = cluster_type - self.client_kwargs = client_kwargs self.component_dir = component_dir self.operation_spec = OperationSpec( @@ -172,8 +166,6 @@ def __init__( for key, value in { "input_partition_rows": input_partition_rows, "cache": self.cache, - "cluster_type": cluster_type, - "client_kwargs": client_kwargs, "operation_spec": self.operation_spec.to_json(), }.items() if value is not None @@ -367,8 +359,6 @@ def get_nested_dict_hash(input_dict): "number_of_accelerators": self.resources.accelerator_number, "accelerator_name": self.resources.accelerator_name, "node_pool_name": self.resources.node_pool_name, - "cluster_type": self.cluster_type, - "client_kwargs": self.client_kwargs, } if previous_component_cache is not None: @@ -429,8 +419,6 @@ def read( input_partition_rows: t.Optional[t.Union[int, str]] = None, resources: t.Optional[Resources] = None, cache: t.Optional[bool] = True, - cluster_type: t.Optional[str] = "default", - client_kwargs: t.Optional[dict] = None, ) -> "Dataset": """ Read data using the provided component. @@ -447,8 +435,6 @@ def read( automatic partitioning resources: The resources to assign to the operation. cache: Set to False to disable caching, True by default. - cluster_type: The type of cluster to use for distributed execution (default is "local"). - client_kwargs: Keyword arguments used to initialise the Dask client. Returns: An intermediate dataset. @@ -466,8 +452,6 @@ def read( input_partition_rows=input_partition_rows, resources=resources, cache=cache, - cluster_type=cluster_type, - client_kwargs=client_kwargs, ) manifest = Manifest.create( pipeline_name=self.name, @@ -650,8 +634,6 @@ def apply( input_partition_rows: t.Optional[t.Union[int, str]] = None, resources: t.Optional[Resources] = None, cache: t.Optional[bool] = True, - cluster_type: t.Optional[str] = "default", - client_kwargs: t.Optional[dict] = None, ) -> "Dataset": """ Apply the provided component on the dataset. @@ -736,8 +718,6 @@ def apply( automatic partitioning resources: The resources to assign to the operation. cache: Set to False to disable caching, True by default. - cluster_type: The type of cluster to use for distributed execution (default is "local"). - client_kwargs: Keyword arguments used to initialise the Dask client. Returns: An intermediate dataset. @@ -751,8 +731,6 @@ def apply( input_partition_rows=input_partition_rows, resources=resources, cache=cache, - cluster_type=cluster_type, - client_kwargs=client_kwargs, ) return self._apply(operation) @@ -766,8 +744,6 @@ def write( input_partition_rows: t.Optional[t.Union[int, str]] = None, resources: t.Optional[Resources] = None, cache: t.Optional[bool] = True, - cluster_type: t.Optional[str] = "default", - client_kwargs: t.Optional[dict] = None, ) -> None: """ Write the dataset using the provided component. @@ -784,8 +760,6 @@ def write( automatic partitioning resources: The resources to assign to the operation. cache: Set to False to disable caching, True by default. - cluster_type: The type of cluster to use for distributed execution (default is "local"). - client_kwargs: Keyword arguments used to initialise the Dask client. Returns: An intermediate dataset. @@ -798,7 +772,5 @@ def write( input_partition_rows=input_partition_rows, resources=resources, cache=cache, - cluster_type=cluster_type, - client_kwargs=client_kwargs, ) self._apply(operation) diff --git a/tests/component/test_component.py b/tests/component/test_component.py index 8a146e9f..18917a4d 100644 --- a/tests/component/test_component.py +++ b/tests/component/test_component.py @@ -187,7 +187,6 @@ def test_run_with_cache(metadata, monkeypatch): "3.14", "--override_default_arg_with_none", "None", - "--cluster_type" "local" "--client_kwargs" "{}", ] class MyExecutor(Executor): diff --git a/tests/core/examples/component_specs/kubeflow_component.yaml b/tests/core/examples/component_specs/kubeflow_component.yaml index f19f5adc..73c9226c 100644 --- a/tests/core/examples/component_specs/kubeflow_component.yaml +++ b/tests/core/examples/component_specs/kubeflow_component.yaml @@ -19,14 +19,6 @@ components: parameterType: BOOLEAN description: Set to False to disable caching, True by default. defaultValue: true - cluster_type: - parameterType: STRING - description: The cluster type to use for the execution - defaultValue: default - client_kwargs: - parameterType: STRUCT - description: Keyword arguments to pass to the Dask client - defaultValue: {} metadata: parameterType: STRING description: Metadata arguments containing the run id and base path @@ -65,10 +57,6 @@ root: componentInputParameter: input_partition_rows cache: componentInputParameter: cache - cluster_type: - componentInputParameter: cluster_type - client_kwargs: - componentInputParameter: client_kwargs metadata: componentInputParameter: metadata output_manifest_path: From 43eb1f536ff6a420fd9bf32b30d19ab6eab965f6 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Wed, 14 Feb 2024 10:52:54 +0100 Subject: [PATCH 2/6] Update components --- src/fondant/component/component.py | 17 ++++++++--------- .../components/caption_images/src/main.py | 1 + src/fondant/components/chunk_text/src/main.py | 1 + src/fondant/components/crop_images/src/main.py | 1 + .../components/download_images/src/main.py | 4 +--- src/fondant/components/embed_text/src/main.py | 1 + .../filter_image_resolution/src/main.py | 1 + .../components/filter_language/src/main.py | 1 + .../components/filter_text_length/src/main.py | 1 + .../components/generate_minhash/src/main.py | 1 + .../components/index_aws_opensearch/src/main.py | 1 + src/fondant/components/index_qdrant/src/main.py | 1 + .../components/index_weaviate/src/main.py | 1 + .../components/load_from_csv/src/main.py | 1 + .../components/load_from_files/src/main.py | 1 + .../components/load_from_hf_hub/src/main.py | 1 + .../components/load_from_parquet/src/main.py | 1 + .../components/load_from_pdf/src/main.py | 1 + .../components/resize_images/src/main.py | 1 + .../retrieve_laion_by_embedding/src/main.py | 1 + .../retrieve_laion_by_prompt/src/main.py | 1 + .../components/segment_images/src/main.py | 1 + .../components/write_to_file/src/main.py | 1 + .../components/write_to_hf_hub/src/main.py | 1 + 24 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/fondant/component/component.py b/src/fondant/component/component.py index f3e5b0ac..b66bb697 100644 --- a/src/fondant/component/component.py +++ b/src/fondant/component/component.py @@ -18,10 +18,7 @@ class BaseComponent: **kwargs: The provided user arguments are passed in as keyword arguments """ - def __init__( - self, - **kwargs, - ): + def __init__(self): self.consumes = None self.produces = None @@ -33,21 +30,23 @@ class DaskComponent(BaseComponent): """Component built on Dask.""" def __init__(self, **kwargs): - super().__init__(**kwargs) - self.dask_client() + super().__init__() - def dask_client(self) -> Client: + # don't assume every object is a string dask.config.set({"dataframe.convert-string": False}) # worker.daemon is set to false because creating a worker process in daemon # mode is not possible in our docker container setup. dask.config.set({"distributed.worker.daemon": False}) - local_cluster = LocalCluster( + self.dask_client() + + def dask_client(self) -> Client: + cluster = LocalCluster( processes=True, n_workers=os.cpu_count(), threads_per_worker=1, ) - return Client(local_cluster) + return Client(cluster) class DaskLoadComponent(DaskComponent): diff --git a/src/fondant/components/caption_images/src/main.py b/src/fondant/components/caption_images/src/main.py index 90d4cccf..234aa28a 100644 --- a/src/fondant/components/caption_images/src/main.py +++ b/src/fondant/components/caption_images/src/main.py @@ -77,6 +77,7 @@ def __init__( batch_size: int, max_new_tokens: int, ): + super().__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Device: {self.device}") diff --git a/src/fondant/components/chunk_text/src/main.py b/src/fondant/components/chunk_text/src/main.py index f4a4b11e..7695257d 100644 --- a/src/fondant/components/chunk_text/src/main.py +++ b/src/fondant/components/chunk_text/src/main.py @@ -59,6 +59,7 @@ def __init__( code_splitter for more information on supported languages. """ + super().__init__() self.chunk_strategy = chunk_strategy self.chunk_kwargs = chunk_kwargs self.chunker = self._get_chunker_class(chunk_strategy) diff --git a/src/fondant/components/crop_images/src/main.py b/src/fondant/components/crop_images/src/main.py index 35e5aa29..8661bfaa 100644 --- a/src/fondant/components/crop_images/src/main.py +++ b/src/fondant/components/crop_images/src/main.py @@ -41,6 +41,7 @@ def __init__( cropping_threshold (int): threshold parameter used for detecting borders padding (int): padding for the image cropping. """ + super().__init__() self.cropping_threshold = cropping_threshold self.padding = padding diff --git a/src/fondant/components/download_images/src/main.py b/src/fondant/components/download_images/src/main.py index 69aaca60..049946b6 100644 --- a/src/fondant/components/download_images/src/main.py +++ b/src/fondant/components/download_images/src/main.py @@ -7,7 +7,6 @@ import logging import typing as t -import dask import httpx import pandas as pd from fondant.component import PandasTransformComponent @@ -15,8 +14,6 @@ logger = logging.getLogger(__name__) -dask.config.set(scheduler="processes") - class DownloadImagesComponent(PandasTransformComponent): """Component that downloads images based on URLs.""" @@ -50,6 +47,7 @@ def __init__( Returns: Dask dataframe """ + super().__init__() self.timeout = timeout self.retries = retries self.n_connections = n_connections diff --git a/src/fondant/components/embed_text/src/main.py b/src/fondant/components/embed_text/src/main.py index 8f4e520a..f6971865 100644 --- a/src/fondant/components/embed_text/src/main.py +++ b/src/fondant/components/embed_text/src/main.py @@ -30,6 +30,7 @@ def __init__( api_keys: dict, auth_kwargs: dict, ): + super().__init__() to_env_vars(api_keys) self.embedding_model = self.get_embedding_model( diff --git a/src/fondant/components/filter_image_resolution/src/main.py b/src/fondant/components/filter_image_resolution/src/main.py index 099a1d44..befb0a82 100644 --- a/src/fondant/components/filter_image_resolution/src/main.py +++ b/src/fondant/components/filter_image_resolution/src/main.py @@ -24,6 +24,7 @@ def __init__( min_image_dim: minimum image dimension. max_aspect_ratio: maximum aspect ratio. """ + super().__init__() self.min_image_dim = min_image_dim self.max_aspect_ratio = max_aspect_ratio diff --git a/src/fondant/components/filter_language/src/main.py b/src/fondant/components/filter_language/src/main.py index 8c9f52cb..e305bdbb 100644 --- a/src/fondant/components/filter_language/src/main.py +++ b/src/fondant/components/filter_language/src/main.py @@ -21,6 +21,7 @@ def __init__(self, language: str): Args: language (str): language to filter on """ + super().__init__() self.language = language self.model = fasttext.load_model(MODEL_PATH) diff --git a/src/fondant/components/filter_text_length/src/main.py b/src/fondant/components/filter_text_length/src/main.py index 5c9c438c..4d88ee10 100644 --- a/src/fondant/components/filter_text_length/src/main.py +++ b/src/fondant/components/filter_text_length/src/main.py @@ -18,6 +18,7 @@ def __init__(self, *, min_characters_length: int, min_words_length: int): min_characters_length: minimum number of characters min_words_length: minimum number of words. """ + super().__init__() self.min_characters_length = min_characters_length self.min_words_length = min_words_length diff --git a/src/fondant/components/generate_minhash/src/main.py b/src/fondant/components/generate_minhash/src/main.py index 0d56f024..1955edf1 100644 --- a/src/fondant/components/generate_minhash/src/main.py +++ b/src/fondant/components/generate_minhash/src/main.py @@ -39,6 +39,7 @@ def __init__(self, *, shingle_ngram_size: int): Args: shingle_ngram_size: Defines size of ngram used for the shingle generation. """ + super().__init__() self.shingle_ngram_size = shingle_ngram_size def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: diff --git a/src/fondant/components/index_aws_opensearch/src/main.py b/src/fondant/components/index_aws_opensearch/src/main.py index 4719a9d2..9650121e 100644 --- a/src/fondant/components/index_aws_opensearch/src/main.py +++ b/src/fondant/components/index_aws_opensearch/src/main.py @@ -23,6 +23,7 @@ def __init__( verify_certs: Optional[bool], pool_maxsize: Optional[int], ): + super().__init__() session = boto3.Session() credentials = session.get_credentials() auth = AWSV4SignerAuth(credentials, region) diff --git a/src/fondant/components/index_qdrant/src/main.py b/src/fondant/components/index_qdrant/src/main.py index 0a14642d..227e0dae 100644 --- a/src/fondant/components/index_qdrant/src/main.py +++ b/src/fondant/components/index_qdrant/src/main.py @@ -28,6 +28,7 @@ def __init__( force_disable_check_same_thread: bool = False, ): """Initialize the IndexQdrantComponent with the component parameters.""" + super().__init__() self.client = QdrantClient( location=location, url=url, diff --git a/src/fondant/components/index_weaviate/src/main.py b/src/fondant/components/index_weaviate/src/main.py index e566d8c7..c30dc7bc 100644 --- a/src/fondant/components/index_weaviate/src/main.py +++ b/src/fondant/components/index_weaviate/src/main.py @@ -24,6 +24,7 @@ def __init__( vectorizer: t.Optional[str], module_config: t.Optional[dict], ): + super().__init__() self.client = weaviate.Client( url=weaviate_url, additional_config=additional_config if additional_config else None, diff --git a/src/fondant/components/load_from_csv/src/main.py b/src/fondant/components/load_from_csv/src/main.py index 9635f7c2..c0e69542 100644 --- a/src/fondant/components/load_from_csv/src/main.py +++ b/src/fondant/components/load_from_csv/src/main.py @@ -31,6 +31,7 @@ def __init__( if not specified a default globally unique index will be set. """ + super().__init__() self.dataset_uri = dataset_uri self.column_separator = column_separator self.column_name_mapping = column_name_mapping diff --git a/src/fondant/components/load_from_files/src/main.py b/src/fondant/components/load_from_files/src/main.py index 62c9f110..183ebfcd 100644 --- a/src/fondant/components/load_from_files/src/main.py +++ b/src/fondant/components/load_from_files/src/main.py @@ -291,6 +291,7 @@ class LoadFromFiles(DaskLoadComponent): """Component that loads datasets from files.""" def __init__(self, *_, directory_uri: str) -> None: + super().__init__() self.directory_uri = directory_uri def load(self) -> dd.DataFrame: diff --git a/src/fondant/components/load_from_hf_hub/src/main.py b/src/fondant/components/load_from_hf_hub/src/main.py index fde320ff..93b1881e 100644 --- a/src/fondant/components/load_from_hf_hub/src/main.py +++ b/src/fondant/components/load_from_hf_hub/src/main.py @@ -33,6 +33,7 @@ def __init__( index_column: Column to set index to in the load component, if not specified a default globally unique index will be set. """ + super().__init__() self.dataset_name = dataset_name self.column_name_mapping = column_name_mapping self.image_column_names = image_column_names diff --git a/src/fondant/components/load_from_parquet/src/main.py b/src/fondant/components/load_from_parquet/src/main.py index f7ff0095..2faba245 100644 --- a/src/fondant/components/load_from_parquet/src/main.py +++ b/src/fondant/components/load_from_parquet/src/main.py @@ -31,6 +31,7 @@ def __init__( index_column: Column to set index to in the load component, if not specified a default globally unique index will be set. """ + super().__init__() self.dataset_uri = dataset_uri self.column_name_mapping = column_name_mapping self.n_rows_to_load = n_rows_to_load diff --git a/src/fondant/components/load_from_pdf/src/main.py b/src/fondant/components/load_from_pdf/src/main.py index 5190e076..05fa9140 100644 --- a/src/fondant/components/load_from_pdf/src/main.py +++ b/src/fondant/components/load_from_pdf/src/main.py @@ -31,6 +31,7 @@ def __init__( of partitions will be equal to the number of CPU cores. Set to high values if the data is large and the pipeline is running out of memory. """ + super().__init__() self.pdf_path = pdf_path self.n_rows_to_load = n_rows_to_load self.index_column = index_column diff --git a/src/fondant/components/resize_images/src/main.py b/src/fondant/components/resize_images/src/main.py index 2587c441..0c29732c 100644 --- a/src/fondant/components/resize_images/src/main.py +++ b/src/fondant/components/resize_images/src/main.py @@ -12,6 +12,7 @@ class ResizeImagesComponent(PandasTransformComponent): """Component that resizes images based on a given width and height.""" def __init__(self, *, resize_width: int, resize_height: int) -> None: + super().__init__() self.resize_width = resize_width self.resize_height = resize_height diff --git a/src/fondant/components/retrieve_laion_by_embedding/src/main.py b/src/fondant/components/retrieve_laion_by_embedding/src/main.py index 9e895a29..5a9ff07c 100644 --- a/src/fondant/components/retrieve_laion_by_embedding/src/main.py +++ b/src/fondant/components/retrieve_laion_by_embedding/src/main.py @@ -30,6 +30,7 @@ def __init__( aesthetic_weight: weight of the aesthetic embedding to add to the query, between 0 and 1. """ + super().__init__() self.client = ClipClient( url="https://knn.laion.ai/knn-service", indice_name="laion5B-L-14", diff --git a/src/fondant/components/retrieve_laion_by_prompt/src/main.py b/src/fondant/components/retrieve_laion_by_prompt/src/main.py index d7f37956..6d4a7d52 100644 --- a/src/fondant/components/retrieve_laion_by_prompt/src/main.py +++ b/src/fondant/components/retrieve_laion_by_prompt/src/main.py @@ -32,6 +32,7 @@ def __init__( between 0 and 1. url: The url of the backend clip retrieval service, defaults to the public clip url. """ + super().__init__() self.client = ClipClient( url=url, indice_name="laion5B-L-14", diff --git a/src/fondant/components/segment_images/src/main.py b/src/fondant/components/segment_images/src/main.py index c43212f9..a77b866a 100644 --- a/src/fondant/components/segment_images/src/main.py +++ b/src/fondant/components/segment_images/src/main.py @@ -115,6 +115,7 @@ def __init__( model_id: id of the model on the Hugging Face hub batch_size: batch size to use. """ + super().__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Device: {self.device}") diff --git a/src/fondant/components/write_to_file/src/main.py b/src/fondant/components/write_to_file/src/main.py index 34e45d5c..3bdb7117 100644 --- a/src/fondant/components/write_to_file/src/main.py +++ b/src/fondant/components/write_to_file/src/main.py @@ -5,6 +5,7 @@ class WriteToFile(DaskWriteComponent): def __init__(self, *, path: str, format: str): """Initialize the write to file component.""" + super().__init__() self.path = path self.format = format diff --git a/src/fondant/components/write_to_hf_hub/src/main.py b/src/fondant/components/write_to_hf_hub/src/main.py index 3af66ade..e8668be3 100644 --- a/src/fondant/components/write_to_hf_hub/src/main.py +++ b/src/fondant/components/write_to_hf_hub/src/main.py @@ -52,6 +52,7 @@ def __init__( column_name_mapping: Mapping of the consumed fondant column names to the written hub column names. """ + super().__init__() huggingface_hub.login(token=hf_token) repo_id = f"{username}/{dataset_name}" From e577fd5a5d22fe277faba76574c05848b4a23b91 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Wed, 14 Feb 2024 10:53:39 +0100 Subject: [PATCH 3/6] Update embed_images component to use multiple gpu --- .../components/embed_images/src/main.py | 137 ++++++++++-------- 1 file changed, 76 insertions(+), 61 deletions(-) diff --git a/src/fondant/components/embed_images/src/main.py b/src/fondant/components/embed_images/src/main.py index e2c6d219..1fd3f8e4 100644 --- a/src/fondant/components/embed_images/src/main.py +++ b/src/fondant/components/embed_images/src/main.py @@ -7,6 +7,8 @@ import numpy as np import pandas as pd import torch +from dask.distributed import Client, get_worker +from dask_cuda import LocalCUDACluster from fondant.component import PandasTransformComponent from PIL import Image from transformers import BatchEncoding, CLIPProcessor, CLIPVisionModelWithProjection @@ -16,53 +18,6 @@ os.environ["TORCH_CUDNN_V8_API_DISABLED"] = "1" -def process_image_batch( - images: np.ndarray, - *, - processor: CLIPProcessor, - device: str, -) -> t.List[torch.Tensor]: - """ - Process image in batches to a list of tensors. - - Args: - images: The input images as a numpy array containing byte strings. - processor: The processor object for transforming the image. - device: The device to move the transformed image to. - """ - - def load(img: bytes) -> Image: - """Load the bytestring as an image.""" - bytes_ = io.BytesIO(img) - return Image.open(bytes_).convert("RGB") - - def transform(img: Image) -> BatchEncoding: - """Transform the image to a tensor using a clip processor and move it to the specified - device. - """ - # Edge case: https://github.com/huggingface/transformers/issues/21638 - if img.width == 1 or img.height == 1: - img = img.resize((224, 224)) - - return processor(images=img, return_tensors="pt").to(device) - - return [transform(load(image))["pixel_values"] for image in images] - - -@torch.no_grad() -def embed_image_batch( - image_batch: t.List[torch.Tensor], - *, - model: CLIPVisionModelWithProjection, - index: pd.Series, -) -> pd.Series: - """Embed a batch of images.""" - input_batch = torch.cat(image_batch) - output_batch = model(input_batch) - embeddings_batch = output_batch.image_embeds.cpu().tolist() - return pd.Series(embeddings_batch, index=index) - - class EmbedImagesComponent(PandasTransformComponent): """Component that embeds images using a CLIP model from the Hugging Face hub.""" @@ -76,19 +31,82 @@ def __init__( Args: model_id: id of the model on the Hugging Face hub batch_size: batch size to use. - kwargs: Unhandled keyword arguments passed in by Fondant. """ + self.model_id = model_id + self.batch_size = batch_size + self.device = "cuda" if torch.cuda.is_available() else "cpu" - logger.info("device used is %s", self.device) + logger.info("Using device '%s'", self.device) - logger.info("Initialize model '%s'", model_id) - self.processor = CLIPProcessor.from_pretrained(model_id) - self.model = CLIPVisionModelWithProjection.from_pretrained(model_id).to( - self.device, - ) - logger.info("Model initialized") + super().__init__() - self.batch_size = batch_size + def dask_client(self) -> Client: + if self.device == "cuda": + cluster = LocalCUDACluster() + return Client(cluster) + + return super().dask_client() + + def process_image_batch(self, images: np.ndarray) -> t.List[torch.Tensor]: + """ + Process image in batches to a list of tensors. + + Args: + images: The input images as a numpy array containing byte strings. + """ + worker = get_worker() + + if hasattr(worker, "processor"): + processor = worker.processor + else: + logger.info( + "Initializing processor for '%s' on worker '%s", + self.model_id, + worker, + ) + processor = CLIPProcessor.from_pretrained(self.model_id) + worker.processor = processor + + def load(img: bytes) -> Image: + """Load the bytestring as an image.""" + bytes_ = io.BytesIO(img) + return Image.open(bytes_).convert("RGB") + + def transform(img: Image) -> BatchEncoding: + """Transform the image to a tensor using a clip processor and move it to the specified + device. + """ + # Edge case: https://github.com/huggingface/transformers/issues/21638 + if img.width == 1 or img.height == 1: + img = img.resize((224, 224)) + + return processor(images=img, return_tensors="pt").to(self.device) + + return [transform(load(image))["pixel_values"] for image in images] + + @torch.no_grad() + def embed_image_batch( + self, + image_batch: t.List[torch.Tensor], + *, + index: pd.Series, + ) -> pd.Series: + """Embed a batch of images.""" + worker = get_worker() + + if hasattr(worker, "model"): + model = worker.model + else: + logger.info("Initializing model '%s' on worker '%s", self.model_id, worker) + model = CLIPVisionModelWithProjection.from_pretrained(self.model_id).to( + self.device, + ) + worker.model = model + + input_batch = torch.cat(image_batch) + output_batch = model(input_batch) + embeddings_batch = output_batch.image_embeds.cpu().tolist() + return pd.Series(embeddings_batch, index=index) def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: images = dataframe["image"] @@ -99,14 +117,11 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: np.arange(self.batch_size, len(images), self.batch_size), ): if not batch.empty: - image_tensors = process_image_batch( + image_tensors = self.process_image_batch( batch, - processor=self.processor, - device=self.device, ) - embeddings = embed_image_batch( + embeddings = self.embed_image_batch( image_tensors, - model=self.model, index=batch.index, ).T results.append(embeddings) From 737da5b78085631fabe57a91b0aef30c23c76542 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Wed, 14 Feb 2024 10:54:09 +0100 Subject: [PATCH 4/6] Add docstring --- src/fondant/component/component.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fondant/component/component.py b/src/fondant/component/component.py index b66bb697..a1036d8a 100644 --- a/src/fondant/component/component.py +++ b/src/fondant/component/component.py @@ -41,6 +41,7 @@ def __init__(self, **kwargs): self.dask_client() def dask_client(self) -> Client: + """Initialize the dask client to use for this component.""" cluster = LocalCluster( processes=True, n_workers=os.cpu_count(), From 429ab71ba7cdfe5a862a6a78b0d0509c9d13c510 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Thu, 15 Feb 2024 09:53:14 +0100 Subject: [PATCH 5/6] Add docs --- docs/components/components.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/docs/components/components.md b/docs/components/components.md index 914c75f3..c8f74e63 100644 --- a/docs/components/components.md +++ b/docs/components/components.md @@ -1,3 +1,5 @@ +from distributed import Client + # Components Fondant makes it easy to build data preparation pipelines leveraging reusable components. Fondant @@ -65,6 +67,36 @@ this data can be accessed using `dataframe["image"]`. The `transform` method should return a single dataframe, with the columns complying to the schema defined by the `produces` section of the component specification. +### Configuring Dask + +You can configure the [Dask client](https://docs.dask.org/en/stable/scheduling.html) based on the +needs of your component by overriding the `dask_client` method: + +```python +import os + +from dask.distributed import Client, LocalCluster +from fondant.component import PandasTransformComponent + +class Component(PandasTransformComponent): + + def dask_client(self) -> Client: + """Initialize the dask client to use for this component.""" + cluster = LocalCluster( + processes=True, + n_workers=os.cpu_count(), + threads_per_worker=1, + ) + return Client(cluster) +``` + +The default configuration uses a `LocalCluster` which works with processes, the same amount of +workers as logical CPUs available, and one thread per worker. + +Some components might work more optimally using threads or a different combination of threads +and processes. To use multiple GPUs, you can use a +[`LocalCUDACluster`](https://docs.rapids.ai/api/dask-cuda/stable/quickstart/#localcudacluster). + ## Component types We can distinguish two different types of components: From 41adcbb7d4672ed1cc09167a4fd63bc77310964e Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Thu, 15 Feb 2024 10:30:20 +0100 Subject: [PATCH 6/6] Add dask-cuda to embed_images requirements --- src/fondant/components/embed_images/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fondant/components/embed_images/requirements.txt b/src/fondant/components/embed_images/requirements.txt index f8ded479..57d15da7 100644 --- a/src/fondant/components/embed_images/requirements.txt +++ b/src/fondant/components/embed_images/requirements.txt @@ -1,2 +1,3 @@ Pillow==10.0.1 -transformers==4.28.0 \ No newline at end of file +transformers==4.28.0 +dask-cuda==23.12.0 \ No newline at end of file