Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support more API fields; be more tolerant of API model value changes #115

Merged
merged 11 commits into from
Jan 13, 2024
Merged
1 change: 1 addition & 0 deletions .github/workflows/maintests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
env:
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1
TESTS_ONGOING: 1
runs-on: ubuntu-latest
strategy:
matrix:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/prtests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
build:
env:
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
TESTS_ONGOING: 1
HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1
runs-on: ubuntu-latest
strategy:
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.11.0
rev: 23.12.1
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.5
rev: v0.1.13
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.6.1'
rev: 'v1.8.0'
hooks:
- id: mypy
additional_dependencies: [pydantic, types-requests, types-pytz, types-setuptools, types-urllib3, StrEnum]
2 changes: 1 addition & 1 deletion docs/build_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def api_to_sdk_map_create_markdown() -> None:
for http_status_code, _sdk_response_type in http_status_code_map.items():
f.write(
f"| {api_endpoint} | {http_status_code} | "
"[{sdk_response_type.split('.')[-1]}][{sdk_response_type}] |\n",
f"[{_sdk_response_type.split('.')[-1]}][{_sdk_response_type}] |\n",
)


Expand Down
8 changes: 8 additions & 0 deletions docs/response_field_names_and_descriptions.json
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@
]
],
"FindUserResponse": [
[
"admin_comment",
"(Privileged) Comments from the horde admins about this user."
],
[
"account_age",
"How many seconds since this account was created."
Expand Down Expand Up @@ -198,6 +202,10 @@
"sharedkey_ids",
null
],
[
"service",
"This user is a Horde service account and can provide the `proxied_user` field."
],
[
"special",
"(Privileged) This user has been given the Special role."
Expand Down
9 changes: 9 additions & 0 deletions horde_sdk/ai_horde_api/apimodels/_find_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class FindUserResponse(HordeResponseBaseModel):
def get_api_model_name(cls) -> str | None:
return "UserDetails"

admin_comment: str | None = Field(
default=None,
description="(Privileged) Comments from the horde admins about this user.",
)
account_age: int | None = Field(
default=None,
description="How many seconds since this account was created.",
Expand Down Expand Up @@ -126,6 +130,11 @@ def get_api_model_name(cls) -> str | None:
"""How many images, texts, megapixelsteps and tokens this user has generated or requested."""
sharedkey_ids: list[str] | None = None
"""The IDs of the shared keys this user has access to."""
service: bool | None = Field(
default=None,
description="This user is a Horde service account and can provide the `proxied_user` field.",
examples=[False],
)
special: bool | None = Field(
default=None,
description="(Privileged) This user has been given the Special role.",
Expand Down
9 changes: 8 additions & 1 deletion horde_sdk/ai_horde_api/apimodels/alchemy/_async.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import urllib.parse

from loguru import logger
from pydantic import BaseModel, field_validator
from typing_extensions import override

Expand Down Expand Up @@ -63,7 +64,13 @@ def get_follow_up_failure_cleanup_request_type(cls) -> type[AlchemyDeleteRequest


class AlchemyAsyncRequestFormItem(BaseModel):
name: KNOWN_ALCHEMY_TYPES
name: KNOWN_ALCHEMY_TYPES | str

@field_validator("name")
def check_name(cls, v: KNOWN_ALCHEMY_TYPES | str) -> KNOWN_ALCHEMY_TYPES | str:
if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__:
logger.warning(f"Unknown alchemy form name {v}. Is your SDK out of date or did the API change?")
return v


class AlchemyAsyncRequest(
Expand Down
11 changes: 9 additions & 2 deletions horde_sdk/ai_horde_api/apimodels/alchemy/_pop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from loguru import logger
from pydantic import Field
from pydantic import Field, field_validator
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels.alchemy._submit import AlchemyJobSubmitRequest
Expand Down Expand Up @@ -43,11 +43,18 @@ class AlchemyPopFormPayload(HordeAPIObject, JobRequestMixin):
def get_api_model_name(cls) -> str | None:
return "InterrogationPopFormPayload"

form: KNOWN_ALCHEMY_TYPES = Field(
form: KNOWN_ALCHEMY_TYPES | str = Field(
None,
description="The name of this interrogation form",
examples=["caption"],
)

@field_validator("form", mode="before")
def validate_form(cls, v: str | KNOWN_ALCHEMY_TYPES) -> KNOWN_ALCHEMY_TYPES | str:
if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__:
logger.warning(f"Unknown form type {v}")
return v

payload: AlchemyFormPayloadStable | None = None
r2_upload: str | None = Field(None, description="The URL in which the post-processed image can be uploaded.")
source_image: str | None = Field(None, description="The URL From which the source image can be downloaded.")
Expand Down
8 changes: 7 additions & 1 deletion horde_sdk/ai_horde_api/apimodels/alchemy/_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,16 @@ class AlchemyInterrogationResult(BaseModel):
class AlchemyFormStatus(BaseModel):
"""Represents the status of a form in an interrogation job."""

form: KNOWN_ALCHEMY_TYPES
form: KNOWN_ALCHEMY_TYPES | str
state: GENERATION_STATE
result: AlchemyInterrogationDetails | AlchemyNSFWResult | AlchemyCaptionResult | AlchemyUpscaleResult | None = None

@field_validator("form", mode="before")
def validate_form(cls, v: str | KNOWN_ALCHEMY_TYPES) -> KNOWN_ALCHEMY_TYPES | str:
if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__:
logger.warning(f"Unknown form type {v}. Is your SDK out of date or did the API change?")
return v

@property
def done(self) -> bool:
"""Return whether the form is done."""
Expand Down
23 changes: 20 additions & 3 deletions horde_sdk/ai_horde_api/apimodels/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The base classes for all AI Horde API requests/responses."""
from __future__ import annotations

import os
import random
import uuid

Expand Down Expand Up @@ -112,7 +113,9 @@ class ImageGenerateParamMixin(BaseModel):
v2 API Model: `ModelPayloadStable`
"""

model_config = ConfigDict(frozen=True) # , extra="forbid")
model_config = (
ConfigDict(frozen=True) if not os.getenv("TESTS_ONGOING") else ConfigDict(frozen=True, extra="forbid")
)

sampler_name: KNOWN_SAMPLERS | str = KNOWN_SAMPLERS.k_lms
"""The sampler to use for this generation. Defaults to `KNOWN_SAMPLERS.k_lms`."""
Expand Down Expand Up @@ -208,9 +211,23 @@ class GenMetadataEntry(BaseModel):
v2 API Model: `GenerationMetadataStable`
"""

type_: METADATA_TYPE = Field(alias="type")
type_: METADATA_TYPE | str = Field(alias="type")
"""The relevance of the metadata field."""
value: METADATA_VALUE = Field()
value: METADATA_VALUE | str = Field()
"""The value of the metadata field."""
ref: str | None = Field(default=None, max_length=255)
"""Optionally a reference for the metadata (e.g. a lora ID)"""

@field_validator("type_")
def validate_type(cls, v: str | METADATA_TYPE) -> str | METADATA_TYPE:
"""Ensure that the type is in this list of supported types."""
if v not in METADATA_TYPE.__members__:
logger.warning(f"Unknown metadata type {v}. Is your SDK out of date or did the API change?")
return v

@field_validator("value")
def validate_value(cls, v: str | METADATA_VALUE) -> str | METADATA_VALUE:
"""Ensure that the value is in this list of supported values."""
if v not in METADATA_VALUE.__members__:
logger.warning(f"Unknown metadata value {v}. Is your SDK out of date or did the API change?")
return v
2 changes: 2 additions & 0 deletions horde_sdk/ai_horde_api/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class METADATA_TYPE(StrEnum):
censorship = auto()
source_image = auto()
source_mask = auto()
batch_index = auto()


class METADATA_VALUE(StrEnum):
Expand All @@ -207,3 +208,4 @@ class METADATA_VALUE(StrEnum):
baseline_mismatch = auto()
csam = auto()
nsfw = auto()
see_ref = auto()
5 changes: 4 additions & 1 deletion horde_sdk/generic_api/apimodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import abc
import os
import uuid

from loguru import logger
Expand Down Expand Up @@ -46,7 +47,9 @@ class HordeResponse(HordeAPIMessage):


class HordeResponseBaseModel(HordeResponse, BaseModel):
model_config = ConfigDict(frozen=True) # , extra="forbid")
model_config = (
ConfigDict(frozen=True) if not os.getenv("TESTS_ONGOING") else ConfigDict(frozen=True, extra="forbid")
)


class ResponseRequiringFollowUpMixin(abc.ABC):
Expand Down
2 changes: 1 addition & 1 deletion horde_sdk/ratings_api/apimodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class SelectableReturnFormats(StrEnum):
class BaseSelectableReturnTypeRequest(BaseModel):
"""Mix-in class to describe an endpoint for which you can select the return data format."""

format: SelectableReturnFormats # noqa: A003
format: SelectableReturnFormats
"""The format to request the response payload in, typically json."""


Expand Down
2 changes: 1 addition & 1 deletion horde_sdk/ratings_api/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class RatingsAPIQueryFields(GenericQueryFields):
artifacts = auto()
artifacts_comparison = auto()
min_ratings = auto()
format = "format" # type: ignore # noqa: A003 (shadows 'format' built-in)
format = "format" # type: ignore

minutes = auto()

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ exclude = ["codegen"]

[tool.ruff.per-file-ignores]
"__init__.py" = ["E402"]
"conftest.py" = ["E402"]

[tool.black]
line-length = 119
Expand Down
12 changes: 6 additions & 6 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pytest==7.4.3
mypy==1.6.1
black==23.11.0
ruff==0.1.5
tox~=4.11.3
pre-commit~=3.5.0
pytest==7.4.4
mypy==1.8.0
black==23.12.1
ruff==0.1.12
tox~=4.12.0
pre-commit~=3.6.0
build>=0.10.0
coverage>=7.2.7

Expand Down
22 changes: 21 additions & 1 deletion tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
FindUserResponse,
UsageDetails,
)
from horde_sdk.ai_horde_api.apimodels.base import GenMetadataEntry
from horde_sdk.ai_horde_api.apimodels.generate._async import (
ImageGenerateAsyncRequest,
ImageGenerationInputPayload,
Expand All @@ -15,7 +16,13 @@
WorkerDetailItem,
WorkerKudosDetails,
)
from horde_sdk.ai_horde_api.consts import KNOWN_SAMPLERS, KNOWN_SOURCE_PROCESSING, WORKER_TYPE
from horde_sdk.ai_horde_api.consts import (
KNOWN_SAMPLERS,
KNOWN_SOURCE_PROCESSING,
METADATA_TYPE,
METADATA_VALUE,
WORKER_TYPE,
)


def test_api_endpoint() -> None:
Expand Down Expand Up @@ -278,3 +285,16 @@ def test_FindUserResponse() -> None:
worker_invited=False,
vpn=False,
)


def test_GenMetadataEntry() -> None:
GenMetadataEntry(
type=METADATA_TYPE.batch_index,
value=METADATA_VALUE.see_ref,
ref="1",
)

GenMetadataEntry(
type="test key",
value="test value",
)
8 changes: 5 additions & 3 deletions tests/ai_horde_api/test_ai_horde_generate_api_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
ImageGenerateAsyncRequest,
ImageGenerateAsyncResponse,
ImageGenerateStatusResponse,
ImageGeneration,
ImageGenerationInputPayload,
LorasPayloadEntry,
)
Expand Down Expand Up @@ -410,7 +409,10 @@ async def _submit_request() -> tuple[ImageGenerateStatusResponse, JobID] | None:

# Run 5 concurrent requests using asyncio
tasks = [asyncio.create_task(_submit_request()) for _ in range(5)]
all_generations: list[list[ImageGeneration]] = await asyncio.gather(*tasks, self.delayed_cancel(tasks[0]))
all_generations: list[tuple[ImageGenerateStatusResponse, JobID] | None] = await asyncio.gather(
*tasks,
self.delayed_cancel(tasks[0]),
)

# Check that all requests were successful
assert len([generations for generations in all_generations if generations]) == 4
Expand Down Expand Up @@ -439,7 +441,7 @@ async def submit_request() -> ImageGenerateStatusResponse | None:
# Run 5 concurrent requests using asyncio
tasks = [asyncio.create_task(submit_request()) for _ in range(5)]
cancel_tasks = [asyncio.create_task(self.delayed_cancel(task)) for task in tasks]
all_generations: list[ImageGenerateStatusResponse] = await asyncio.gather(*tasks, *cancel_tasks)
all_generations: list[ImageGenerateStatusResponse | None] = await asyncio.gather(*tasks, *cancel_tasks)

# Check that all requests were successful
assert len([generations for generations in all_generations if generations]) == 0
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@

import pytest

os.environ["TESTS_ONGOING"] = "1"

from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGenerationInputPayload
from horde_sdk.generic_api.consts import ANON_API_KEY


@pytest.fixture(scope="session", autouse=True)
def check_tests_ongoing_env_var() -> None:
"""Checks that the TESTS_ONGOING environment variable is set."""
assert os.getenv("TESTS_ONGOING", None) is not None, "TESTS_ONGOING environment variable not set"


@pytest.fixture(scope="session")
def ai_horde_api_key() -> str:
dev_key = os.getenv("AI_HORDE_DEV_APIKEY", None)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"prompt": "a",
"params": {
"sampler_name": "lcm",
"sampler_name": "k_dpm_fast",
"cfg_scale": 7.5,
"denoising_strength": 0.75,
"seed": "The little seed that could",
Expand Down Expand Up @@ -61,5 +61,6 @@
"shared": false,
"replacement_filter": true,
"dry_run": false,
"proxied_account": ""
"proxied_account": "",
"disable_batching": false
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"models": [
"aaa"
],
"bridge_agent": "AI Horde Worker:24:https://github.com/db0/AI-Horde-Worker",
"bridge_agent": "AI Horde Worker reGen:4.1.0:https://github.com/Haidra-Org/horde-worker-reGen",
"threads": 1,
"require_upfront_kudos": false,
"amount": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@
""
],
"dry_run": false,
"proxied_account": ""
"proxied_account": "",
"disable_batching": false
}
Loading