Skip to content

Commit

Permalink
Merge pull request #125 from Haidra-Org/main
Browse files Browse the repository at this point in the history
feat: more props. for ImageGenerateJobPopResponse; show more warnings on unknown field values
  • Loading branch information
tazlin authored Jan 23, 2024
2 parents 35869ce + 69926b7 commit 9db4e2e
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repos:
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.13
rev: v0.1.14
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-mypy
Expand Down
5 changes: 5 additions & 0 deletions horde_sdk/ai_horde_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AIHordeRequestError,
AIHordeServerException,
)
from horde_sdk.ai_horde_api.fields import ImageID, JobID, TeamID, WorkerID

__all__ = [
"AIHordeAPIManualClient",
Expand All @@ -50,4 +51,8 @@
"AIHordeGenerationTimedOutError",
"AIHordeServerException",
"AIHordePayloadValidationError",
"ImageID",
"JobID",
"TeamID",
"WorkerID",
]
49 changes: 44 additions & 5 deletions horde_sdk/ai_horde_api/apimodels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from typing_extensions import override

from horde_sdk.ai_horde_api.consts import KNOWN_SAMPLERS, METADATA_TYPE, METADATA_VALUE, POST_PROCESSOR_ORDER_TYPE
from horde_sdk.ai_horde_api.consts import (
KNOWN_CONTROLNETS,
KNOWN_FACEFIXERS,
KNOWN_MISC_POST_PROCESSORS,
KNOWN_SAMPLERS,
KNOWN_UPSCALERS,
METADATA_TYPE,
METADATA_VALUE,
POST_PROCESSOR_ORDER_TYPE,
)
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_BASE_URL
from horde_sdk.ai_horde_api.fields import JobID, WorkerID
from horde_sdk.generic_api.apimodels import HordeRequest, HordeResponseBaseModel
Expand Down Expand Up @@ -131,7 +140,9 @@ class ImageGenerateParamMixin(BaseModel):
"""The desired output image width."""
seed_variation: int | None = Field(default=None, ge=1, le=1000)
"""Deprecated."""
post_processing: list[str] = Field(default_factory=list)
post_processing: list[str | KNOWN_UPSCALERS | KNOWN_FACEFIXERS | KNOWN_MISC_POST_PROCESSORS] = Field(
default_factory=list,
)
"""A list of post-processing models to use."""
post_processing_order: POST_PROCESSOR_ORDER_TYPE = POST_PROCESSOR_ORDER_TYPE.facefixers_first
"""The order in which to apply post-processing models.
Expand All @@ -144,7 +155,7 @@ class ImageGenerateParamMixin(BaseModel):
"""Set to True if you want to use the hires fix."""
clip_skip: int = Field(default=1, ge=1, le=12)
"""The number of clip layers to skip."""
control_type: str | None = None
control_type: str | KNOWN_CONTROLNETS | None = None
"""The type of control net type to use."""
image_is_control: bool | None = None
"""Set to True if the image is a control image."""
Expand Down Expand Up @@ -185,8 +196,36 @@ def sampler_name_must_be_known(cls, v: str | KNOWN_SAMPLERS) -> str | KNOWN_SAMP
def random_seed_if_none(cls, v: str | None) -> str | None:
"""If the seed is None, generate a random seed."""
if v is None:
logger.debug("Generating random seed")
return str(random.randint(1, 1000000000))
random_seed = str(random.randint(1, 1000000000))
logger.debug(f"Using random seed ({random_seed})")
return random_seed

return v

@field_validator("post_processing")
def post_processors_must_be_known(
cls,
v: list[str | KNOWN_UPSCALERS | KNOWN_FACEFIXERS | KNOWN_MISC_POST_PROCESSORS],
) -> list[str | KNOWN_UPSCALERS | KNOWN_FACEFIXERS | KNOWN_MISC_POST_PROCESSORS]:
"""Ensure that the post processors are in this list of supported post processors."""
for post_processor in v:
if (
post_processor not in KNOWN_UPSCALERS.__members__
and post_processor not in KNOWN_FACEFIXERS.__members__
and post_processor not in KNOWN_MISC_POST_PROCESSORS.__members__
):
logger.warning(
f"Unknown post processor {post_processor}. Is your SDK out of date or did the API change?",
)
return v

@field_validator("control_type")
def control_type_must_be_known(cls, v: str | KNOWN_CONTROLNETS | None) -> str | KNOWN_CONTROLNETS | None:
"""Ensure that the control type is in this list of supported control types."""
if v is None:
return None
if v not in KNOWN_CONTROLNETS.__members__:
logger.warning(f"Unknown control type '{v}'. Is your SDK out of date or did the API change?")
return v


Expand Down
23 changes: 22 additions & 1 deletion horde_sdk/ai_horde_api/apimodels/generate/_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
ImageGenerateParamMixin,
)
from horde_sdk.ai_horde_api.apimodels.generate._submit import ImageGenerationJobSubmitRequest
from horde_sdk.ai_horde_api.consts import GENERATION_STATE, KNOWN_SOURCE_PROCESSING
from horde_sdk.ai_horde_api.consts import (
GENERATION_STATE,
KNOWN_FACEFIXERS,
KNOWN_SOURCE_PROCESSING,
KNOWN_UPSCALERS,
)
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_ENDPOINT_SUBPATH
from horde_sdk.ai_horde_api.fields import JobID
from horde_sdk.consts import HTTPMethod
Expand Down Expand Up @@ -158,6 +163,22 @@ def ignore_failure(self) -> bool:

return super().ignore_failure()

@property
def has_upscaler(self) -> bool:
"""Whether or not this image generation has an upscaler."""
if len(self.payload.post_processing) == 0:
return False

return any(post_processing in KNOWN_UPSCALERS.__members__ for post_processing in self.payload.post_processing)

@property
def has_facefixer(self) -> bool:
"""Whether or not this image generation has a facefixer."""
if len(self.payload.post_processing) == 0:
return False

return any(post_processing in KNOWN_FACEFIXERS.__members__ for post_processing in self.payload.post_processing)


class ImageGenerateJobPopRequest(BaseAIHordeRequest, APIKeyAllowedInRequestMixin):
"""Represents the data needed to make a job request from a worker to the /v2/generate/pop endpoint.
Expand Down
4 changes: 2 additions & 2 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
pytest==7.4.4
mypy==1.8.0
black==23.12.1
ruff==0.1.12
tox~=4.12.0
ruff==0.1.14
tox~=4.12.1
pre-commit~=3.6.0
build>=0.10.0
coverage>=7.2.7
Expand Down
76 changes: 76 additions & 0 deletions tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Unit tests for AI-Horde API models."""
from uuid import UUID

from horde_sdk.ai_horde_api.apimodels._find_user import (
ContributionsDetails,
FindUserRequest,
Expand All @@ -10,19 +12,27 @@
ImageGenerateAsyncRequest,
ImageGenerationInputPayload,
)
from horde_sdk.ai_horde_api.apimodels.generate._pop import (
ImageGenerateJobPopPayload,
ImageGenerateJobPopResponse,
ImageGenerateJobPopSkippedStatus,
)
from horde_sdk.ai_horde_api.apimodels.workers._workers_all import (
AllWorkersDetailsResponse,
TeamDetailsLite,
WorkerDetailItem,
WorkerKudosDetails,
)
from horde_sdk.ai_horde_api.consts import (
KNOWN_FACEFIXERS,
KNOWN_SAMPLERS,
KNOWN_SOURCE_PROCESSING,
KNOWN_UPSCALERS,
METADATA_TYPE,
METADATA_VALUE,
WORKER_TYPE,
)
from horde_sdk.ai_horde_api.fields import JobID


def test_api_endpoint() -> None:
Expand Down Expand Up @@ -298,3 +308,69 @@ def test_GenMetadataEntry() -> None:
type="test key",
value="test value",
)


def test_ImageGenerateJobPopResponse() -> None:
test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
payload=ImageGenerateJobPopPayload(
post_processing=[KNOWN_UPSCALERS.RealESRGAN_x2plus],
prompt="A cat in a hat",
),
skipped=ImageGenerateJobPopSkippedStatus(),
)

assert test_response.id_ is None
assert test_response.has_upscaler is True
assert test_response.has_facefixer is False

test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
payload=ImageGenerateJobPopPayload(
prompt="A cat in a hat",
),
skipped=ImageGenerateJobPopSkippedStatus(),
)

assert test_response.has_upscaler is False
assert test_response.has_facefixer is False

test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
payload=ImageGenerateJobPopPayload(
post_processing=[KNOWN_FACEFIXERS.CodeFormers],
prompt="A cat in a hat",
),
skipped=ImageGenerateJobPopSkippedStatus(),
)

assert test_response.has_upscaler is False
assert test_response.has_facefixer is True

test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
payload=ImageGenerateJobPopPayload(
post_processing=[KNOWN_FACEFIXERS.CodeFormers, KNOWN_UPSCALERS.RealESRGAN_x2plus],
prompt="A cat in a hat",
),
skipped=ImageGenerateJobPopSkippedStatus(),
)

assert test_response.has_upscaler is True
assert test_response.has_facefixer is True

test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
payload=ImageGenerateJobPopPayload(
post_processing=["unknown post processor"],
control_type="unknown control type",
sampler_name="unknown sampler",
prompt="A cat in a hat",
),
skipped=ImageGenerateJobPopSkippedStatus(),
)

0 comments on commit 9db4e2e

Please sign in to comment.