Skip to content

Commit

Permalink
Merge pull request #136 from Haidra-Org/main
Browse files Browse the repository at this point in the history
fix: correct logic for `KNOWN_CONTROLNET` check
  • Loading branch information
tazlin authored Feb 3, 2024
2 parents 4a69c3a + f443ae1 commit 525e419
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
8 changes: 6 additions & 2 deletions horde_sdk/ai_horde_api/apimodels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,12 @@ def control_type_must_be_known(cls, v: str | KNOWN_CONTROLNETS | None) -> str |
"""Ensure that the control type is in this list of supported control types."""
if v is None:
return None
if (isinstance(v, str) and v not in KNOWN_CONTROLNETS.__members__) or (not isinstance(v, KNOWN_CONTROLNETS)):
logger.warning(f"Unknown control type '{v}'. Is your SDK out of date or did the API change?")
if isinstance(v, KNOWN_CONTROLNETS):
return v
if v in KNOWN_CONTROLNETS.__members__:
return v

logger.warning(f"Unknown control type {v}. Is your SDK out of date or did the API change?")
return v


Expand Down
24 changes: 24 additions & 0 deletions tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
WorkerKudosDetails,
)
from horde_sdk.ai_horde_api.consts import (
KNOWN_CONTROLNETS,
KNOWN_FACEFIXERS,
KNOWN_SAMPLERS,
KNOWN_SOURCE_PROCESSING,
Expand Down Expand Up @@ -375,6 +376,29 @@ def test_ImageGenerateJobPopResponse() -> None:
),
skipped=ImageGenerateJobPopSkippedStatus(),
)
test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
payload=ImageGenerateJobPopPayload(
post_processing=["unknown post processor"],
control_type=KNOWN_CONTROLNETS.canny,
sampler_name="unknown sampler",
prompt="A cat in a hat",
),
skipped=ImageGenerateJobPopSkippedStatus(),
)
test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
payload=ImageGenerateJobPopPayload(
post_processing=["unknown post processor"],
control_type="canny",
sampler_name="unknown sampler",
prompt="A cat in a hat",
),
skipped=ImageGenerateJobPopSkippedStatus(),
)

test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
Expand Down

0 comments on commit 525e419

Please sign in to comment.