diff --git a/CHANGELOG.md b/CHANGELOG.md index 78d10f4fac..770aab9f32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,14 +7,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -### Added +### Added +- add `max_retries` and `num_threads` parameters to `rg.log` to run data logging request concurrently with backoff retry policy. See [#2458](https://github.com/argilla-io/argilla/issues/2458) and [#2533](https://github.com/argilla-io/argilla/issues/2533) - `rg.load` accepts `exclude_vectors` and `exclude_metrics` when loading data. Closes [#2398](https://github.com/argilla-io/argilla/issues/2398) ### Changed - Argilla quickstart image dependencies are externalized into `quickstart.requirements.txt`. See [#2666](https://github.com/argilla-io/argilla/pull/2666) - bulk endpoints will upsert data when record `id` is present. Closes [#2535](https://github.com/argilla-io/argilla/issues/2535) +- The `rg.log` computes all batches and raise an error for all failed batches. +- The default batch size for `rg.log` is now 100. + +### Deprecated + +- The `rg.log_async` function is deprecated and will be removed in next minor release. + ## [1.6.0](https://github.com/argilla-io/argilla/compare/v1.5.1...v1.6.0) diff --git a/src/argilla/client/api.py b/src/argilla/client/api.py index 3f95c3e8ad..368122f25c 100644 --- a/src/argilla/client/api.py +++ b/src/argilla/client/api.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging +import warnings from asyncio import Future from typing import Any, Dict, Iterable, List, Optional, Tuple, Union @@ -112,10 +114,12 @@ def log( workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, - batch_size: int = 500, + batch_size: int = 100, verbose: bool = True, background: bool = False, chunk_size: Optional[int] = None, + num_threads: int = 0, + max_retries: int = 3, ) -> Union[BulkResponse, Future]: """Logs Records to argilla. @@ -133,6 +137,10 @@ def log( background: If True, we will NOT wait for the logging process to finish and return an ``asyncio.Future`` object. You probably want to set ``verbose`` to False in that case. chunk_size: DEPRECATED! Use `batch_size` instead. + num_threads: If > 0, will use num_thread separate number threads to batches, sending data concurrently. + Default to `0`, which means no threading at all. + max_retries: Number of retries when logging a batch of records if a `httpx.TransportError` occurs. + Default `3`. Returns: Summary of the response from the REST API. @@ -162,6 +170,8 @@ def log( verbose=verbose, background=background, chunk_size=chunk_size, + num_threads=num_threads, + max_retries=max_retries, ) @@ -171,7 +181,7 @@ async def log_async( workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, - batch_size: int = 500, + batch_size: int = 100, verbose: bool = True, chunk_size: Optional[int] = None, ) -> BulkResponse: @@ -201,7 +211,14 @@ async def log_async( ... rg.log_async(my_records, dataset_name), loop ... ) """ - return await ArgillaSingleton.get().log_async( + + warnings.warn( + "`log_async` is deprecated and will be removed in next release. " + "Please, use `log` with `background=True` instead", + DeprecationWarning, + ) + + future = ArgillaSingleton.get().log( records=records, name=name, workspace=workspace, @@ -210,8 +227,11 @@ async def log_async( batch_size=batch_size, verbose=verbose, chunk_size=chunk_size, + background=True, ) + return await asyncio.wrap_future(future) + def load( name: str, diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index dbfa280953..cf87a5aeb6 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio + import logging import os import re import warnings from asyncio import Future +from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +import backoff +import httpx from rich import print as rprint from rich.progress import Progress @@ -46,10 +49,10 @@ TokenClassificationRecord, ) from argilla.client.sdk.client import AuthenticatedClient -from argilla.client.sdk.commons.api import async_bulk +from argilla.client.sdk.commons.api import bulk from argilla.client.sdk.commons.errors import ( AlreadyExistsApiError, - ApiCompatibilityError, + BaseClientError, InputValueError, NotFoundApiError, ) @@ -88,24 +91,6 @@ _LOGGER = logging.getLogger(__name__) -class _ArgillaLogAgent: - def __init__(self, api: "Argilla"): - self.__api__ = api - self.__loop__, self.__thread__ = setup_loop_in_thread() - - @staticmethod - async def __log_internal__(api: "Argilla", *args, **kwargs): - try: - return await api.log_async(*args, **kwargs) - except Exception as ex: - dataset = kwargs["name"] - _LOGGER.error(f"\nCannot log data in dataset '{dataset}'\nError: {type(ex).__name__}\nDetails: {ex}") - raise ex - - def log(self, *args, **kwargs) -> Future: - return asyncio.run_coroutine_threadsafe(self.__log_internal__(self.__api__, *args, **kwargs), self.__loop__) - - class Argilla: """ The main argilla client. @@ -119,7 +104,7 @@ def __init__( api_url: Optional[str] = None, api_key: Optional[str] = None, workspace: Optional[str] = None, - timeout: int = 60, + timeout: int = 120, extra_headers: Optional[Dict[str, str]] = None, ): """ @@ -153,13 +138,9 @@ def __init__( self._user: User = users_api.whoami(client=self._client) self.set_workspace(workspace or self._user.username) - self._agent = _ArgillaLogAgent(self) - def __del__(self): if hasattr(self, "_client"): del self._client - if hasattr(self, "_agent"): - del self._agent @property def client(self) -> AuthenticatedClient: @@ -239,10 +220,7 @@ def copy(self, dataset: str, name_of_copy: str, workspace: str = None): datasets_api.copy_dataset( client=self._client, name=dataset, - json_body=CopyDatasetRequest( - name=name_of_copy, - target_workspace=workspace, - ), + json_body=CopyDatasetRequest(name=name_of_copy, target_workspace=workspace), ) def delete(self, name: str, workspace: Optional[str] = None): @@ -263,9 +241,11 @@ def log( workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, - batch_size: int = 500, + batch_size: int = 100, verbose: bool = True, background: bool = False, + num_threads: int = 0, + max_retries: int = 3, chunk_size: Optional[int] = None, ) -> Union[BulkResponse, Future]: """Logs Records to argilla. @@ -282,6 +262,10 @@ def log( background: If True, we will NOT wait for the logging process to finish and return an ``asyncio.Future`` object. You probably want to set ``verbose`` to False in that case. + num_threads: If > 0, will use num_thread separate number threads to batches, sending data concurrently. + Default to `0`, which means no threading at all. + max_retries: Number of retries when logging a batch of records if a `httpx.TransportError` occurs. + Default `3` chunk_size: DEPRECATED! Use `batch_size` instead. Returns: @@ -290,52 +274,24 @@ def log( will be returned instead. """ - if workspace is not None: - self.set_workspace(workspace) - future = self._agent.log( - records=records, - name=name, - tags=tags, - metadata=metadata, - batch_size=batch_size, - verbose=verbose, - chunk_size=chunk_size, - ) if background: - return future - - try: - return future.result() - finally: - future.cancel() - - async def log_async( - self, - records: Union[Record, Iterable[Record], Dataset], - name: str, - workspace: Optional[str] = None, - tags: Optional[Dict[str, str]] = None, - metadata: Optional[Dict[str, Any]] = None, - batch_size: int = 500, - verbose: bool = True, - chunk_size: Optional[int] = None, - ) -> BulkResponse: - """Logs Records to argilla with asyncio. - - Args: - records: The record, an iterable of records, or a dataset to log. - name: The dataset name. - tags: A dictionary of tags related to the dataset. - metadata: A dictionary of extra info for the dataset. - batch_size: The batch size for a data bulk. - verbose: If True, shows a progress bar and prints out a quick summary at the end. - chunk_size: DEPRECATED! Use `batch_size` instead. + executor = ThreadPoolExecutor(max_workers=1) - Returns: - Summary of the response from the REST API + return executor.submit( + self.log, + records=records, + name=name, + workspace=workspace, + tags=tags, + metadata=metadata, + batch_size=batch_size, + verbose=verbose, + chunk_size=chunk_size, + num_threads=num_threads, + max_retries=max_retries, + ) - """ tags = tags or {} metadata = metadata or {} @@ -389,27 +345,40 @@ async def log_async( else: raise InputValueError(f"Unknown record type {record_type}. Available values are {Record.__args__}") - processed, failed = 0, 0 + results = [] with Progress() as progress_bar: task = progress_bar.add_task("Logging...", total=len(records), visible=verbose) - for i in range(0, len(records), batch_size): - batch = records[i : i + batch_size] + batches = [records[i : i + batch_size] for i in range(0, len(records), batch_size)] + + @backoff.on_exception( + backoff.expo, + exception=httpx.TransportError, + max_tries=max_retries, + backoff_log_level=logging.DEBUG, + ) + def log_batch(batch_info: Tuple[int, list]) -> Union[Tuple[int, int]]: + batch_id, batch = batch_info - response = await async_bulk( + bulk_result = bulk( client=self._client, name=name, json_body=bulk_class( - tags=tags, - metadata=metadata, - records=[creation_class.from_client(r) for r in batch], + tags=tags, metadata=metadata, records=[creation_class.from_client(r) for r in batch] ), ) - processed += response.parsed.processed - failed += response.parsed.failed - progress_bar.update(task, advance=len(batch)) + return bulk_result.processed, bulk_result.failed + + if num_threads >= 1: + with ThreadPoolExecutor(max_workers=num_threads) as executor: + results.extend(list(executor.map(log_batch, enumerate(batches)))) + else: + results.extend(list(map(log_batch, enumerate(batches)))) + + processed_batches, failed_batches = zip(*results) + processed, failed = sum(processed_batches), sum(failed_batches) # TODO: improve logging policy in library if verbose: diff --git a/src/argilla/client/sdk/commons/api.py b/src/argilla/client/sdk/commons/api.py index 5c96eff3f8..527180fa14 100644 --- a/src/argilla/client/sdk/commons/api.py +++ b/src/argilla/client/sdk/commons/api.py @@ -63,18 +63,11 @@ def bulk( client: AuthenticatedClient, name: str, json_body: Union[TextClassificationBulkData, TokenClassificationBulkData, Text2TextBulkData], -) -> Response[BulkResponse]: +) -> BulkResponse: url = f"{client.base_url}/api/datasets/{name}/{_TASK_TO_ENDPOINT[type(json_body)]}:bulk" - response = httpx.post( - url=url, - headers=client.get_headers(), - cookies=client.get_cookies(), - timeout=client.get_timeout(), - json=json_body.dict(by_alias=True), - ) - - return build_bulk_response(response, name=name, body=json_body) + response = client.post(path=url, json=json_body.dict(by_alias=True)) + return BulkResponse.parse_obj(response) async def async_bulk( diff --git a/tests/client/sdk/commons/api.py b/tests/client/sdk/commons/api.py index 554a03611c..2b3604c572 100644 --- a/tests/client/sdk/commons/api.py +++ b/tests/client/sdk/commons/api.py @@ -48,8 +48,7 @@ def test_textclass_bulk(sdk_client, mocked_client, bulk_textclass_data, monkeypa mocked_client.delete(f"/api/datasets/{dataset_name}") response = bulk(sdk_client, name=dataset_name, json_body=bulk_textclass_data) - assert response.status_code == 200 - assert isinstance(response.parsed, BulkResponse) + assert isinstance(response, BulkResponse) def test_tokenclass_bulk(sdk_client, mocked_client, bulk_tokenclass_data, monkeypatch): @@ -59,8 +58,7 @@ def test_tokenclass_bulk(sdk_client, mocked_client, bulk_tokenclass_data, monkey mocked_client.delete(f"/api/datasets/{dataset_name}") response = bulk(sdk_client, name=dataset_name, json_body=bulk_tokenclass_data) - assert response.status_code == 200 - assert isinstance(response.parsed, BulkResponse) + assert isinstance(response, BulkResponse) @pytest.mark.parametrize( diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 3440aa0717..033774ec0c 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -49,6 +49,7 @@ ) from argilla.server.contexts import accounts from argilla.server.security.model import WorkspaceCreate, WorkspaceUserCreate +from httpx import ConnectError from sqlalchemy.orm import Session from tests.helpers import SecuredClient @@ -295,10 +296,10 @@ def test_log_background_with_error(mocked_client: SecuredClient, monkeypatch: An def raise_http_error(*args, **kwargs): raise httpx.ConnectError("Mock error", request=None) - monkeypatch.setattr(httpx.AsyncClient, "post", raise_http_error) + monkeypatch.setattr(api.active_client().http_client, "post", raise_http_error) future = api.log(rg.TextClassificationRecord(text=sample_text), name=dataset_name, background=True) - with pytest.raises(BaseClientError): + with pytest.raises(ConnectError): try: future.result() finally: diff --git a/tests/functional_tests/test_log_for_token_classification.py b/tests/functional_tests/test_log_for_token_classification.py index ceab2bee9d..08e7048e58 100644 --- a/tests/functional_tests/test_log_for_token_classification.py +++ b/tests/functional_tests/test_log_for_token_classification.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import argilla import pytest from argilla import TokenClassificationRecord @@ -19,6 +18,7 @@ from argilla.client.sdk.commons.errors import NotFoundApiError from argilla.metrics import __all__ as ALL_METRICS from argilla.metrics import entity_consistency +from datasets import load_dataset from tests.client.conftest import SUPPORTED_VECTOR_SEARCH from tests.helpers import SecuredClient @@ -504,6 +504,21 @@ def test_log_data_with_vectors_and_update_ok(mocked_client: SecuredClient, api): assert ds[0].id == 3 +def test_logging_data_with_concurrency(mocked_client): + from datasets import load_dataset + + dataset = "test_logging_data_with_concurrency" + dataset_ds = load_dataset("rubrix/gutenberg_spacy-ner", split="train") + + dataset_rb = argilla.read_datasets(dataset_ds, task="TokenClassification") + + api.delete(dataset) + api.log(name=dataset, records=dataset_rb, batch_size=int(len(dataset_ds) / 4), num_threads=4) + + ds = api.load(name=dataset) + assert len(dataset_ds) == len(ds) + + @pytest.mark.skipif( condition=not SUPPORTED_VECTOR_SEARCH, reason="Vector search not supported",