From 0cd3498609e2233b05275675eaf8d298dd2ebc5d Mon Sep 17 00:00:00 2001 From: tazlin Date: Mon, 8 Jan 2024 11:27:52 -0500 Subject: [PATCH 1/2] fix: allow arbitrary sampler/upscaler names (enables tolerating API changes) --- horde_sdk/ai_horde_api/apimodels/alchemy/_status.py | 2 +- horde_sdk/ai_horde_api/apimodels/base.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py b/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py index 6ea30dc..f53fc6b 100644 --- a/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py +++ b/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py @@ -18,7 +18,7 @@ class AlchemyUpscaleResult(BaseModel): """Represents the result of an upscale job.""" - upscaler_used: KNOWN_UPSCALERS + upscaler_used: KNOWN_UPSCALERS | str url: str diff --git a/horde_sdk/ai_horde_api/apimodels/base.py b/horde_sdk/ai_horde_api/apimodels/base.py index 551ca8d..9536b31 100644 --- a/horde_sdk/ai_horde_api/apimodels/base.py +++ b/horde_sdk/ai_horde_api/apimodels/base.py @@ -114,7 +114,7 @@ class ImageGenerateParamMixin(BaseModel): model_config = ConfigDict(frozen=True) # , extra="forbid") - sampler_name: KNOWN_SAMPLERS = KNOWN_SAMPLERS.k_lms + sampler_name: KNOWN_SAMPLERS | str = KNOWN_SAMPLERS.k_lms """The sampler to use for this generation. Defaults to `KNOWN_SAMPLERS.k_lms`.""" cfg_scale: float = 7.5 """The cfg_scale to use for this generation. Defaults to 7.5.""" @@ -169,7 +169,7 @@ def width_divisible_by_64(cls, value: int) -> int: def sampler_name_must_be_known(cls, v: str | KNOWN_SAMPLERS) -> str | KNOWN_SAMPLERS: """Ensure that the sampler name is in this list of supported samplers.""" if v not in KNOWN_SAMPLERS.__members__: - raise ValueError(f"Unknown sampler name {v}") + logger.warning(f"Unknown sampler name {v}. Is your SDK out of date or did the API change?") return v # @model_validator(mode="after") From f7c34f2db0e7638f0ac85d7665758a5920494ed6 Mon Sep 17 00:00:00 2001 From: tazlin Date: Mon, 8 Jan 2024 11:34:23 -0500 Subject: [PATCH 2/2] tests: check unknown samplers are allowed --- .../ai_horde_api/test_ai_horde_api_models.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/ai_horde_api/test_ai_horde_api_models.py b/tests/ai_horde_api/test_ai_horde_api_models.py index 6cd8aeb..b54e13f 100644 --- a/tests/ai_horde_api/test_ai_horde_api_models.py +++ b/tests/ai_horde_api/test_ai_horde_api_models.py @@ -105,6 +105,51 @@ def test_ImageGenerateAsyncRequest(ai_horde_api_key: str) -> None: assert test_async_request.dry_run is False +def test_ImageGenerateAsyncRequest_unknown_sampler(ai_horde_api_key: str) -> None: + test_async_request = ImageGenerateAsyncRequest( + apikey=ai_horde_api_key, + models=["Deliberate"], + prompt="test prompt", + params=ImageGenerationInputPayload( + sampler_name="unknown sampler", + cfg_scale=7.5, + denoising_strength=1, + seed="123456789", + height=512, + width=512, + seed_variation=None, + post_processing=[], + karras=True, + tiling=False, + hires_fix=False, + clip_skip=1, + control_type=None, + image_is_control=None, + return_control_map=None, + facefixer_strength=None, + loras=[], + special={}, + steps=25, + n=1, + use_nsfw_censor=False, + ), + nsfw=True, + trusted_workers=False, + slow_workers=False, + workers=[], + censor_nsfw=False, + source_image="test source image (usually base64)", + source_processing=KNOWN_SOURCE_PROCESSING.txt2img, + source_mask="test source mask (usually base64)", + r2=True, + shared=False, + replacement_filter=True, + dry_run=False, + ) + assert test_async_request.params is not None + assert test_async_request.params.sampler_name == "unknown sampler" + + def test_TeamDetailsLite() -> None: test_team_details_lite = TeamDetailsLite( name="test team name",