Skip to content

Commit

Permalink
Merge pull request #115 from Haidra-Org/main
Browse files Browse the repository at this point in the history
feat: support more API fields; be more tolerant of API model value changes
  • Loading branch information
tazlin authored Jan 13, 2024
2 parents 074d78a + 4732ee4 commit 8389e28
Show file tree
Hide file tree
Showing 28 changed files with 128 additions and 33 deletions.
1 change: 1 addition & 0 deletions .github/workflows/maintests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
env:
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1
TESTS_ONGOING: 1
runs-on: ubuntu-latest
strategy:
matrix:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/prtests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
build:
env:
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
TESTS_ONGOING: 1
HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1
runs-on: ubuntu-latest
strategy:
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.11.0
rev: 23.12.1
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.5
rev: v0.1.13
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.6.1'
rev: 'v1.8.0'
hooks:
- id: mypy
additional_dependencies: [pydantic, types-requests, types-pytz, types-setuptools, types-urllib3, StrEnum]
2 changes: 1 addition & 1 deletion docs/build_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def api_to_sdk_map_create_markdown() -> None:
for http_status_code, _sdk_response_type in http_status_code_map.items():
f.write(
f"| {api_endpoint} | {http_status_code} | "
"[{sdk_response_type.split('.')[-1]}][{sdk_response_type}] |\n",
f"[{_sdk_response_type.split('.')[-1]}][{_sdk_response_type}] |\n",
)


Expand Down
8 changes: 8 additions & 0 deletions docs/response_field_names_and_descriptions.json
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@
]
],
"FindUserResponse": [
[
"admin_comment",
"(Privileged) Comments from the horde admins about this user."
],
[
"account_age",
"How many seconds since this account was created."
Expand Down Expand Up @@ -198,6 +202,10 @@
"sharedkey_ids",
null
],
[
"service",
"This user is a Horde service account and can provide the `proxied_user` field."
],
[
"special",
"(Privileged) This user has been given the Special role."
Expand Down
9 changes: 9 additions & 0 deletions horde_sdk/ai_horde_api/apimodels/_find_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class FindUserResponse(HordeResponseBaseModel):
def get_api_model_name(cls) -> str | None:
return "UserDetails"

admin_comment: str | None = Field(
default=None,
description="(Privileged) Comments from the horde admins about this user.",
)
account_age: int | None = Field(
default=None,
description="How many seconds since this account was created.",
Expand Down Expand Up @@ -126,6 +130,11 @@ def get_api_model_name(cls) -> str | None:
"""How many images, texts, megapixelsteps and tokens this user has generated or requested."""
sharedkey_ids: list[str] | None = None
"""The IDs of the shared keys this user has access to."""
service: bool | None = Field(
default=None,
description="This user is a Horde service account and can provide the `proxied_user` field.",
examples=[False],
)
special: bool | None = Field(
default=None,
description="(Privileged) This user has been given the Special role.",
Expand Down
9 changes: 8 additions & 1 deletion horde_sdk/ai_horde_api/apimodels/alchemy/_async.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import urllib.parse

from loguru import logger
from pydantic import BaseModel, field_validator
from typing_extensions import override

Expand Down Expand Up @@ -63,7 +64,13 @@ def get_follow_up_failure_cleanup_request_type(cls) -> type[AlchemyDeleteRequest


class AlchemyAsyncRequestFormItem(BaseModel):
name: KNOWN_ALCHEMY_TYPES
name: KNOWN_ALCHEMY_TYPES | str

@field_validator("name")
def check_name(cls, v: KNOWN_ALCHEMY_TYPES | str) -> KNOWN_ALCHEMY_TYPES | str:
if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__:
logger.warning(f"Unknown alchemy form name {v}. Is your SDK out of date or did the API change?")
return v


class AlchemyAsyncRequest(
Expand Down
11 changes: 9 additions & 2 deletions horde_sdk/ai_horde_api/apimodels/alchemy/_pop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from loguru import logger
from pydantic import Field
from pydantic import Field, field_validator
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels.alchemy._submit import AlchemyJobSubmitRequest
Expand Down Expand Up @@ -43,11 +43,18 @@ class AlchemyPopFormPayload(HordeAPIObject, JobRequestMixin):
def get_api_model_name(cls) -> str | None:
return "InterrogationPopFormPayload"

form: KNOWN_ALCHEMY_TYPES = Field(
form: KNOWN_ALCHEMY_TYPES | str = Field(
None,
description="The name of this interrogation form",
examples=["caption"],
)

@field_validator("form", mode="before")
def validate_form(cls, v: str | KNOWN_ALCHEMY_TYPES) -> KNOWN_ALCHEMY_TYPES | str:
if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__:
logger.warning(f"Unknown form type {v}")
return v

payload: AlchemyFormPayloadStable | None = None
r2_upload: str | None = Field(None, description="The URL in which the post-processed image can be uploaded.")
source_image: str | None = Field(None, description="The URL From which the source image can be downloaded.")
Expand Down
8 changes: 7 additions & 1 deletion horde_sdk/ai_horde_api/apimodels/alchemy/_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,16 @@ class AlchemyInterrogationResult(BaseModel):
class AlchemyFormStatus(BaseModel):
"""Represents the status of a form in an interrogation job."""

form: KNOWN_ALCHEMY_TYPES
form: KNOWN_ALCHEMY_TYPES | str
state: GENERATION_STATE
result: AlchemyInterrogationDetails | AlchemyNSFWResult | AlchemyCaptionResult | AlchemyUpscaleResult | None = None

@field_validator("form", mode="before")
def validate_form(cls, v: str | KNOWN_ALCHEMY_TYPES) -> KNOWN_ALCHEMY_TYPES | str:
if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__:
logger.warning(f"Unknown form type {v}. Is your SDK out of date or did the API change?")
return v

@property
def done(self) -> bool:
"""Return whether the form is done."""
Expand Down
23 changes: 20 additions & 3 deletions horde_sdk/ai_horde_api/apimodels/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The base classes for all AI Horde API requests/responses."""
from __future__ import annotations

import os
import random
import uuid

Expand Down Expand Up @@ -112,7 +113,9 @@ class ImageGenerateParamMixin(BaseModel):
v2 API Model: `ModelPayloadStable`
"""

model_config = ConfigDict(frozen=True) # , extra="forbid")
model_config = (
ConfigDict(frozen=True) if not os.getenv("TESTS_ONGOING") else ConfigDict(frozen=True, extra="forbid")
)

sampler_name: KNOWN_SAMPLERS | str = KNOWN_SAMPLERS.k_lms
"""The sampler to use for this generation. Defaults to `KNOWN_SAMPLERS.k_lms`."""
Expand Down Expand Up @@ -208,9 +211,23 @@ class GenMetadataEntry(BaseModel):
v2 API Model: `GenerationMetadataStable`
"""

type_: METADATA_TYPE = Field(alias="type")
type_: METADATA_TYPE | str = Field(alias="type")
"""The relevance of the metadata field."""
value: METADATA_VALUE = Field()
value: METADATA_VALUE | str = Field()
"""The value of the metadata field."""
ref: str | None = Field(default=None, max_length=255)
"""Optionally a reference for the metadata (e.g. a lora ID)"""

@field_validator("type_")
def validate_type(cls, v: str | METADATA_TYPE) -> str | METADATA_TYPE:
"""Ensure that the type is in this list of supported types."""
if v not in METADATA_TYPE.__members__:
logger.warning(f"Unknown metadata type {v}. Is your SDK out of date or did the API change?")
return v

@field_validator("value")
def validate_value(cls, v: str | METADATA_VALUE) -> str | METADATA_VALUE:
"""Ensure that the value is in this list of supported values."""
if v not in METADATA_VALUE.__members__:
logger.warning(f"Unknown metadata value {v}. Is your SDK out of date or did the API change?")
return v
2 changes: 2 additions & 0 deletions horde_sdk/ai_horde_api/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class METADATA_TYPE(StrEnum):
censorship = auto()
source_image = auto()
source_mask = auto()
batch_index = auto()


class METADATA_VALUE(StrEnum):
Expand All @@ -207,3 +208,4 @@ class METADATA_VALUE(StrEnum):
baseline_mismatch = auto()
csam = auto()
nsfw = auto()
see_ref = auto()
5 changes: 4 additions & 1 deletion horde_sdk/generic_api/apimodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import abc
import os
import uuid

from loguru import logger
Expand Down Expand Up @@ -46,7 +47,9 @@ class HordeResponse(HordeAPIMessage):


class HordeResponseBaseModel(HordeResponse, BaseModel):
model_config = ConfigDict(frozen=True) # , extra="forbid")
model_config = (
ConfigDict(frozen=True) if not os.getenv("TESTS_ONGOING") else ConfigDict(frozen=True, extra="forbid")
)


class ResponseRequiringFollowUpMixin(abc.ABC):
Expand Down
2 changes: 1 addition & 1 deletion horde_sdk/ratings_api/apimodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class SelectableReturnFormats(StrEnum):
class BaseSelectableReturnTypeRequest(BaseModel):
"""Mix-in class to describe an endpoint for which you can select the return data format."""

format: SelectableReturnFormats # noqa: A003
format: SelectableReturnFormats
"""The format to request the response payload in, typically json."""


Expand Down
2 changes: 1 addition & 1 deletion horde_sdk/ratings_api/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class RatingsAPIQueryFields(GenericQueryFields):
artifacts = auto()
artifacts_comparison = auto()
min_ratings = auto()
format = "format" # type: ignore # noqa: A003 (shadows 'format' built-in)
format = "format" # type: ignore

minutes = auto()

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ exclude = ["codegen"]

[tool.ruff.per-file-ignores]
"__init__.py" = ["E402"]
"conftest.py" = ["E402"]

[tool.black]
line-length = 119
Expand Down
12 changes: 6 additions & 6 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pytest==7.4.3
mypy==1.6.1
black==23.11.0
ruff==0.1.5
tox~=4.11.3
pre-commit~=3.5.0
pytest==7.4.4
mypy==1.8.0
black==23.12.1
ruff==0.1.12
tox~=4.12.0
pre-commit~=3.6.0
build>=0.10.0
coverage>=7.2.7

Expand Down
22 changes: 21 additions & 1 deletion tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
FindUserResponse,
UsageDetails,
)
from horde_sdk.ai_horde_api.apimodels.base import GenMetadataEntry
from horde_sdk.ai_horde_api.apimodels.generate._async import (
ImageGenerateAsyncRequest,
ImageGenerationInputPayload,
Expand All @@ -15,7 +16,13 @@
WorkerDetailItem,
WorkerKudosDetails,
)
from horde_sdk.ai_horde_api.consts import KNOWN_SAMPLERS, KNOWN_SOURCE_PROCESSING, WORKER_TYPE
from horde_sdk.ai_horde_api.consts import (
KNOWN_SAMPLERS,
KNOWN_SOURCE_PROCESSING,
METADATA_TYPE,
METADATA_VALUE,
WORKER_TYPE,
)


def test_api_endpoint() -> None:
Expand Down Expand Up @@ -278,3 +285,16 @@ def test_FindUserResponse() -> None:
worker_invited=False,
vpn=False,
)


def test_GenMetadataEntry() -> None:
GenMetadataEntry(
type=METADATA_TYPE.batch_index,
value=METADATA_VALUE.see_ref,
ref="1",
)

GenMetadataEntry(
type="test key",
value="test value",
)
8 changes: 5 additions & 3 deletions tests/ai_horde_api/test_ai_horde_generate_api_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
ImageGenerateAsyncRequest,
ImageGenerateAsyncResponse,
ImageGenerateStatusResponse,
ImageGeneration,
ImageGenerationInputPayload,
LorasPayloadEntry,
)
Expand Down Expand Up @@ -410,7 +409,10 @@ async def _submit_request() -> tuple[ImageGenerateStatusResponse, JobID] | None:

# Run 5 concurrent requests using asyncio
tasks = [asyncio.create_task(_submit_request()) for _ in range(5)]
all_generations: list[list[ImageGeneration]] = await asyncio.gather(*tasks, self.delayed_cancel(tasks[0]))
all_generations: list[tuple[ImageGenerateStatusResponse, JobID] | None] = await asyncio.gather(
*tasks,
self.delayed_cancel(tasks[0]),
)

# Check that all requests were successful
assert len([generations for generations in all_generations if generations]) == 4
Expand Down Expand Up @@ -439,7 +441,7 @@ async def submit_request() -> ImageGenerateStatusResponse | None:
# Run 5 concurrent requests using asyncio
tasks = [asyncio.create_task(submit_request()) for _ in range(5)]
cancel_tasks = [asyncio.create_task(self.delayed_cancel(task)) for task in tasks]
all_generations: list[ImageGenerateStatusResponse] = await asyncio.gather(*tasks, *cancel_tasks)
all_generations: list[ImageGenerateStatusResponse | None] = await asyncio.gather(*tasks, *cancel_tasks)

# Check that all requests were successful
assert len([generations for generations in all_generations if generations]) == 0
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@

import pytest

os.environ["TESTS_ONGOING"] = "1"

from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGenerationInputPayload
from horde_sdk.generic_api.consts import ANON_API_KEY


@pytest.fixture(scope="session", autouse=True)
def check_tests_ongoing_env_var() -> None:
"""Checks that the TESTS_ONGOING environment variable is set."""
assert os.getenv("TESTS_ONGOING", None) is not None, "TESTS_ONGOING environment variable not set"


@pytest.fixture(scope="session")
def ai_horde_api_key() -> str:
dev_key = os.getenv("AI_HORDE_DEV_APIKEY", None)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"prompt": "a",
"params": {
"sampler_name": "lcm",
"sampler_name": "k_dpm_fast",
"cfg_scale": 7.5,
"denoising_strength": 0.75,
"seed": "The little seed that could",
Expand Down Expand Up @@ -61,5 +61,6 @@
"shared": false,
"replacement_filter": true,
"dry_run": false,
"proxied_account": ""
"proxied_account": "",
"disable_batching": false
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"models": [
"aaa"
],
"bridge_agent": "AI Horde Worker:24:https://github.com/db0/AI-Horde-Worker",
"bridge_agent": "AI Horde Worker reGen:4.1.0:https://github.com/Haidra-Org/horde-worker-reGen",
"threads": 1,
"require_upfront_kudos": false,
"amount": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@
""
],
"dry_run": false,
"proxied_account": ""
"proxied_account": "",
"disable_batching": false
}
Loading

0 comments on commit 8389e28

Please sign in to comment.