Skip to content

Commit

Permalink
feat: has_facefixer property for ImageGenerateJobPopResponse
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Jan 23, 2024
1 parent 9739ef8 commit 221ab43
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
15 changes: 14 additions & 1 deletion horde_sdk/ai_horde_api/apimodels/generate/_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
ImageGenerateParamMixin,
)
from horde_sdk.ai_horde_api.apimodels.generate._submit import ImageGenerationJobSubmitRequest
from horde_sdk.ai_horde_api.consts import GENERATION_STATE, KNOWN_SOURCE_PROCESSING, KNOWN_UPSCALERS
from horde_sdk.ai_horde_api.consts import (
GENERATION_STATE,
KNOWN_FACEFIXERS,
KNOWN_SOURCE_PROCESSING,
KNOWN_UPSCALERS,
)
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_ENDPOINT_SUBPATH
from horde_sdk.ai_horde_api.fields import JobID
from horde_sdk.consts import HTTPMethod
Expand Down Expand Up @@ -166,6 +171,14 @@ def has_upscaler(self) -> bool:

return any(post_processing in KNOWN_UPSCALERS.__members__ for post_processing in self.payload.post_processing)

@property
def has_facefixer(self) -> bool:
"""Whether or not this image generation has a facefixer."""
if len(self.payload.post_processing) == 0:
return False

return any(post_processing in KNOWN_FACEFIXERS.__members__ for post_processing in self.payload.post_processing)


class ImageGenerateJobPopRequest(BaseAIHordeRequest, APIKeyAllowedInRequestMixin):
"""Represents the data needed to make a job request from a worker to the /v2/generate/pop endpoint.
Expand Down
18 changes: 16 additions & 2 deletions tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def test_ImageGenerateJobPopResponse() -> None:

assert test_response.id_ is None
assert test_response.has_upscaler is True
assert test_response.has_facefixer is False

test_response = ImageGenerateJobPopResponse(
id=None,
Expand All @@ -333,8 +334,8 @@ def test_ImageGenerateJobPopResponse() -> None:
skipped=ImageGenerateJobPopSkippedStatus(),
)

assert test_response.id_ is None
assert test_response.has_upscaler is False
assert test_response.has_facefixer is False

test_response = ImageGenerateJobPopResponse(
id=None,
Expand All @@ -346,5 +347,18 @@ def test_ImageGenerateJobPopResponse() -> None:
skipped=ImageGenerateJobPopSkippedStatus(),
)

assert test_response.id_ is None
assert test_response.has_upscaler is False
assert test_response.has_facefixer is True

test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
payload=ImageGenerateJobPopPayload(
post_processing=[KNOWN_FACEFIXERS.CodeFormers, KNOWN_UPSCALERS.RealESRGAN_x2plus],
prompt="A cat in a hat",
),
skipped=ImageGenerateJobPopSkippedStatus(),
)

assert test_response.has_upscaler is True
assert test_response.has_facefixer is True

0 comments on commit 221ab43

Please sign in to comment.