Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improve rg.log function #2640

Merged
merged 21 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed
## Added
frascuchon marked this conversation as resolved.
Show resolved Hide resolved

- add `max_retries` and `num_threads` parameters to `rg.log` to run data logging request concurrently with backoff retry policy. See [#2458](https://github.com/argilla-io/argilla/issues/2458) and [#2533](https://github.com/argilla-io/argilla/issues/2533)

## Changed

- Argilla quickstart image dependencies are externalized into `quickstart.requirements.txt`. See [#2666](https://github.com/argilla-io/argilla/pull/2666)
- The `rg.log_async` is deprecated and will be removed in next minor release.
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
- The `rg.log` will compute all batches and raise an error for all failed batches.

frascuchon marked this conversation as resolved.
Show resolved Hide resolved
## [1.6.0](https://github.com/argilla-io/argilla/compare/v1.5.1...v1.6.0)

Expand Down
26 changes: 23 additions & 3 deletions src/argilla/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import warnings
from asyncio import Future
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -112,10 +114,12 @@ def log(
workspace: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
batch_size: int = 500,
batch_size: int = 100,
verbose: bool = True,
background: bool = False,
chunk_size: Optional[int] = None,
num_threads: int = 0,
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
max_retries: int = 3,
) -> Union[BulkResponse, Future]:
"""Logs Records to argilla.

Expand All @@ -133,6 +137,10 @@ def log(
background: If True, we will NOT wait for the logging process to finish and return an ``asyncio.Future``
object. You probably want to set ``verbose`` to False in that case.
chunk_size: DEPRECATED! Use `batch_size` instead.
num_threads: If > 0, will use num_thread separate number threads to batches, sending data concurrently.
Default to `0`, which means no threading at all.
max_retries: Number of retries when logging a batch of records if a `httpx.TransportError` occurs.
Default `3`.

Returns:
Summary of the response from the REST API.
Expand Down Expand Up @@ -162,6 +170,8 @@ def log(
verbose=verbose,
background=background,
chunk_size=chunk_size,
num_threads=num_threads,
max_retries=max_retries,
)


Expand All @@ -171,7 +181,7 @@ async def log_async(
workspace: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
batch_size: int = 500,
batch_size: int = 100,
verbose: bool = True,
chunk_size: Optional[int] = None,
) -> BulkResponse:
Expand Down Expand Up @@ -201,7 +211,14 @@ async def log_async(
... rg.log_async(my_records, dataset_name), loop
... )
"""
return await ArgillaSingleton.get().log_async(

warnings.warn(
"`log_async` is deprecated and will be removed in next release. "
"Please, use `log` with `background=True` instead",
DeprecationWarning,
)

future = ArgillaSingleton.get().log(
records=records,
name=name,
workspace=workspace,
Expand All @@ -210,8 +227,11 @@ async def log_async(
batch_size=batch_size,
verbose=verbose,
chunk_size=chunk_size,
background=True,
)

return await asyncio.wrap_future(future)


def load(
name: str,
Expand Down
137 changes: 55 additions & 82 deletions src/argilla/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio

import logging
import os
import re
import warnings
from asyncio import Future
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import backoff
import httpx
from rich import print as rprint
from rich.progress import Progress

Expand All @@ -46,10 +49,10 @@
TokenClassificationRecord,
)
from argilla.client.sdk.client import AuthenticatedClient
from argilla.client.sdk.commons.api import async_bulk
from argilla.client.sdk.commons.api import bulk
from argilla.client.sdk.commons.errors import (
AlreadyExistsApiError,
ApiCompatibilityError,
BaseClientError,
InputValueError,
NotFoundApiError,
)
Expand Down Expand Up @@ -88,24 +91,6 @@
_LOGGER = logging.getLogger(__name__)


class _ArgillaLogAgent:
def __init__(self, api: "Argilla"):
self.__api__ = api
self.__loop__, self.__thread__ = setup_loop_in_thread()

@staticmethod
async def __log_internal__(api: "Argilla", *args, **kwargs):
try:
return await api.log_async(*args, **kwargs)
except Exception as ex:
dataset = kwargs["name"]
_LOGGER.error(f"\nCannot log data in dataset '{dataset}'\nError: {type(ex).__name__}\nDetails: {ex}")
raise ex

def log(self, *args, **kwargs) -> Future:
return asyncio.run_coroutine_threadsafe(self.__log_internal__(self.__api__, *args, **kwargs), self.__loop__)


class Argilla:
"""
The main argilla client.
Expand All @@ -119,7 +104,7 @@ def __init__(
api_url: Optional[str] = None,
api_key: Optional[str] = None,
workspace: Optional[str] = None,
timeout: int = 60,
timeout: int = 120,
extra_headers: Optional[Dict[str, str]] = None,
):
"""
Expand Down Expand Up @@ -153,13 +138,9 @@ def __init__(
self._user: User = users_api.whoami(client=self._client)
self.set_workspace(workspace or self._user.username)

self._agent = _ArgillaLogAgent(self)

def __del__(self):
if hasattr(self, "_client"):
del self._client
if hasattr(self, "_agent"):
del self._agent

@property
def client(self) -> AuthenticatedClient:
Expand Down Expand Up @@ -239,10 +220,7 @@ def copy(self, dataset: str, name_of_copy: str, workspace: str = None):
datasets_api.copy_dataset(
client=self._client,
name=dataset,
json_body=CopyDatasetRequest(
name=name_of_copy,
target_workspace=workspace,
),
json_body=CopyDatasetRequest(name=name_of_copy, target_workspace=workspace),
)

def delete(self, name: str, workspace: Optional[str] = None):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also like to see the backoff variables here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should apply changes step by step. We can consider adding a backoff mechanism to another method in a separate PR. Otherwise, a lot of changes will be included in the same PR, which can be a great bug farm. :-)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it to rg.load which seam the most relevant to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For rg.load, things could be a bit different. For instance, we should decrease the batch size, or we should prefetch some data before splitting and parallelizing the data loading. But yes. we can have a similar approach to improve also that method

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decreasing batch size on failure seems very smart for rg.load in particular.

Expand All @@ -263,10 +241,12 @@ def log(
workspace: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
batch_size: int = 500,
batch_size: int = 100,
verbose: bool = True,
background: bool = False,
chunk_size: Optional[int] = None,
num_threads: int = 0,
max_retries: int = 3,
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
) -> Union[BulkResponse, Future]:
"""Logs Records to argilla.

Expand All @@ -283,59 +263,38 @@ def log(
an ``asyncio.Future`` object. You probably want to set ``verbose`` to False
in that case.
chunk_size: DEPRECATED! Use `batch_size` instead.
num_threads: If > 0, will use num_thread separate number threads to batches, sending data concurrently.
Default to `0`, which means no threading at all.
max_retries: Number of retries when logging a batch of records if a `httpx.TransportError` occurs.
Default `3`
frascuchon marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Summary of the response from the REST API.
If the ``background`` argument is set to True, an ``asyncio.Future``
will be returned instead.

"""
if workspace is not None:
self.set_workspace(workspace)

future = self._agent.log(
records=records,
name=name,
tags=tags,
metadata=metadata,
batch_size=batch_size,
verbose=verbose,
chunk_size=chunk_size,
)
if background:
return future

try:
return future.result()
finally:
future.cancel()
executor = ThreadPoolExecutor(max_workers=1)

async def log_async(
self,
records: Union[Record, Iterable[Record], Dataset],
name: str,
workspace: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
batch_size: int = 500,
verbose: bool = True,
chunk_size: Optional[int] = None,
) -> BulkResponse:
"""Logs Records to argilla with asyncio.

Args:
records: The record, an iterable of records, or a dataset to log.
name: The dataset name.
tags: A dictionary of tags related to the dataset.
metadata: A dictionary of extra info for the dataset.
batch_size: The batch size for a data bulk.
verbose: If True, shows a progress bar and prints out a quick summary at the end.
chunk_size: DEPRECATED! Use `batch_size` instead.
return executor.submit(
self.log,
records=records,
name=name,
workspace=workspace,
tags=tags,
metadata=metadata,
batch_size=batch_size,
verbose=verbose,
chunk_size=chunk_size,
num_threads=num_threads,
max_retries=max_retries,
)

Returns:
Summary of the response from the REST API
if workspace is not None:
self.set_workspace(workspace)

tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
"""
tags = tags or {}
metadata = metadata or {}

Expand Down Expand Up @@ -389,27 +348,41 @@ async def log_async(
else:
raise InputValueError(f"Unknown record type {record_type}. Available values are {Record.__args__}")

processed, failed = 0, 0
with Progress() as progress_bar:
task = progress_bar.add_task("Logging...", total=len(records), visible=verbose)

for i in range(0, len(records), batch_size):
batch = records[i : i + batch_size]
batches = [records[i : i + batch_size] for i in range(0, len(records), batch_size)]

@backoff.on_exception(
backoff.expo,
exception=httpx.TransportError,
max_tries=max_retries,
backoff_log_level=logging.DEBUG,
)
def log_batch(batch_info: Tuple[int, list]) -> Union[Tuple[int, int]]:
batch_id, batch = batch_info

response = await async_bulk(
bulk_result = bulk(
client=self._client,
name=name,
json_body=bulk_class(
tags=tags,
metadata=metadata,
records=[creation_class.from_client(r) for r in batch],
tags=tags, metadata=metadata, records=[creation_class.from_client(r) for r in batch]
),
)

processed += response.parsed.processed
failed += response.parsed.failed

progress_bar.update(task, advance=len(batch))
return bulk_result.processed, bulk_result.failed

if num_threads >= 1:
with ThreadPoolExecutor(max_workers=num_threads) as executor:
results = list(executor.map(log_batch, enumerate(batches)))
else:
results = list(map(log_batch, enumerate(batches)))

processed, failed = 0, 0
for processed_batch, failed_batch in results:
processed += processed_batch
failed += failed_batch
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved

# TODO: improve logging policy in library
if verbose:
Expand Down
13 changes: 3 additions & 10 deletions src/argilla/client/sdk/commons/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,11 @@ def bulk(
client: AuthenticatedClient,
name: str,
json_body: Union[TextClassificationBulkData, TokenClassificationBulkData, Text2TextBulkData],
) -> Response[BulkResponse]:
) -> BulkResponse:
url = f"{client.base_url}/api/datasets/{name}/{_TASK_TO_ENDPOINT[type(json_body)]}:bulk"

response = httpx.post(
url=url,
headers=client.get_headers(),
cookies=client.get_cookies(),
timeout=client.get_timeout(),
json=json_body.dict(by_alias=True),
)

return build_bulk_response(response, name=name, body=json_body)
response = client.post(path=url, json=json_body.dict(by_alias=True))
return BulkResponse.parse_obj(response)


async def async_bulk(
Expand Down
6 changes: 2 additions & 4 deletions tests/client/sdk/commons/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def test_textclass_bulk(sdk_client, mocked_client, bulk_textclass_data, monkeypa
mocked_client.delete(f"/api/datasets/{dataset_name}")
response = bulk(sdk_client, name=dataset_name, json_body=bulk_textclass_data)

assert response.status_code == 200
assert isinstance(response.parsed, BulkResponse)
assert isinstance(response, BulkResponse)


def test_tokenclass_bulk(sdk_client, mocked_client, bulk_tokenclass_data, monkeypatch):
Expand All @@ -59,8 +58,7 @@ def test_tokenclass_bulk(sdk_client, mocked_client, bulk_tokenclass_data, monkey
mocked_client.delete(f"/api/datasets/{dataset_name}")
response = bulk(sdk_client, name=dataset_name, json_body=bulk_tokenclass_data)

assert response.status_code == 200
assert isinstance(response.parsed, BulkResponse)
assert isinstance(response, BulkResponse)


@pytest.mark.parametrize(
Expand Down
5 changes: 3 additions & 2 deletions tests/client/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
)
from argilla.server.contexts import accounts
from argilla.server.security.model import WorkspaceCreate, WorkspaceUserCreate
from httpx import ConnectError
from sqlalchemy.orm import Session

from tests.helpers import SecuredClient
Expand Down Expand Up @@ -295,10 +296,10 @@ def test_log_background_with_error(mocked_client: SecuredClient, monkeypatch: An
def raise_http_error(*args, **kwargs):
raise httpx.ConnectError("Mock error", request=None)

monkeypatch.setattr(httpx.AsyncClient, "post", raise_http_error)
monkeypatch.setattr(api.active_client().http_client, "post", raise_http_error)

future = api.log(rg.TextClassificationRecord(text=sample_text), name=dataset_name, background=True)
with pytest.raises(BaseClientError):
with pytest.raises(ConnectError):
try:
future.result()
finally:
Expand Down
Loading