Skip to content

Commit

Permalink
Merge pull request #156 from Haidra-Org/main
Browse files Browse the repository at this point in the history
feat: more model meta load instructions
  • Loading branch information
tazlin authored Mar 5, 2024
2 parents 8d20902 + 39653a7 commit 15900fb
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 9 deletions.
3 changes: 3 additions & 0 deletions horde_sdk/ai_horde_api/ai_horde_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,7 @@ async def image_generate_request(
image_gen_request: ImageGenerateAsyncRequest,
timeout: int = GENERATION_MAX_LIFE,
check_callback: Callable[[ImageGenerateCheckResponse], None] | None = None,
delay: float = 0.0,
) -> tuple[ImageGenerateStatusResponse, JobID]:
"""Submit an image generation request to the AI-Horde API, and wait for it to complete.
Expand All @@ -1044,6 +1045,8 @@ async def image_generate_request(
AIHordeRequestError: If the request failed. The error response is included in the exception.
"""

await asyncio.sleep(delay)

timeout = self.validate_timeout(timeout, log_message=True)

n = image_gen_request.params.n if image_gen_request.params and image_gen_request.params.n else 1
Expand Down
6 changes: 3 additions & 3 deletions horde_sdk/ai_horde_api/apimodels/alchemy/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class AlchemyAsyncRequestFormItem(HordeAPIDataObject):
def check_name(cls, v: KNOWN_ALCHEMY_TYPES | str) -> KNOWN_ALCHEMY_TYPES | str:
if isinstance(v, KNOWN_ALCHEMY_TYPES):
return v
if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__:
if str(v) not in KNOWN_ALCHEMY_TYPES.__members__:
logger.warning(f"Unknown alchemy form name {v}. Is your SDK out of date or did the API change?")
return v

Expand All @@ -83,8 +83,8 @@ class AlchemyAsyncRequest(
forms: list[AlchemyAsyncRequestFormItem]
source_image: str
"""The public URL of the source image or a base64 string to use."""
slow_workers: bool = False
"""Whether to use the slower workers. Costs additional kudos if `True`."""
slow_workers: bool = True
"""Whether to use the slower workers. Costs additional kudos if `False`."""

@field_validator("forms")
def check_at_least_one_form(cls, v: list[AlchemyAsyncRequestFormItem]) -> list[AlchemyAsyncRequestFormItem]:
Expand Down
4 changes: 1 addition & 3 deletions horde_sdk/ai_horde_api/apimodels/alchemy/_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ class AlchemyFormStatus(HordeAPIDataObject):
def validate_form(cls, v: str | KNOWN_ALCHEMY_TYPES) -> KNOWN_ALCHEMY_TYPES | str:
if isinstance(v, KNOWN_ALCHEMY_TYPES):
return v
if (isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__) or (
not isinstance(v, KNOWN_ALCHEMY_TYPES)
):
if str(v) not in KNOWN_ALCHEMY_TYPES.__members__:
logger.warning(f"Unknown form type {v}. Is your SDK out of date or did the API change?")
return v

Expand Down
30 changes: 29 additions & 1 deletion horde_sdk/ai_horde_worker/bridge_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@


class MetaInstruction(StrEnum):
ALL_REGEX = r"all$|all models+$"
ALL_REGEX = r"all$|all models?$"

ALL_SDXL_REGEX = r"all sdxl$|all sdxl models?$"
ALL_SD15_REGEX = r"all sd15$|all sd15 models?$"
ALL_SD21_REGEX = r"all sd21$|all sd21 models?$"

ALL_SFW_REGEX = r"all sfw$|all sfw models?$"
ALL_NSFW_REGEX = r"all nsfw$|all nsfw models?$"

ALL_INPAINTING_REGEX = r"all inpainting$|all inpainting models?$"

TOP_N_REGEX = r"TOP (\d+)"
"""The regex to use to match the top N models. The number is in a capture group on its own."""
Expand Down Expand Up @@ -296,6 +305,7 @@ def validate_model(self) -> ImageWorkerBridgeData:
return self

_meta_load_instructions: list[str] | None = None
_meta_skip_instructions: list[str] | None = None

@property
def meta_load_instructions(self) -> list[str] | None:
Expand All @@ -315,6 +325,24 @@ def handle_meta_instructions(self) -> ImageWorkerBridgeData:

return self

@property
def meta_skip_instructions(self) -> list[str] | None:
"""The meta skip instructions."""
return self._meta_skip_instructions

@model_validator(mode="after")
def handle_meta_skip_instructions(self) -> ImageWorkerBridgeData:
# See if any entries are meta instructions, and if so, remove them and place them in _meta_skip_instructions
for instruction_regex in MetaInstruction.__members__.values():
for i, model in enumerate(self.image_models_to_skip):
if re.match(instruction_regex, model, re.IGNORECASE):
if self._meta_skip_instructions is None:
self._meta_skip_instructions = []
self._meta_skip_instructions.append(model)
self.image_models_to_skip.pop(i)

return self

@field_validator("image_models_to_load")
def validate_models_to_load(cls, v: list) -> list:
"""Validate and parse the models to load."""
Expand Down
138 changes: 138 additions & 0 deletions horde_sdk/ai_horde_worker/model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY
from horde_model_reference.model_reference_manager import ModelReferenceManager
from horde_model_reference.model_reference_records import StableDiffusion_ModelRecord
from loguru import logger

from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIManualClient
Expand Down Expand Up @@ -80,6 +81,43 @@ def resolve_meta_instructions(
found_bottom_n = True
continue

if ImageModelLoadResolver.meta_instruction_regex_match(
MetaInstruction.ALL_SDXL_REGEX,
possible_instruction,
):
return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_xl"))

if ImageModelLoadResolver.meta_instruction_regex_match(
MetaInstruction.ALL_SD15_REGEX,
possible_instruction,
):
return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_1"))

if ImageModelLoadResolver.meta_instruction_regex_match(
MetaInstruction.ALL_SD21_REGEX,
possible_instruction,
):
return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_2_512"))
return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_2_768"))

if ImageModelLoadResolver.meta_instruction_regex_match(
MetaInstruction.ALL_INPAINTING_REGEX,
possible_instruction,
):
return_list.extend(self.resolve_all_inpainting_models())

if ImageModelLoadResolver.meta_instruction_regex_match(
MetaInstruction.ALL_SFW_REGEX,
possible_instruction,
):
return_list.extend(self.resolve_all_sfw_model_names())

if ImageModelLoadResolver.meta_instruction_regex_match(
MetaInstruction.ALL_NSFW_REGEX,
possible_instruction,
):
return_list.extend(self.resolve_all_nsfw_model_names())

# If no valid meta instruction were found, return None
return set(return_list)

Expand Down Expand Up @@ -112,6 +150,106 @@ def resolve_all_model_names(self) -> set[str]:
logger.error("No stable diffusion models found in model reference.")
return set()

def _resolve_sfw_nsfw_model_names(self, nsfw: bool) -> set[str]:
"""Get the names of all SFW or NSFW models defined in the model reference.
Args:
nsfw: A boolean representing whether to get SFW or NSFW models.
Returns:
A set of strings representing the names of all SFW or NSFW models.
"""
all_model_references = self._model_reference_manager.get_all_model_references()

sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion]

found_models: set[str] = set()

if sd_model_references is None:
logger.error("No stable diffusion models found in model reference.")
return found_models

for model in sd_model_references.root.values():
if not isinstance(model, StableDiffusion_ModelRecord):
logger.error(f"Model {model} is not a StableDiffusion_ModelRecord")
continue

if model.nsfw == nsfw:
found_models.add(model.name)

return found_models

def resolve_all_sfw_model_names(self) -> set[str]:
"""Get the names of all SFW models defined in the model reference.
Returns:
A set of strings representing the names of all SFW models.
"""
return self._resolve_sfw_nsfw_model_names(nsfw=False)

def resolve_all_nsfw_model_names(self) -> set[str]:
"""Get the names of all NSFW models defined in the model reference.
Returns:
A set of strings representing the names of all NSFW models.
"""
return self._resolve_sfw_nsfw_model_names(nsfw=True)

def resolve_all_inpainting_models(self) -> set[str]:
"""Get the names of all inpainting models defined in the model reference.
Returns:
A set of strings representing the names of all inpainting models.
"""
all_model_references = self._model_reference_manager.get_all_model_references()

sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion]

found_models: set[str] = set()

if sd_model_references is None:
logger.error("No stable diffusion models found in model reference.")
return found_models

for model in sd_model_references.root.values():
if not isinstance(model, StableDiffusion_ModelRecord):
logger.error(f"Model {model} is not a StableDiffusion_ModelRecord")
continue

if model.inpainting:
found_models.add(model.name)

return found_models

def resolve_all_models_of_baseline(self, baseline: str) -> set[str]:
"""Get the names of all models of a given baseline defined in the model reference.
Args:
baseline: A string representing the baseline to get models for.
Returns:
A set of strings representing the names of all models of the given baseline.
"""
all_model_references = self._model_reference_manager.get_all_model_references()

sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion]

found_models: set[str] = set()

if sd_model_references is None:
logger.error("No stable diffusion models found in model reference.")
return found_models

for model in sd_model_references.root.values():
if not isinstance(model, StableDiffusion_ModelRecord):
logger.error(f"Model {model} is not a StableDiffusion_ModelRecord")
continue

if model.baseline == baseline:
found_models.add(model.name)

return found_models

@staticmethod
def resolve_top_n_model_names(
number_of_top_models: int,
Expand Down
2 changes: 1 addition & 1 deletion horde_sdk/generic_api/apimodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ class RequestUsesImageWorkerMixin(BaseModel):
"""Mix-in class to describe an endpoint for which you can specify workers."""

trusted_workers: bool = False
slow_workers: bool = False
slow_workers: bool = True
workers: list[str] = Field(default_factory=list)
worker_blacklist: list[str] = Field(default_factory=list)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
horde_model_reference~=0.6.1
horde_model_reference~=0.6.3

pydantic
requests
Expand Down
83 changes: 83 additions & 0 deletions tests/ai_horde_worker/test_model_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,81 @@ def test_image_model_load_resolver_multiple_instructions(
assert len(resolved_model_names) == 2


def test_image_model_load_resolved_all_sd15(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
resolved_model_names = image_model_load_resolver.resolve_meta_instructions(
["all sd15"],
AIHordeAPIManualClient(),
)

assert len(resolved_model_names) > 0

for model_name in resolved_model_names:
assert "SDXL" not in model_name

assert "Deliberate" in resolved_model_names


def test_image_model_load_resolved_all_sd21(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
resolved_model_names = image_model_load_resolver.resolve_meta_instructions(
["all sd21"],
AIHordeAPIManualClient(),
)

assert len(resolved_model_names) > 0

for model_name in resolved_model_names:
assert "SDXL" not in model_name
assert model_name != "Deliberate"


def test_image_model_load_resolved_all_sdxl(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
resolved_model_names = image_model_load_resolver.resolve_meta_instructions(
["all sdxl"],
AIHordeAPIManualClient(),
)

assert len(resolved_model_names) > 0
assert "AlbedoBase XL (SDXL)" in resolved_model_names


def test_image_model_load_resolved_all_inpainting(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
resolved_model_names = image_model_load_resolver.resolve_meta_instructions(
["all inpainting"],
AIHordeAPIManualClient(),
)

assert len(resolved_model_names) > 0
assert any("inpainting" in model_name.lower() for model_name in resolved_model_names)


def test_image_model_load_resolved_sfw_nsfw(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
resolved_model_names = image_model_load_resolver.resolve_meta_instructions(
["all sfw"],
AIHordeAPIManualClient(),
)

assert len(resolved_model_names) > 0
assert not any("urpm" in model_name.lower() for model_name in resolved_model_names)

resolved_model_names = image_model_load_resolver.resolve_meta_instructions(
["all nsfw"],
AIHordeAPIManualClient(),
)

assert len(resolved_model_names) > 0
assert any("urpm" in model_name.lower() for model_name in resolved_model_names)


def test_image_models_unique_results_only(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
Expand All @@ -103,3 +178,11 @@ def test_image_models_unique_results_only(
all_model_names = image_model_load_resolver.resolve_all_model_names()

assert len(resolved_model_names) >= (len(all_model_names) - 1) # FIXME: -1 is to account for SDXL beta


def test_resolve_all_models_of_baseline(
image_model_load_resolver: ImageModelLoadResolver,
) -> None:
resolved_model_names = image_model_load_resolver.resolve_all_models_of_baseline("stable_diffusion_xl")

assert len(resolved_model_names) > 0

0 comments on commit 15900fb

Please sign in to comment.