Skip to content

Commit

Permalink
fix: import errors when importing from argilla.feedback (#3471)
Browse files Browse the repository at this point in the history
# Description

This PRs fixes the `ModuleNotFoundError` and `ImportError` that occurred
when trying to import something from `argilla.feedback` module.

The first error was caused because in #3336 the telemetry was included
in the `ArgillaTrainer`, but in the `argilla.utils.telemetry` module
some optional dependencies used by the server were being imported.

The second one was caused because the module in which
`HuggingFaceDatasetMixin` (and from which `FeedbackDataset` is
inheriting) class lives was importing classes from the
`argilla.client.feedback.config` module, which was importing `pyyaml` in
its root causing the `ImportError`.

Closes #3468 

**Type of change**

- [x] Bug fix (non-breaking change which fixes an issue)

**How Has This Been Tested**

I've created a wheel of this branch, installed in a new virtual
environment and I was able to import something `argilla.feedback` module
without errors.

**Checklist**

- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [x] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Francisco Aranda <francis@argilla.io>
  • Loading branch information
gabrielmbmb and frascuchon authored Jul 27, 2023
1 parent 2d0029a commit d37ea7e
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 81 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ These are the section headers that we use:

## [Unreleased]

## [1.13.3](https://github.com/argilla-io/argilla/compare/v1.13.2...v1.13.3)

### Fixed

- Fixed `ModuleNotFoundError` caused because the `argilla.utils.telemetry` module used in the `ArgillaTrainer` was importing an optional dependency not installed by default ([#3471](https://github.com/argilla-io/argilla/pull/3471)).
- Fixed `ImportError` caused because the `argilla.client.feedback.config` module was importing `pyyaml` optional dependency not installed by default ([#3471](https://github.com/argilla-io/argilla/pull/3471)).

## [1.13.2](https://github.com/argilla-io/argilla/compare/v1.13.1...v1.13.2)

### Fixed
Expand Down
10 changes: 9 additions & 1 deletion src/argilla/client/feedback/integrations/huggingface/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from packaging.version import parse as parse_version

from argilla.client.feedback.config import DatasetConfig, DeprecatedDatasetConfig
from argilla.client.feedback.constants import FIELD_TYPE_TO_PYTHON_TYPE
from argilla.client.feedback.schemas import FeedbackRecord
from argilla.client.feedback.types import AllowedQuestionTypes
Expand Down Expand Up @@ -188,6 +187,9 @@ def push_to_huggingface(
import huggingface_hub
from huggingface_hub import DatasetCardData, HfApi

# https://github.com/argilla-io/argilla/issues/3468
from argilla.client.feedback.config import DatasetConfig

if parse_version(huggingface_hub.__version__) < parse_version("0.14.0"):
_LOGGER.warning(
"Recommended `huggingface_hub` version is 0.14.0 or higher, and you have"
Expand Down Expand Up @@ -261,6 +263,12 @@ def from_huggingface(cls: Type["FeedbackDataset"], repo_id: str, *args: Any, **k
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError

# https://github.com/argilla-io/argilla/issues/3468
from argilla.client.feedback.config import (
DatasetConfig,
DeprecatedDatasetConfig,
)

if parse_version(huggingface_hub.__version__) < parse_version("0.14.0"):
_LOGGER.warning(
"Recommended `huggingface_hub` version is 0.14.0 or higher, and you have"
Expand Down
24 changes: 19 additions & 5 deletions src/argilla/server/errors/api_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any, Dict

from fastapi import HTTPException, Request
from fastapi.exception_handlers import http_exception_handler
from pydantic import BaseModel

from argilla.server.errors.adapter import exception_to_argilla_error
from argilla.server.errors.base_errors import ServerError
from argilla.server.errors.base_errors import (
EntityAlreadyExistsError,
EntityNotFoundError,
GenericServerError,
ServerError,
)
from argilla.utils import telemetry

_LOGGER = logging.getLogger("argilla")


class ErrorDetail(BaseModel):
code: str
Expand All @@ -41,10 +43,22 @@ def __init__(self, error: ServerError):


class APIErrorHandler:
@staticmethod
async def track_error(error: ServerError, request: Request):
data = {
"code": error.code,
"user-agent": request.headers.get("user-agent"),
"accept-language": request.headers.get("accept-language"),
}
if isinstance(error, (GenericServerError, EntityNotFoundError, EntityAlreadyExistsError)):
data["type"] = error.type

telemetry.get_telemetry_client().track_data(action="ServerErrorFound", data=data)

@staticmethod
async def common_exception_handler(request: Request, error: Exception):
"""Wraps errors as custom generic error"""
argilla_error = exception_to_argilla_error(error)
await telemetry.track_error(argilla_error, request=request)
await APIErrorHandler.track_error(argilla_error, request=request)

return await http_exception_handler(request, ServerHTTPException(argilla_error))
27 changes: 6 additions & 21 deletions src/argilla/utils/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,14 @@
import logging
import platform
import uuid
from typing import Any, Dict, Optional

from fastapi import Request
from typing import TYPE_CHECKING, Any, Dict, Optional

from argilla.server.commons.models import TaskType
from argilla.server.errors.base_errors import (
EntityAlreadyExistsError,
EntityNotFoundError,
GenericServerError,
ServerError,
)
from argilla.server.settings import settings

if TYPE_CHECKING:
from fastapi import Request

try:
from analytics import Client # This import works only for version 2.2.0
except (ImportError, ModuleNotFoundError):
Expand Down Expand Up @@ -89,25 +84,15 @@ def track_data(self, action: str, data: Dict[str, Any], include_system_info: boo
_CLIENT = TelemetryClient()


def _process_request_info(request: Request):
def _process_request_info(request: "Request"):
return {header: request.headers.get(header) for header in ["user-agent", "accept-language"]}


async def track_error(error: ServerError, request: Request):
data = {"code": error.code}
if isinstance(error, (GenericServerError, EntityNotFoundError, EntityAlreadyExistsError)):
data["type"] = error.type

data.update(_process_request_info(request))

_CLIENT.track_data(action="ServerErrorFound", data=data)


async def track_bulk(task: TaskType, records: int):
_CLIENT.track_data(action="LogRecordsRequested", data={"task": task, "records": records})


async def track_login(request: Request, username: str):
async def track_login(request: "Request", username: str):
_CLIENT.track_data(
action="UserInfoRequested",
data={
Expand Down
54 changes: 0 additions & 54 deletions tests/server/commons/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,6 @@

import pytest
from argilla.server.commons.models import TaskType
from argilla.server.errors import (
EntityAlreadyExistsError,
EntityNotFoundError,
GenericServerError,
ServerError,
)
from argilla.server.schemas.datasets import Dataset
from argilla.utils import telemetry
from argilla.utils.telemetry import TelemetryClient, get_telemetry_client
from fastapi import Request
Expand Down Expand Up @@ -57,50 +50,3 @@ async def test_track_bulk(test_telemetry):

await telemetry.track_bulk(task=task, records=records)
test_telemetry.assert_called_once_with("LogRecordsRequested", {"task": task, "records": records})


@pytest.mark.asyncio
@pytest.mark.parametrize(
["error", "expected_event"],
[
(
EntityNotFoundError(name="mock-name", type="MockType"),
{
"accept-language": None,
"code": "argilla.api.errors::EntityNotFoundError",
"type": "MockType",
"user-agent": None,
},
),
(
EntityAlreadyExistsError(name="mock-name", type=Dataset, workspace="mock-workspace"),
{
"accept-language": None,
"code": "argilla.api.errors::EntityAlreadyExistsError",
"type": "Dataset",
"user-agent": None,
},
),
(
GenericServerError(RuntimeError("This is a mock error")),
{
"accept-language": None,
"code": "argilla.api.errors::GenericServerError",
"type": "builtins.RuntimeError",
"user-agent": None,
},
),
(
ServerError(),
{
"accept-language": None,
"code": "argilla.api.errors::ServerError",
"user-agent": None,
},
),
],
)
async def test_track_error(test_telemetry, error, expected_event):
await telemetry.track_error(error, request=mock_request)

test_telemetry.assert_called_once_with("ServerErrorFound", expected_event)
13 changes: 13 additions & 0 deletions tests/server/errors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
75 changes: 75 additions & 0 deletions tests/server/errors/test_api_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from argilla.server.errors.api_errors import APIErrorHandler
from argilla.server.errors.base_errors import (
EntityAlreadyExistsError,
EntityNotFoundError,
GenericServerError,
ServerError,
)
from argilla.server.schemas.datasets import Dataset
from fastapi import Request

mock_request = Request(scope={"type": "http", "headers": {}})


@pytest.mark.asyncio
class TestAPIErrorHandler:
@pytest.mark.asyncio
@pytest.mark.parametrize(
["error", "expected_event"],
[
(
EntityNotFoundError(name="mock-name", type="MockType"),
{
"accept-language": None,
"code": "argilla.api.errors::EntityNotFoundError",
"type": "MockType",
"user-agent": None,
},
),
(
EntityAlreadyExistsError(name="mock-name", type=Dataset, workspace="mock-workspace"),
{
"accept-language": None,
"code": "argilla.api.errors::EntityAlreadyExistsError",
"type": "Dataset",
"user-agent": None,
},
),
(
GenericServerError(RuntimeError("This is a mock error")),
{
"accept-language": None,
"code": "argilla.api.errors::GenericServerError",
"type": "builtins.RuntimeError",
"user-agent": None,
},
),
(
ServerError(),
{
"accept-language": None,
"code": "argilla.api.errors::ServerError",
"user-agent": None,
},
),
],
)
async def test_track_error(self, test_telemetry, error, expected_event):
await APIErrorHandler.track_error(error, request=mock_request)

test_telemetry.assert_called_once_with("ServerErrorFound", expected_event)

0 comments on commit d37ea7e

Please sign in to comment.