Skip to content

Commit

Permalink
fix: show warnings on unknown control_type or post_processors
Browse files Browse the repository at this point in the history
- also prints the randomly generated seed if used
  • Loading branch information
tazlin committed Jan 23, 2024
1 parent 221ab43 commit 69926b7
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
49 changes: 44 additions & 5 deletions horde_sdk/ai_horde_api/apimodels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from typing_extensions import override

from horde_sdk.ai_horde_api.consts import KNOWN_SAMPLERS, METADATA_TYPE, METADATA_VALUE, POST_PROCESSOR_ORDER_TYPE
from horde_sdk.ai_horde_api.consts import (
KNOWN_CONTROLNETS,
KNOWN_FACEFIXERS,
KNOWN_MISC_POST_PROCESSORS,
KNOWN_SAMPLERS,
KNOWN_UPSCALERS,
METADATA_TYPE,
METADATA_VALUE,
POST_PROCESSOR_ORDER_TYPE,
)
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_BASE_URL
from horde_sdk.ai_horde_api.fields import JobID, WorkerID
from horde_sdk.generic_api.apimodels import HordeRequest, HordeResponseBaseModel
Expand Down Expand Up @@ -131,7 +140,9 @@ class ImageGenerateParamMixin(BaseModel):
"""The desired output image width."""
seed_variation: int | None = Field(default=None, ge=1, le=1000)
"""Deprecated."""
post_processing: list[str] = Field(default_factory=list)
post_processing: list[str | KNOWN_UPSCALERS | KNOWN_FACEFIXERS | KNOWN_MISC_POST_PROCESSORS] = Field(
default_factory=list,
)
"""A list of post-processing models to use."""
post_processing_order: POST_PROCESSOR_ORDER_TYPE = POST_PROCESSOR_ORDER_TYPE.facefixers_first
"""The order in which to apply post-processing models.
Expand All @@ -144,7 +155,7 @@ class ImageGenerateParamMixin(BaseModel):
"""Set to True if you want to use the hires fix."""
clip_skip: int = Field(default=1, ge=1, le=12)
"""The number of clip layers to skip."""
control_type: str | None = None
control_type: str | KNOWN_CONTROLNETS | None = None
"""The type of control net type to use."""
image_is_control: bool | None = None
"""Set to True if the image is a control image."""
Expand Down Expand Up @@ -185,8 +196,36 @@ def sampler_name_must_be_known(cls, v: str | KNOWN_SAMPLERS) -> str | KNOWN_SAMP
def random_seed_if_none(cls, v: str | None) -> str | None:
"""If the seed is None, generate a random seed."""
if v is None:
logger.debug("Generating random seed")
return str(random.randint(1, 1000000000))
random_seed = str(random.randint(1, 1000000000))
logger.debug(f"Using random seed ({random_seed})")
return random_seed

return v

@field_validator("post_processing")
def post_processors_must_be_known(
cls,
v: list[str | KNOWN_UPSCALERS | KNOWN_FACEFIXERS | KNOWN_MISC_POST_PROCESSORS],
) -> list[str | KNOWN_UPSCALERS | KNOWN_FACEFIXERS | KNOWN_MISC_POST_PROCESSORS]:
"""Ensure that the post processors are in this list of supported post processors."""
for post_processor in v:
if (
post_processor not in KNOWN_UPSCALERS.__members__
and post_processor not in KNOWN_FACEFIXERS.__members__
and post_processor not in KNOWN_MISC_POST_PROCESSORS.__members__
):
logger.warning(
f"Unknown post processor {post_processor}. Is your SDK out of date or did the API change?",
)
return v

@field_validator("control_type")
def control_type_must_be_known(cls, v: str | KNOWN_CONTROLNETS | None) -> str | KNOWN_CONTROLNETS | None:
"""Ensure that the control type is in this list of supported control types."""
if v is None:
return None
if v not in KNOWN_CONTROLNETS.__members__:
logger.warning(f"Unknown control type '{v}'. Is your SDK out of date or did the API change?")
return v


Expand Down
12 changes: 12 additions & 0 deletions tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,15 @@ def test_ImageGenerateJobPopResponse() -> None:

assert test_response.has_upscaler is True
assert test_response.has_facefixer is True

test_response = ImageGenerateJobPopResponse(
id=None,
ids=[JobID(root=UUID("00000000-0000-0000-0000-000000000000"))],
payload=ImageGenerateJobPopPayload(
post_processing=["unknown post processor"],
control_type="unknown control type",
sampler_name="unknown sampler",
prompt="A cat in a hat",
),
skipped=ImageGenerateJobPopSkippedStatus(),
)

0 comments on commit 69926b7

Please sign in to comment.