Skip to content

Commit

Permalink
Merge pull request #111 from Haidra-Org/main
Browse files Browse the repository at this point in the history
fix: allow arbitrary sampler/upscaler names
  • Loading branch information
tazlin authored Jan 8, 2024
2 parents cad05af + f7c34f2 commit f7db5f3
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
2 changes: 1 addition & 1 deletion horde_sdk/ai_horde_api/apimodels/alchemy/_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class AlchemyUpscaleResult(BaseModel):
"""Represents the result of an upscale job."""

upscaler_used: KNOWN_UPSCALERS
upscaler_used: KNOWN_UPSCALERS | str
url: str


Expand Down
4 changes: 2 additions & 2 deletions horde_sdk/ai_horde_api/apimodels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class ImageGenerateParamMixin(BaseModel):

model_config = ConfigDict(frozen=True) # , extra="forbid")

sampler_name: KNOWN_SAMPLERS = KNOWN_SAMPLERS.k_lms
sampler_name: KNOWN_SAMPLERS | str = KNOWN_SAMPLERS.k_lms
"""The sampler to use for this generation. Defaults to `KNOWN_SAMPLERS.k_lms`."""
cfg_scale: float = 7.5
"""The cfg_scale to use for this generation. Defaults to 7.5."""
Expand Down Expand Up @@ -169,7 +169,7 @@ def width_divisible_by_64(cls, value: int) -> int:
def sampler_name_must_be_known(cls, v: str | KNOWN_SAMPLERS) -> str | KNOWN_SAMPLERS:
"""Ensure that the sampler name is in this list of supported samplers."""
if v not in KNOWN_SAMPLERS.__members__:
raise ValueError(f"Unknown sampler name {v}")
logger.warning(f"Unknown sampler name {v}. Is your SDK out of date or did the API change?")
return v

# @model_validator(mode="after")
Expand Down
45 changes: 45 additions & 0 deletions tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,51 @@ def test_ImageGenerateAsyncRequest(ai_horde_api_key: str) -> None:
assert test_async_request.dry_run is False


def test_ImageGenerateAsyncRequest_unknown_sampler(ai_horde_api_key: str) -> None:
test_async_request = ImageGenerateAsyncRequest(
apikey=ai_horde_api_key,
models=["Deliberate"],
prompt="test prompt",
params=ImageGenerationInputPayload(
sampler_name="unknown sampler",
cfg_scale=7.5,
denoising_strength=1,
seed="123456789",
height=512,
width=512,
seed_variation=None,
post_processing=[],
karras=True,
tiling=False,
hires_fix=False,
clip_skip=1,
control_type=None,
image_is_control=None,
return_control_map=None,
facefixer_strength=None,
loras=[],
special={},
steps=25,
n=1,
use_nsfw_censor=False,
),
nsfw=True,
trusted_workers=False,
slow_workers=False,
workers=[],
censor_nsfw=False,
source_image="test source image (usually base64)",
source_processing=KNOWN_SOURCE_PROCESSING.txt2img,
source_mask="test source mask (usually base64)",
r2=True,
shared=False,
replacement_filter=True,
dry_run=False,
)
assert test_async_request.params is not None
assert test_async_request.params.sampler_name == "unknown sampler"


def test_TeamDetailsLite() -> None:
test_team_details_lite = TeamDetailsLite(
name="test team name",
Expand Down

0 comments on commit f7db5f3

Please sign in to comment.