Skip to content

Commit

Permalink
feat: stable cascade 2pass (#407)
Browse files Browse the repository at this point in the history
* feat: stable cascade 2pass

* style: linted
  • Loading branch information
db0 authored May 12, 2024
1 parent 3071402 commit 85bed96
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 3 deletions.
5 changes: 4 additions & 1 deletion horde/apis/models/stable_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,10 @@ def __init__(self, api):
),
"hires_fix": fields.Boolean(
default=False,
description="Set to True to process the image at base resolution before upscaling and re-processing.",
description=(
"Set to True to process the image at base resolution "
"before upscaling and re-processing or to use Stable Cascade 2-pass."
),
),
"clip_skip": fields.Integer(
required=False,
Expand Down
2 changes: 2 additions & 0 deletions horde/apis/v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def post(self):
if skipped_reason != "secret":
self.skipped[skipped_reason] = self.skipped.get(skipped_reason, 0) + 1
# logger.warning(datetime.utcnow())

continue
# There is a chance that by the time we finished all the checks, another worker picked up the WP.
# So we do another final check here before picking it up to avoid sending the same WP to two workers by mistake.
Expand All @@ -462,6 +463,7 @@ def post(self):
continue
# logger.debug(worker_ret)
return worker_ret, 200
db.session.commit() # Unlock all locked wp rows before picking up new ones
self.wp_page += 1
self.prioritized_wp = self.get_sorted_wp()
logger.debug(f"Couldn't find WP. Checking next page: {self.wp_page}")
Expand Down
2 changes: 0 additions & 2 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ def validate(self):
if "control_type" in self.params:
raise e.BadRequest("ControlNet does not work with SDXL currently.", rc="ControlNetMismatch")
if any(model_reference.get_model_baseline(model_name).startswith("stable_cascade") for model_name in self.args.models):
if self.params.get("hires_fix", False) is True:
raise e.BadRequest("hires fix does not work with Stable Cascade currently.", rc="HiResFixMismatch")
if "control_type" in self.params:
raise e.BadRequest("ControlNet does not work with Stable Cascade currently.", rc="ControlNetMismatch")
if "loras" in self.params:
Expand Down
1 change: 1 addition & 0 deletions horde/bridge_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

BRIDGE_CAPABILITIES = {
"AI Horde Worker reGen": {
6: {"stable_cascade_2pass"},
5: {"extra_source_images"},
3: {"lora_versions"},
2: {"textual_inversion", "lora"},
Expand Down
9 changes: 9 additions & 0 deletions horde/classes/base/news.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@

class News:
HORDE_NEWS = [
{
"date_published": "2024-05-12",
"newspiece": (
"The AI Horde now supports Stable Cascade 2pass"
"Simply switch hires_fix to True to use. Note that this has double the cost of a normal Stable Cascade."
),
"tags": ["Stable Cascade", "db0", "nlnet"],
"importance": "Information",
},
{
"date_published": "2024-03-24",
"newspiece": (
Expand Down
3 changes: 3 additions & 0 deletions horde/classes/stable/processing_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def get_gen_kudos(self):
if model_reference.get_model_baseline(self.model) in ["stable_diffusion_xl"]:
return self.wp.kudos * 2
if model_reference.get_model_baseline(self.model) in ["stable_cascade"]:
# Stable Cascade 2pass has almost a double cost as it generates extra at a low generation
if self.wp.params.get("hires_fix", False):
return self.wp.kudos * 7
return self.wp.kudos * 4
return self.wp.kudos

Expand Down
6 changes: 6 additions & 0 deletions horde/classes/stable/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ def can_generate(self, waiting_prompt):
return [False, "bridge_version"]
if waiting_prompt.params.get("hires_fix") and not check_bridge_capability("hires_fix", self.bridge_agent):
return [False, "bridge_version"]
if (
waiting_prompt.params.get("hires_fix")
and "stable_cascade" in model_reference.get_all_model_baselines(self.get_model_names())
and not check_bridge_capability("stable_cascade_2pass", self.bridge_agent)
):
return [False, "bridge_version"]
if waiting_prompt.params.get("clip_skip", 1) > 1 and not check_bridge_capability(
"clip_skip",
self.bridge_agent,
Expand Down
7 changes: 7 additions & 0 deletions horde/model_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def get_model_baseline(self, model_name):
model_details = self.reference.get(model_name, {})
return model_details.get("baseline", "stable diffusion 1")

def get_all_model_baselines(self, model_names):
baselines = set()
for model_name in model_names:
model_details = self.reference.get(model_name, {})
baselines.add(model_details.get("baseline", "stable diffusion 1"))
return baselines

def get_model_requirements(self, model_name):
model_details = self.reference.get(model_name, {})
return model_details.get("requirements", {})
Expand Down

0 comments on commit 85bed96

Please sign in to comment.