Skip to content

Commit

Permalink
fix: handle enum validation better
Browse files Browse the repository at this point in the history
- Check for expected types more rigorously with enum values
- Fixes the corner case where the enum member name doesn't match the API name (as with `4x_AnimeSharp`)
  • Loading branch information
tazlin committed Jan 25, 2024
1 parent 69926b7 commit d4c7bde
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 10 deletions.
4 changes: 3 additions & 1 deletion horde_sdk/ai_horde_api/apimodels/alchemy/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ class AlchemyAsyncRequestFormItem(BaseModel):

@field_validator("name")
def check_name(cls, v: KNOWN_ALCHEMY_TYPES | str) -> KNOWN_ALCHEMY_TYPES | str:
if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__:
if (isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__) or (
not isinstance(v, KNOWN_ALCHEMY_TYPES)
):
logger.warning(f"Unknown alchemy form name {v}. Is your SDK out of date or did the API change?")
return v

Expand Down
4 changes: 3 additions & 1 deletion horde_sdk/ai_horde_api/apimodels/alchemy/_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ class AlchemyFormStatus(BaseModel):

@field_validator("form", mode="before")
def validate_form(cls, v: str | KNOWN_ALCHEMY_TYPES) -> KNOWN_ALCHEMY_TYPES | str:
if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__:
if (isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__) or (
not isinstance(v, KNOWN_ALCHEMY_TYPES)
):
logger.warning(f"Unknown form type {v}. Is your SDK out of date or did the API change?")
return v

Expand Down
18 changes: 10 additions & 8 deletions horde_sdk/ai_horde_api/apimodels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
METADATA_TYPE,
METADATA_VALUE,
POST_PROCESSOR_ORDER_TYPE,
_all_valid_post_processors_names_and_values,
)
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_BASE_URL
from horde_sdk.ai_horde_api.fields import JobID, WorkerID
Expand Down Expand Up @@ -182,8 +183,9 @@ def width_divisible_by_64(cls, value: int) -> int:
@field_validator("sampler_name")
def sampler_name_must_be_known(cls, v: str | KNOWN_SAMPLERS) -> str | KNOWN_SAMPLERS:
"""Ensure that the sampler name is in this list of supported samplers."""
if v not in KNOWN_SAMPLERS.__members__:
if (isinstance(v, str) and v not in KNOWN_SAMPLERS.__members__) or (not isinstance(v, KNOWN_SAMPLERS)):
logger.warning(f"Unknown sampler name {v}. Is your SDK out of date or did the API change?")

return v

# @model_validator(mode="after")
Expand All @@ -208,11 +210,11 @@ def post_processors_must_be_known(
v: list[str | KNOWN_UPSCALERS | KNOWN_FACEFIXERS | KNOWN_MISC_POST_PROCESSORS],
) -> list[str | KNOWN_UPSCALERS | KNOWN_FACEFIXERS | KNOWN_MISC_POST_PROCESSORS]:
"""Ensure that the post processors are in this list of supported post processors."""

_valid_types: list[type] = [str, KNOWN_UPSCALERS, KNOWN_FACEFIXERS, KNOWN_MISC_POST_PROCESSORS]
for post_processor in v:
if (
post_processor not in KNOWN_UPSCALERS.__members__
and post_processor not in KNOWN_FACEFIXERS.__members__
and post_processor not in KNOWN_MISC_POST_PROCESSORS.__members__
if post_processor not in _all_valid_post_processors_names_and_values or (
type(post_processor) not in _valid_types
):
logger.warning(
f"Unknown post processor {post_processor}. Is your SDK out of date or did the API change?",
Expand All @@ -224,7 +226,7 @@ def control_type_must_be_known(cls, v: str | KNOWN_CONTROLNETS | None) -> str |
"""Ensure that the control type is in this list of supported control types."""
if v is None:
return None
if v not in KNOWN_CONTROLNETS.__members__:
if (isinstance(v, str) and v not in KNOWN_CONTROLNETS.__members__) or (not isinstance(v, KNOWN_CONTROLNETS)):
logger.warning(f"Unknown control type '{v}'. Is your SDK out of date or did the API change?")
return v

Expand Down Expand Up @@ -260,13 +262,13 @@ class GenMetadataEntry(BaseModel):
@field_validator("type_")
def validate_type(cls, v: str | METADATA_TYPE) -> str | METADATA_TYPE:
"""Ensure that the type is in this list of supported types."""
if v not in METADATA_TYPE.__members__:
if (isinstance(v, str) and v not in METADATA_TYPE.__members__) or (not isinstance(v, METADATA_TYPE)):
logger.warning(f"Unknown metadata type {v}. Is your SDK out of date or did the API change?")
return v

@field_validator("value")
def validate_value(cls, v: str | METADATA_VALUE) -> str | METADATA_VALUE:
"""Ensure that the value is in this list of supported values."""
if v not in METADATA_VALUE.__members__:
if (isinstance(v, str) and v not in METADATA_VALUE.__members__) or (not isinstance(v, METADATA_VALUE)):
logger.warning(f"Unknown metadata value {v}. Is your SDK out of date or did the API change?")
return v
12 changes: 12 additions & 0 deletions horde_sdk/ai_horde_api/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,18 @@ class KNOWN_MISC_POST_PROCESSORS(StrEnum):
strip_background = auto()


_all_valid_post_processors_names_and_values = (
list(KNOWN_UPSCALERS.__members__.keys())
+ list(KNOWN_UPSCALERS.__members__.values())
+ list(KNOWN_FACEFIXERS.__members__.keys())
+ list(KNOWN_FACEFIXERS.__members__.values())
+ list(KNOWN_MISC_POST_PROCESSORS.__members__.keys())
+ list(KNOWN_MISC_POST_PROCESSORS.__members__.values())
)
"""Used to validate post processor names and values. \
This is because some post processor names are not valid python variable names."""


class POST_PROCESSOR_ORDER_TYPE(StrEnum):
"""The post processor order types that are known to the API.
Expand Down
19 changes: 19 additions & 0 deletions tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,22 @@ def test_ImageGenerateJobPopResponse() -> None:
),
skipped=ImageGenerateJobPopSkippedStatus(),
)
test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
payload=ImageGenerateJobPopPayload(
post_processing=["4x_AnimeSharp"],
prompt="A cat in a hat",
),
skipped=ImageGenerateJobPopSkippedStatus(),
)

test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
payload=ImageGenerateJobPopPayload(
post_processing=[KNOWN_UPSCALERS.four_4x_AnimeSharp],
prompt="A cat in a hat",
),
skipped=ImageGenerateJobPopSkippedStatus(),
)

0 comments on commit d4c7bde

Please sign in to comment.