From 22d0f09293ed36f32f4781322be80569169d2488 Mon Sep 17 00:00:00 2001 From: tazlin Date: Wed, 21 Feb 2024 10:51:49 -0500 Subject: [PATCH 1/3] feat: initial baseline model meta load command support --- horde_sdk/ai_horde_worker/model_meta.py | 30 ++++++++++++++++++++++++ tests/ai_horde_worker/test_model_meta.py | 8 +++++++ 2 files changed, 38 insertions(+) diff --git a/horde_sdk/ai_horde_worker/model_meta.py b/horde_sdk/ai_horde_worker/model_meta.py index e636f37..737deab 100644 --- a/horde_sdk/ai_horde_worker/model_meta.py +++ b/horde_sdk/ai_horde_worker/model_meta.py @@ -2,6 +2,7 @@ from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY from horde_model_reference.model_reference_manager import ModelReferenceManager +from horde_model_reference.model_reference_records import StableDiffusion_ModelRecord from loguru import logger from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIManualClient @@ -112,6 +113,35 @@ def resolve_all_model_names(self) -> set[str]: logger.error("No stable diffusion models found in model reference.") return set() + def resolve_all_models_of_baseline(self, baseline: str) -> set[str]: + """Get the names of all models of a given baseline defined in the model reference. + + Args: + baseline: A string representing the baseline to get models for. + + Returns: + A set of strings representing the names of all models of the given baseline. + """ + all_model_references = self._model_reference_manager.get_all_model_references() + + sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion] + + found_models: set[str] = set() + + if sd_model_references is None: + logger.error("No stable diffusion models found in model reference.") + return found_models + + for model in sd_model_references.root.values(): + if not isinstance(model, StableDiffusion_ModelRecord): + logger.error(f"Model {model} is not a StableDiffusion_ModelRecord") + continue + + if model.baseline == baseline: + found_models.add(model.name) + + return found_models + @staticmethod def resolve_top_n_model_names( number_of_top_models: int, diff --git a/tests/ai_horde_worker/test_model_meta.py b/tests/ai_horde_worker/test_model_meta.py index 9bff8ab..11e0a14 100644 --- a/tests/ai_horde_worker/test_model_meta.py +++ b/tests/ai_horde_worker/test_model_meta.py @@ -103,3 +103,11 @@ def test_image_models_unique_results_only( all_model_names = image_model_load_resolver.resolve_all_model_names() assert len(resolved_model_names) >= (len(all_model_names) - 1) # FIXME: -1 is to account for SDXL beta + + +def test_resolve_all_models_of_baseline( + image_model_load_resolver: ImageModelLoadResolver, +) -> None: + resolved_model_names = image_model_load_resolver.resolve_all_models_of_baseline("stable_diffusion_xl") + + assert len(resolved_model_names) > 0 From bb8c072f754a732afc5fae2bd32c2ae9a12449c3 Mon Sep 17 00:00:00 2001 From: tazlin Date: Tue, 5 Mar 2024 08:06:41 -0500 Subject: [PATCH 2/3] fix: allow unknown alchemy types; fix `slow_workers` default --- horde_sdk/ai_horde_api/ai_horde_clients.py | 3 +++ horde_sdk/ai_horde_api/apimodels/alchemy/_async.py | 6 +++--- horde_sdk/ai_horde_api/apimodels/alchemy/_status.py | 4 +--- horde_sdk/generic_api/apimodels.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/horde_sdk/ai_horde_api/ai_horde_clients.py b/horde_sdk/ai_horde_api/ai_horde_clients.py index b832794..b549cf8 100644 --- a/horde_sdk/ai_horde_api/ai_horde_clients.py +++ b/horde_sdk/ai_horde_api/ai_horde_clients.py @@ -1025,6 +1025,7 @@ async def image_generate_request( image_gen_request: ImageGenerateAsyncRequest, timeout: int = GENERATION_MAX_LIFE, check_callback: Callable[[ImageGenerateCheckResponse], None] | None = None, + delay: float = 0.0, ) -> tuple[ImageGenerateStatusResponse, JobID]: """Submit an image generation request to the AI-Horde API, and wait for it to complete. @@ -1044,6 +1045,8 @@ async def image_generate_request( AIHordeRequestError: If the request failed. The error response is included in the exception. """ + await asyncio.sleep(delay) + timeout = self.validate_timeout(timeout, log_message=True) n = image_gen_request.params.n if image_gen_request.params and image_gen_request.params.n else 1 diff --git a/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py b/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py index 96cec3c..86d3c67 100644 --- a/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py +++ b/horde_sdk/ai_horde_api/apimodels/alchemy/_async.py @@ -71,7 +71,7 @@ class AlchemyAsyncRequestFormItem(HordeAPIDataObject): def check_name(cls, v: KNOWN_ALCHEMY_TYPES | str) -> KNOWN_ALCHEMY_TYPES | str: if isinstance(v, KNOWN_ALCHEMY_TYPES): return v - if isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__: + if str(v) not in KNOWN_ALCHEMY_TYPES.__members__: logger.warning(f"Unknown alchemy form name {v}. Is your SDK out of date or did the API change?") return v @@ -83,8 +83,8 @@ class AlchemyAsyncRequest( forms: list[AlchemyAsyncRequestFormItem] source_image: str """The public URL of the source image or a base64 string to use.""" - slow_workers: bool = False - """Whether to use the slower workers. Costs additional kudos if `True`.""" + slow_workers: bool = True + """Whether to use the slower workers. Costs additional kudos if `False`.""" @field_validator("forms") def check_at_least_one_form(cls, v: list[AlchemyAsyncRequestFormItem]) -> list[AlchemyAsyncRequestFormItem]: diff --git a/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py b/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py index d31d385..71d433c 100644 --- a/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py +++ b/horde_sdk/ai_horde_api/apimodels/alchemy/_status.py @@ -71,9 +71,7 @@ class AlchemyFormStatus(HordeAPIDataObject): def validate_form(cls, v: str | KNOWN_ALCHEMY_TYPES) -> KNOWN_ALCHEMY_TYPES | str: if isinstance(v, KNOWN_ALCHEMY_TYPES): return v - if (isinstance(v, str) and v not in KNOWN_ALCHEMY_TYPES.__members__) or ( - not isinstance(v, KNOWN_ALCHEMY_TYPES) - ): + if str(v) not in KNOWN_ALCHEMY_TYPES.__members__: logger.warning(f"Unknown form type {v}. Is your SDK out of date or did the API change?") return v diff --git a/horde_sdk/generic_api/apimodels.py b/horde_sdk/generic_api/apimodels.py index 61c6889..770acaf 100644 --- a/horde_sdk/generic_api/apimodels.py +++ b/horde_sdk/generic_api/apimodels.py @@ -375,7 +375,7 @@ class RequestUsesImageWorkerMixin(BaseModel): """Mix-in class to describe an endpoint for which you can specify workers.""" trusted_workers: bool = False - slow_workers: bool = False + slow_workers: bool = True workers: list[str] = Field(default_factory=list) worker_blacklist: list[str] = Field(default_factory=list) From 39653a762ce05202aea3c2dde0897bb054ed597c Mon Sep 17 00:00:00 2001 From: tazlin Date: Tue, 5 Mar 2024 09:47:29 -0500 Subject: [PATCH 3/3] feat: more model meta instructions; allow meta in skip models `ALL SDXL`, `ALL SD15`, `ALL SD21`, `ALL SFW`, `ALL NSFW` --- horde_sdk/ai_horde_worker/bridge_data.py | 30 ++++++- horde_sdk/ai_horde_worker/model_meta.py | 108 +++++++++++++++++++++++ requirements.txt | 2 +- tests/ai_horde_worker/test_model_meta.py | 75 ++++++++++++++++ 4 files changed, 213 insertions(+), 2 deletions(-) diff --git a/horde_sdk/ai_horde_worker/bridge_data.py b/horde_sdk/ai_horde_worker/bridge_data.py index d4be0ee..15a1abf 100644 --- a/horde_sdk/ai_horde_worker/bridge_data.py +++ b/horde_sdk/ai_horde_worker/bridge_data.py @@ -19,7 +19,16 @@ class MetaInstruction(StrEnum): - ALL_REGEX = r"all$|all models+$" + ALL_REGEX = r"all$|all models?$" + + ALL_SDXL_REGEX = r"all sdxl$|all sdxl models?$" + ALL_SD15_REGEX = r"all sd15$|all sd15 models?$" + ALL_SD21_REGEX = r"all sd21$|all sd21 models?$" + + ALL_SFW_REGEX = r"all sfw$|all sfw models?$" + ALL_NSFW_REGEX = r"all nsfw$|all nsfw models?$" + + ALL_INPAINTING_REGEX = r"all inpainting$|all inpainting models?$" TOP_N_REGEX = r"TOP (\d+)" """The regex to use to match the top N models. The number is in a capture group on its own.""" @@ -296,6 +305,7 @@ def validate_model(self) -> ImageWorkerBridgeData: return self _meta_load_instructions: list[str] | None = None + _meta_skip_instructions: list[str] | None = None @property def meta_load_instructions(self) -> list[str] | None: @@ -315,6 +325,24 @@ def handle_meta_instructions(self) -> ImageWorkerBridgeData: return self + @property + def meta_skip_instructions(self) -> list[str] | None: + """The meta skip instructions.""" + return self._meta_skip_instructions + + @model_validator(mode="after") + def handle_meta_skip_instructions(self) -> ImageWorkerBridgeData: + # See if any entries are meta instructions, and if so, remove them and place them in _meta_skip_instructions + for instruction_regex in MetaInstruction.__members__.values(): + for i, model in enumerate(self.image_models_to_skip): + if re.match(instruction_regex, model, re.IGNORECASE): + if self._meta_skip_instructions is None: + self._meta_skip_instructions = [] + self._meta_skip_instructions.append(model) + self.image_models_to_skip.pop(i) + + return self + @field_validator("image_models_to_load") def validate_models_to_load(cls, v: list) -> list: """Validate and parse the models to load.""" diff --git a/horde_sdk/ai_horde_worker/model_meta.py b/horde_sdk/ai_horde_worker/model_meta.py index 737deab..a2757fe 100644 --- a/horde_sdk/ai_horde_worker/model_meta.py +++ b/horde_sdk/ai_horde_worker/model_meta.py @@ -81,6 +81,43 @@ def resolve_meta_instructions( found_bottom_n = True continue + if ImageModelLoadResolver.meta_instruction_regex_match( + MetaInstruction.ALL_SDXL_REGEX, + possible_instruction, + ): + return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_xl")) + + if ImageModelLoadResolver.meta_instruction_regex_match( + MetaInstruction.ALL_SD15_REGEX, + possible_instruction, + ): + return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_1")) + + if ImageModelLoadResolver.meta_instruction_regex_match( + MetaInstruction.ALL_SD21_REGEX, + possible_instruction, + ): + return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_2_512")) + return_list.extend(self.resolve_all_models_of_baseline("stable_diffusion_2_768")) + + if ImageModelLoadResolver.meta_instruction_regex_match( + MetaInstruction.ALL_INPAINTING_REGEX, + possible_instruction, + ): + return_list.extend(self.resolve_all_inpainting_models()) + + if ImageModelLoadResolver.meta_instruction_regex_match( + MetaInstruction.ALL_SFW_REGEX, + possible_instruction, + ): + return_list.extend(self.resolve_all_sfw_model_names()) + + if ImageModelLoadResolver.meta_instruction_regex_match( + MetaInstruction.ALL_NSFW_REGEX, + possible_instruction, + ): + return_list.extend(self.resolve_all_nsfw_model_names()) + # If no valid meta instruction were found, return None return set(return_list) @@ -113,6 +150,77 @@ def resolve_all_model_names(self) -> set[str]: logger.error("No stable diffusion models found in model reference.") return set() + def _resolve_sfw_nsfw_model_names(self, nsfw: bool) -> set[str]: + """Get the names of all SFW or NSFW models defined in the model reference. + + Args: + nsfw: A boolean representing whether to get SFW or NSFW models. + + Returns: + A set of strings representing the names of all SFW or NSFW models. + """ + all_model_references = self._model_reference_manager.get_all_model_references() + + sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion] + + found_models: set[str] = set() + + if sd_model_references is None: + logger.error("No stable diffusion models found in model reference.") + return found_models + + for model in sd_model_references.root.values(): + if not isinstance(model, StableDiffusion_ModelRecord): + logger.error(f"Model {model} is not a StableDiffusion_ModelRecord") + continue + + if model.nsfw == nsfw: + found_models.add(model.name) + + return found_models + + def resolve_all_sfw_model_names(self) -> set[str]: + """Get the names of all SFW models defined in the model reference. + + Returns: + A set of strings representing the names of all SFW models. + """ + return self._resolve_sfw_nsfw_model_names(nsfw=False) + + def resolve_all_nsfw_model_names(self) -> set[str]: + """Get the names of all NSFW models defined in the model reference. + + Returns: + A set of strings representing the names of all NSFW models. + """ + return self._resolve_sfw_nsfw_model_names(nsfw=True) + + def resolve_all_inpainting_models(self) -> set[str]: + """Get the names of all inpainting models defined in the model reference. + + Returns: + A set of strings representing the names of all inpainting models. + """ + all_model_references = self._model_reference_manager.get_all_model_references() + + sd_model_references = all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion] + + found_models: set[str] = set() + + if sd_model_references is None: + logger.error("No stable diffusion models found in model reference.") + return found_models + + for model in sd_model_references.root.values(): + if not isinstance(model, StableDiffusion_ModelRecord): + logger.error(f"Model {model} is not a StableDiffusion_ModelRecord") + continue + + if model.inpainting: + found_models.add(model.name) + + return found_models + def resolve_all_models_of_baseline(self, baseline: str) -> set[str]: """Get the names of all models of a given baseline defined in the model reference. diff --git a/requirements.txt b/requirements.txt index da57b0c..20fd6bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -horde_model_reference~=0.6.1 +horde_model_reference~=0.6.3 pydantic requests diff --git a/tests/ai_horde_worker/test_model_meta.py b/tests/ai_horde_worker/test_model_meta.py index 11e0a14..d3deee3 100644 --- a/tests/ai_horde_worker/test_model_meta.py +++ b/tests/ai_horde_worker/test_model_meta.py @@ -93,6 +93,81 @@ def test_image_model_load_resolver_multiple_instructions( assert len(resolved_model_names) == 2 +def test_image_model_load_resolved_all_sd15( + image_model_load_resolver: ImageModelLoadResolver, +) -> None: + resolved_model_names = image_model_load_resolver.resolve_meta_instructions( + ["all sd15"], + AIHordeAPIManualClient(), + ) + + assert len(resolved_model_names) > 0 + + for model_name in resolved_model_names: + assert "SDXL" not in model_name + + assert "Deliberate" in resolved_model_names + + +def test_image_model_load_resolved_all_sd21( + image_model_load_resolver: ImageModelLoadResolver, +) -> None: + resolved_model_names = image_model_load_resolver.resolve_meta_instructions( + ["all sd21"], + AIHordeAPIManualClient(), + ) + + assert len(resolved_model_names) > 0 + + for model_name in resolved_model_names: + assert "SDXL" not in model_name + assert model_name != "Deliberate" + + +def test_image_model_load_resolved_all_sdxl( + image_model_load_resolver: ImageModelLoadResolver, +) -> None: + resolved_model_names = image_model_load_resolver.resolve_meta_instructions( + ["all sdxl"], + AIHordeAPIManualClient(), + ) + + assert len(resolved_model_names) > 0 + assert "AlbedoBase XL (SDXL)" in resolved_model_names + + +def test_image_model_load_resolved_all_inpainting( + image_model_load_resolver: ImageModelLoadResolver, +) -> None: + resolved_model_names = image_model_load_resolver.resolve_meta_instructions( + ["all inpainting"], + AIHordeAPIManualClient(), + ) + + assert len(resolved_model_names) > 0 + assert any("inpainting" in model_name.lower() for model_name in resolved_model_names) + + +def test_image_model_load_resolved_sfw_nsfw( + image_model_load_resolver: ImageModelLoadResolver, +) -> None: + resolved_model_names = image_model_load_resolver.resolve_meta_instructions( + ["all sfw"], + AIHordeAPIManualClient(), + ) + + assert len(resolved_model_names) > 0 + assert not any("urpm" in model_name.lower() for model_name in resolved_model_names) + + resolved_model_names = image_model_load_resolver.resolve_meta_instructions( + ["all nsfw"], + AIHordeAPIManualClient(), + ) + + assert len(resolved_model_names) > 0 + assert any("urpm" in model_name.lower() for model_name in resolved_model_names) + + def test_image_models_unique_results_only( image_model_load_resolver: ImageModelLoadResolver, ) -> None: