diff --git a/horde_sdk/ai_horde_api/apimodels/base.py b/horde_sdk/ai_horde_api/apimodels/base.py index ae5f291..7b66216 100644 --- a/horde_sdk/ai_horde_api/apimodels/base.py +++ b/horde_sdk/ai_horde_api/apimodels/base.py @@ -228,8 +228,12 @@ def control_type_must_be_known(cls, v: str | KNOWN_CONTROLNETS | None) -> str | """Ensure that the control type is in this list of supported control types.""" if v is None: return None - if (isinstance(v, str) and v not in KNOWN_CONTROLNETS.__members__) or (not isinstance(v, KNOWN_CONTROLNETS)): - logger.warning(f"Unknown control type '{v}'. Is your SDK out of date or did the API change?") + if isinstance(v, KNOWN_CONTROLNETS): + return v + if v in KNOWN_CONTROLNETS.__members__: + return v + + logger.warning(f"Unknown control type {v}. Is your SDK out of date or did the API change?") return v diff --git a/tests/ai_horde_api/test_ai_horde_api_models.py b/tests/ai_horde_api/test_ai_horde_api_models.py index eb18f1e..2f8c258 100644 --- a/tests/ai_horde_api/test_ai_horde_api_models.py +++ b/tests/ai_horde_api/test_ai_horde_api_models.py @@ -24,6 +24,7 @@ WorkerKudosDetails, ) from horde_sdk.ai_horde_api.consts import ( + KNOWN_CONTROLNETS, KNOWN_FACEFIXERS, KNOWN_SAMPLERS, KNOWN_SOURCE_PROCESSING, @@ -375,6 +376,29 @@ def test_ImageGenerateJobPopResponse() -> None: ), skipped=ImageGenerateJobPopSkippedStatus(), ) + test_response = ImageGenerateJobPopResponse( + id=None, + ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))], + payload=ImageGenerateJobPopPayload( + post_processing=["unknown post processor"], + control_type=KNOWN_CONTROLNETS.canny, + sampler_name="unknown sampler", + prompt="A cat in a hat", + ), + skipped=ImageGenerateJobPopSkippedStatus(), + ) + test_response = ImageGenerateJobPopResponse( + id=None, + ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))], + payload=ImageGenerateJobPopPayload( + post_processing=["unknown post processor"], + control_type="canny", + sampler_name="unknown sampler", + prompt="A cat in a hat", + ), + skipped=ImageGenerateJobPopSkippedStatus(), + ) + test_response = ImageGenerateJobPopResponse( id=None, ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],