Skip to content

Commit

Permalink
Merge pull request #143 from Haidra-Org/main
Browse files Browse the repository at this point in the history
fix: `rc` field support for `RequestErrorResponse`; feat: better `__eq__` and `__hash__` implementations where appropriate
  • Loading branch information
tazlin authored Feb 17, 2024
2 parents 525e419 + 39eebd2 commit 43317ab
Show file tree
Hide file tree
Showing 42 changed files with 666 additions and 263 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.12.1
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.14
rev: v0.2.1
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-mypy
Expand Down
136 changes: 1 addition & 135 deletions docs/examples.md
Original file line number Diff line number Diff line change
@@ -1,137 +1,3 @@
# Example Clients

See `examples/` for a complete list. These examples are all made in mind with your current working directory as `horde_sdk` (e.g., `cd horde_sdk`).

## Simple Client (sync) Example
From `examples/ai_horde_client/aihorde_simple_client_example.py`:

``` python
from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPISimpleClient
from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGeneration


def simple_generate_example() -> None:
simple_client = AIHordeAPISimpleClient()

generations: list[ImageGeneration] = simple_client.image_generate_request(
ImageGenerateAsyncRequest(
apikey=ANON_API_KEY,
prompt="A cat in a hat",
models=["Deliberate"],
),
)

image = simple_client.generation_to_image(generations[0])

image.save("cat_in_hat.webp")

if __name__ == "__main__":
simple_generate_example()
```



```python
import argparse
import asyncio
from collections.abc import Coroutine
from pathlib import Path

import aiohttp
from PIL.Image import Image

from horde_sdk import ANON_API_KEY, RequestErrorResponse
from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIAsyncSimpleClient
from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGenerateStatusResponse
from horde_sdk.ai_horde_api.fields import JobID


async def async_one_image_generate_example(
simple_client: AIHordeAPIAsyncSimpleClient,
apikey: str = ANON_API_KEY,
) -> None:
single_generation_response: ImageGenerateStatusResponse
job_id: JobID

single_generation_response, job_id = await simple_client.image_generate_request(
ImageGenerateAsyncRequest(
apikey=apikey,
prompt="A cat in a hat",
models=["Deliberate"],
),
)

if isinstance(single_generation_response, RequestErrorResponse):
print(f"Error: {single_generation_response.message}")
else:
single_image, _ = await simple_client.download_image_from_generation(single_generation_response.generations[0])

example_path = Path("examples/requested_images")
example_path.mkdir(exist_ok=True, parents=True)

single_image.save(example_path / f"{job_id}_simple_async_example.webp")


async def async_multi_image_generate_example(
simple_client: AIHordeAPIAsyncSimpleClient,
apikey: str = ANON_API_KEY,
) -> None:
multi_generation_responses: tuple[
tuple[ImageGenerateStatusResponse, JobID],
tuple[ImageGenerateStatusResponse, JobID],
]
multi_generation_responses = await asyncio.gather(
simple_client.image_generate_request(
ImageGenerateAsyncRequest(
apikey=apikey,
prompt="A cat in a blue hat",
models=["Deliberate"],
),
),
simple_client.image_generate_request(
ImageGenerateAsyncRequest(
apikey=apikey,
prompt="A cat in a red hat",
models=["Deliberate"],
),
),
)

download_image_from_generation_calls: list[Coroutine[None, None, tuple[Image, JobID]]] = []

for status_response, _ in multi_generation_responses:
download_image_from_generation_calls.append(
simple_client.download_image_from_generation(status_response.generations[0]),
)

downloaded_images: list[tuple[Image, JobID]] = await asyncio.gather(*download_image_from_generation_calls)

example_path = Path("examples/requested_images")
example_path.mkdir(exist_ok=True, parents=True)

for image, job_id in downloaded_images:
image.save(example_path / f"{job_id}_simple_async_example.webp")


async def async_simple_generate_example(apikey: str = ANON_API_KEY) -> None:
async with aiohttp.ClientSession() as aiohttp_session:
simple_client = AIHordeAPIAsyncSimpleClient(aiohttp_session)

await async_one_image_generate_example(simple_client, apikey)
await async_multi_image_generate_example(simple_client, apikey)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="AI Horde API Manual Client Example")
parser.add_argument(
"--apikey",
type=str,
default=ANON_API_KEY,
help="The API key to use. Defaults to the anon key.",
)
args = parser.parse_args()

# Run the example.
asyncio.run(async_simple_generate_example(args.apikey))

```
See `examples/` (https://github.com/Haidra-Org/horde-sdk/tree/main/examples) for a complete list. These examples are all made in mind with your current working directory as `horde_sdk` (e.g., `cd horde_sdk`).
1 change: 1 addition & 0 deletions horde_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Any model or helper useful for creating or interacting with a horde API."""

# isort: off
# We import dotenv first so that we can use it to load environment variables before importing anything else.
import dotenv
Expand Down
1 change: 1 addition & 0 deletions horde_sdk/ai_horde_api/ai_horde_clients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Definitions to help interact with the AI-Horde API."""

from __future__ import annotations

import asyncio
Expand Down
1 change: 1 addition & 0 deletions horde_sdk/ai_horde_api/apimodels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""All requests, responses and API models defined for the AI Horde API."""

from horde_sdk.ai_horde_api.apimodels._find_user import (
ContributionsDetails,
FindUserRequest,
Expand Down
24 changes: 15 additions & 9 deletions horde_sdk/ai_horde_api/apimodels/_find_user.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from datetime import datetime

from pydantic import BaseModel, Field
from pydantic import Field
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels.base import BaseAIHordeRequest
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_ENDPOINT_SUBPATH
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import APIKeyAllowedInRequestMixin, HordeResponseBaseModel
from horde_sdk.generic_api.apimodels import APIKeyAllowedInRequestMixin, HordeAPIDataObject, HordeResponseBaseModel


class ContributionsDetails(BaseModel):
class ContributionsDetails(HordeAPIDataObject):
fulfillments: int | None = Field(default=None, description="How many images this user has generated.")
megapixelsteps: float | None = Field(default=None, description="How many megapixelsteps this user has generated.")


class UserKudosDetails(BaseModel):
class UserKudosDetails(HordeAPIDataObject):
accumulated: float | None = Field(0, description="The amount of Kudos accumulated or used for generating images.")
admin: float | None = Field(0, description="The amount of Kudos this user has been given by the Horde admins.")
awarded: float | None = Field(
Expand All @@ -29,33 +29,33 @@ class UserKudosDetails(BaseModel):
)


class MonthlyKudos(BaseModel):
class MonthlyKudos(HordeAPIDataObject):
amount: int | None = Field(default=None, description="How much recurring Kudos this user receives monthly.")
last_received: datetime | None = Field(default=None, description="Last date this user received monthly Kudos.")


class UserThingRecords(BaseModel):
class UserThingRecords(HordeAPIDataObject):
megapixelsteps: float | None = Field(
0,
description="How many megapixelsteps this user has generated or requested.",
)
tokens: int | None = Field(0, description="How many token this user has generated or requested.")


class UserAmountRecords(BaseModel):
class UserAmountRecords(HordeAPIDataObject):
image: int | None = Field(0, description="How many images this user has generated or requested.")
interrogation: int | None = Field(0, description="How many texts this user has generated or requested.")
text: int | None = Field(0, description="How many texts this user has generated or requested.")


class UserRecords(BaseModel):
class UserRecords(HordeAPIDataObject):
contribution: UserThingRecords | None = None
fulfillment: UserAmountRecords | None = None
request: UserAmountRecords | None = None
usage: UserThingRecords | None = None


class UsageDetails(BaseModel):
class UsageDetails(HordeAPIDataObject):
megapixelsteps: float | None = Field(default=None, description="How many megapixelsteps this user has requested.")
requests: int | None = Field(default=None, description="How many images this user has requested.")

Expand Down Expand Up @@ -183,6 +183,12 @@ def get_api_model_name(cls) -> str | None:
"""Whether this user has been invited to join a worker to the horde and how many of them.
When 0, this user cannot add (new) workers to the horde."""

def __eq__(self, other: object) -> bool:
raise NotImplementedError("TODO")

def __hash__(self) -> int:
raise NotImplementedError("TODO")


class FindUserRequest(BaseAIHordeRequest, APIKeyAllowedInRequestMixin):
@override
Expand Down
6 changes: 6 additions & 0 deletions horde_sdk/ai_horde_api/apimodels/_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def get_timeframe(self, timeframe: StatsModelsTimeframe) -> dict[str, int]:

raise ValueError(f"Invalid timeframe: {timeframe}")

def __eq__(self, other: object) -> bool:
raise NotImplementedError("Cannot compare StatsModelsResponse objects")

def __hash__(self) -> int:
raise NotImplementedError("Cannot hash StatsModelsResponse objects")


class StatsImageModelsRequest(BaseAIHordeRequest):
"""Represents the data needed to make a request to the `/v2/stats/img/models` endpoint."""
Expand Down
11 changes: 6 additions & 5 deletions horde_sdk/ai_horde_api/apimodels/alchemy/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import urllib.parse

from loguru import logger
from pydantic import BaseModel, field_validator
from pydantic import field_validator
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels.alchemy._status import AlchemyDeleteRequest, AlchemyStatusRequest
Expand All @@ -17,6 +17,7 @@
from horde_sdk.generic_api.apimodels import (
APIKeyAllowedInRequestMixin,
ContainsMessageResponseMixin,
HordeAPIDataObject,
HordeResponse,
HordeResponseBaseModel,
ResponseRequiringFollowUpMixin,
Expand Down Expand Up @@ -63,14 +64,14 @@ def get_follow_up_failure_cleanup_request_type(cls) -> type[AlchemyDeleteRequest
return AlchemyDeleteRequest


class AlchemyAsyncRequestFormItem(BaseModel):
class AlchemyAsyncRequestFormItem(HordeAPIDataObject):
name: KNOWN_ALCHEMY_TYPES | str

@field_validator("name")
def check_name(cls, v: KNOWN_ALCHEMY_TYPES | str) -> KNOWN_ALCHEMY_TYPES | str:
if (isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__) or (
not isinstance(v, KNOWN_ALCHEMY_TYPES)
):
if isinstance(v, KNOWN_ALCHEMY_TYPES):
return v
if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__:
logger.warning(f"Unknown alchemy form name {v}. Is your SDK out of date or did the API change?")
return v

Expand Down
35 changes: 34 additions & 1 deletion horde_sdk/ai_horde_api/apimodels/alchemy/_pop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from loguru import logger
from pydantic import Field, field_validator
from pydantic import Field, field_validator, model_validator
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels.alchemy._submit import AlchemyJobSubmitRequest
Expand Down Expand Up @@ -51,6 +53,8 @@ def get_api_model_name(cls) -> str | None:

@field_validator("form", mode="before")
def validate_form(cls, v: str | KNOWN_ALCHEMY_TYPES) -> KNOWN_ALCHEMY_TYPES | str:
if isinstance(v, KNOWN_ALCHEMY_TYPES):
return v
if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__:
logger.warning(f"Unknown form type {v}")
return v
Expand Down Expand Up @@ -130,12 +134,41 @@ def get_follow_up_returned_params(self, *, as_python_field_name: bool = False) -

return all_ids

@model_validator(mode="after")
def coerce_list_order(self) -> AlchemyPopResponse:
if self.forms is not None:
logger.debug("Sorting forms by id")
self.forms.sort(key=lambda form: form.id_)

return self

@override
@classmethod
def get_follow_up_request_types(cls) -> list[type[AlchemyJobSubmitRequest]]: # type: ignore[override]
"""Return a list of all the possible follow up request types for this response."""
return [AlchemyJobSubmitRequest]

def __eq__(self, other: object) -> bool:
if not isinstance(other, AlchemyPopResponse):
return False

forms_match = True
skipped_match = True

if self.forms is not None and other.forms is not None:
forms_match = all(form in other.forms for form in self.forms)

if self.skipped is not None:
skipped_match = self.skipped == other.skipped

return forms_match and skipped_match

def __hash__(self) -> int:
if self.forms is None:
return hash(self.skipped)

return hash((tuple([form.id_ for form in self.forms]), self.skipped))


class AlchemyPopRequest(BaseAIHordeRequest, APIKeyAllowedInRequestMixin):
"""Represents the data needed to make a request to the `/v2/interrogate/pop` endpoint.
Expand Down
Loading

0 comments on commit 43317ab

Please sign in to comment.