From 85bed9679a5a5c39a02fd4554abe20328f4eeb3f Mon Sep 17 00:00:00 2001 From: Divided by Zer0 Date: Sun, 12 May 2024 17:33:37 +0200 Subject: [PATCH] feat: stable cascade 2pass (#407) * feat: stable cascade 2pass * style: linted --- horde/apis/models/stable_v2.py | 5 ++++- horde/apis/v2/base.py | 2 ++ horde/apis/v2/stable.py | 2 -- horde/bridge_reference.py | 1 + horde/classes/base/news.py | 9 +++++++++ horde/classes/stable/processing_generation.py | 3 +++ horde/classes/stable/worker.py | 6 ++++++ horde/model_reference.py | 7 +++++++ 8 files changed, 32 insertions(+), 3 deletions(-) diff --git a/horde/apis/models/stable_v2.py b/horde/apis/models/stable_v2.py index 4e1232d7..f84cf67e 100644 --- a/horde/apis/models/stable_v2.py +++ b/horde/apis/models/stable_v2.py @@ -352,7 +352,10 @@ def __init__(self, api): ), "hires_fix": fields.Boolean( default=False, - description="Set to True to process the image at base resolution before upscaling and re-processing.", + description=( + "Set to True to process the image at base resolution " + "before upscaling and re-processing or to use Stable Cascade 2-pass." + ), ), "clip_skip": fields.Integer( required=False, diff --git a/horde/apis/v2/base.py b/horde/apis/v2/base.py index 207d1cfd..b005773d 100644 --- a/horde/apis/v2/base.py +++ b/horde/apis/v2/base.py @@ -450,6 +450,7 @@ def post(self): if skipped_reason != "secret": self.skipped[skipped_reason] = self.skipped.get(skipped_reason, 0) + 1 # logger.warning(datetime.utcnow()) + continue # There is a chance that by the time we finished all the checks, another worker picked up the WP. # So we do another final check here before picking it up to avoid sending the same WP to two workers by mistake. @@ -462,6 +463,7 @@ def post(self): continue # logger.debug(worker_ret) return worker_ret, 200 + db.session.commit() # Unlock all locked wp rows before picking up new ones self.wp_page += 1 self.prioritized_wp = self.get_sorted_wp() logger.debug(f"Couldn't find WP. Checking next page: {self.wp_page}") diff --git a/horde/apis/v2/stable.py b/horde/apis/v2/stable.py index d82d98c7..6455fddc 100644 --- a/horde/apis/v2/stable.py +++ b/horde/apis/v2/stable.py @@ -170,8 +170,6 @@ def validate(self): if "control_type" in self.params: raise e.BadRequest("ControlNet does not work with SDXL currently.", rc="ControlNetMismatch") if any(model_reference.get_model_baseline(model_name).startswith("stable_cascade") for model_name in self.args.models): - if self.params.get("hires_fix", False) is True: - raise e.BadRequest("hires fix does not work with Stable Cascade currently.", rc="HiResFixMismatch") if "control_type" in self.params: raise e.BadRequest("ControlNet does not work with Stable Cascade currently.", rc="ControlNetMismatch") if "loras" in self.params: diff --git a/horde/bridge_reference.py b/horde/bridge_reference.py index 7344894a..4cdb3ee0 100644 --- a/horde/bridge_reference.py +++ b/horde/bridge_reference.py @@ -5,6 +5,7 @@ BRIDGE_CAPABILITIES = { "AI Horde Worker reGen": { + 6: {"stable_cascade_2pass"}, 5: {"extra_source_images"}, 3: {"lora_versions"}, 2: {"textual_inversion", "lora"}, diff --git a/horde/classes/base/news.py b/horde/classes/base/news.py index 8feb8a88..b4645427 100644 --- a/horde/classes/base/news.py +++ b/horde/classes/base/news.py @@ -3,6 +3,15 @@ class News: HORDE_NEWS = [ + { + "date_published": "2024-05-12", + "newspiece": ( + "The AI Horde now supports Stable Cascade 2pass" + "Simply switch hires_fix to True to use. Note that this has double the cost of a normal Stable Cascade." + ), + "tags": ["Stable Cascade", "db0", "nlnet"], + "importance": "Information", + }, { "date_published": "2024-03-24", "newspiece": ( diff --git a/horde/classes/stable/processing_generation.py b/horde/classes/stable/processing_generation.py index 28864e16..331fcfbc 100644 --- a/horde/classes/stable/processing_generation.py +++ b/horde/classes/stable/processing_generation.py @@ -54,6 +54,9 @@ def get_gen_kudos(self): if model_reference.get_model_baseline(self.model) in ["stable_diffusion_xl"]: return self.wp.kudos * 2 if model_reference.get_model_baseline(self.model) in ["stable_cascade"]: + # Stable Cascade 2pass has almost a double cost as it generates extra at a low generation + if self.wp.params.get("hires_fix", False): + return self.wp.kudos * 7 return self.wp.kudos * 4 return self.wp.kudos diff --git a/horde/classes/stable/worker.py b/horde/classes/stable/worker.py index d631da67..193225e1 100644 --- a/horde/classes/stable/worker.py +++ b/horde/classes/stable/worker.py @@ -119,6 +119,12 @@ def can_generate(self, waiting_prompt): return [False, "bridge_version"] if waiting_prompt.params.get("hires_fix") and not check_bridge_capability("hires_fix", self.bridge_agent): return [False, "bridge_version"] + if ( + waiting_prompt.params.get("hires_fix") + and "stable_cascade" in model_reference.get_all_model_baselines(self.get_model_names()) + and not check_bridge_capability("stable_cascade_2pass", self.bridge_agent) + ): + return [False, "bridge_version"] if waiting_prompt.params.get("clip_skip", 1) > 1 and not check_bridge_capability( "clip_skip", self.bridge_agent, diff --git a/horde/model_reference.py b/horde/model_reference.py index 51404ece..a6459469 100644 --- a/horde/model_reference.py +++ b/horde/model_reference.py @@ -87,6 +87,13 @@ def get_model_baseline(self, model_name): model_details = self.reference.get(model_name, {}) return model_details.get("baseline", "stable diffusion 1") + def get_all_model_baselines(self, model_names): + baselines = set() + for model_name in model_names: + model_details = self.reference.get(model_name, {}) + baselines.add(model_details.get("baseline", "stable diffusion 1")) + return baselines + def get_model_requirements(self, model_name): model_details = self.reference.get(model_name, {}) return model_details.get("requirements", {})