Skip to content

Commit

Permalink
feat: support callbacks for inference/post-proc.
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Mar 12, 2024
1 parent ec32ca1 commit d13e782
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 17 deletions.
20 changes: 15 additions & 5 deletions hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from loguru import logger

from hordelib.settings import UserSettings
from hordelib.utils.ioredirect import OutputCollector
from hordelib.utils.ioredirect import ComfyUIProgress, OutputCollector
from hordelib.config_path import get_hordelib_path

# Note It may not be abundantly clear with no context what is going on below, and I will attempt to clarify:
Expand Down Expand Up @@ -669,7 +669,12 @@ def send_sync(self, label: str, data: dict, _id: str) -> None:

# Execute the named pipeline and pass the pipeline the parameter provided.
# For the horde we assume the pipeline returns an array of images.
def _run_pipeline(self, pipeline: dict, params: dict) -> list[dict] | None:
def _run_pipeline(
self,
pipeline: dict,
params: dict,
comfyui_progress_callback: typing.Callable[[ComfyUIProgress, str], None] | None = None,
) -> list[dict] | None:
if _comfy_current_loaded_models is None:
raise RuntimeError("hordelib.initialise() must be called before using comfy_horde.")
# Wipe any previous images, if they exist.
Expand All @@ -692,7 +697,7 @@ def _run_pipeline(self, pipeline: dict, params: dict) -> list[dict] | None:

# The client_id parameter here is just so we receive comfy callbacks for debugging.
# We pretend we are a web client and want async callbacks.
stdio = OutputCollector()
stdio = OutputCollector(comfyui_progress_callback=comfyui_progress_callback)
with contextlib.redirect_stdout(stdio), contextlib.redirect_stderr(stdio):
# validate_prompt from comfy returns [bool, str, list]
# Which gives us these nice hardcoded list indexes, which valid[2] is the output node list
Expand Down Expand Up @@ -720,7 +725,12 @@ def _run_pipeline(self, pipeline: dict, params: dict) -> list[dict] | None:
return self.images

# Run a pipeline that returns an image in pixel space
def run_image_pipeline(self, pipeline, params: dict) -> list[dict[str, typing.Any]]:
def run_image_pipeline(
self,
pipeline,
params: dict,
comfyui_progress_callback: typing.Callable[[ComfyUIProgress, str], None] | None = None,
) -> list[dict[str, typing.Any]]:
# From the horde point of view, let us assume the output we are interested in
# is always in a HordeImageOutput node named "output_image". This is an array of
# dicts of the form:
Expand Down Expand Up @@ -748,7 +758,7 @@ def run_image_pipeline(self, pipeline, params: dict) -> list[dict[str, typing.An
if idle_time > 1 and UserSettings.enable_idle_time_warning.active:
logger.warning(f"No job ran for {round(idle_time, 3)} seconds")

result = self._run_pipeline(pipeline_data, params)
result = self._run_pipeline(pipeline_data, params, comfyui_progress_callback)

if result:
return result
Expand Down
95 changes: 92 additions & 3 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import typing
from collections.abc import Callable
from copy import deepcopy
from enum import Enum, auto

from horde_sdk.ai_horde_api.apimodels import ImageGenerateJobPopResponse
from horde_sdk.ai_horde_api.apimodels.base import (
Expand All @@ -18,12 +19,32 @@
from horde_sdk.ai_horde_api.consts import KNOWN_FACEFIXERS, KNOWN_UPSCALERS, METADATA_TYPE, METADATA_VALUE
from loguru import logger
from PIL import Image
from pydantic import BaseModel

from hordelib.comfy_horde import Comfy_Horde
from hordelib.consts import MODEL_CATEGORY_NAMES
from hordelib.shared_model_manager import SharedModelManager
from hordelib.utils.dynamicprompt import DynamicPromptParser
from hordelib.utils.image_utils import ImageUtils
from hordelib.utils.ioredirect import ComfyUIProgress


class ProgressState(Enum):
"""The state of the progress report"""

started = auto()
progress = auto()
post_processing = auto()
finished = auto()


class ProgressReport(BaseModel):
"""A progress message sent to a callback"""

hordelib_progress_state: ProgressState
comfyui_progress: ComfyUIProgress | None = None
progress: float | None = None
hordelib_message: str | None = None


class ResultingImageReturn:
Expand Down Expand Up @@ -869,6 +890,7 @@ def _inference(
payload: dict,
*,
single_image_expected: bool = True,
comfyui_progress_callback: Callable[[ComfyUIProgress, str], None] | None = None,
) -> list[ResultingImageReturn] | ResultingImageReturn:
payload, pipeline_data, faults = self._get_validated_payload_and_pipeline_data(payload)

Expand Down Expand Up @@ -901,7 +923,7 @@ def _inference(

# Call the inference pipeline
# logger.debug(payload)
images = self.generator.run_image_pipeline(pipeline_data, payload)
images = self.generator.run_image_pipeline(pipeline_data, payload, comfyui_progress_callback)

results = self._process_results(images)
ret_results = [
Expand All @@ -920,8 +942,15 @@ def _inference(

return ret_results

def basic_inference(self, payload: dict | ImageGenerateJobPopResponse) -> list[ResultingImageReturn]:
def basic_inference(
self,
payload: dict | ImageGenerateJobPopResponse,
*,
progress_callback: Callable[[ProgressReport], None] | None = None,
) -> list[ResultingImageReturn]:
post_processing_requested: list[str] | None = None
if isinstance(payload, dict):
post_processing_requested = payload.get("post_processing")

faults = []
if isinstance(payload, ImageGenerateJobPopResponse): # TODO move this to _inference()
Expand Down Expand Up @@ -968,7 +997,37 @@ def basic_inference(self, payload: dict | ImageGenerateJobPopResponse) -> list[R
sub_payload["model"] = payload.model
payload = sub_payload

result = self._inference(payload, single_image_expected=False)
if progress_callback is not None:
try:
progress_callback(
ProgressReport(
hordelib_progress_state=ProgressState.started,
hordelib_message="Initiating inference...",
progress=0,
),
)
except Exception as e:
logger.error(f"Progress callback failed ({type(e)}): {e}")

def _default_progress_callback(comfyui_progress: ComfyUIProgress, message: str) -> None:
nonlocal progress_callback
if progress_callback is not None:
try:
progress_callback(
ProgressReport(
hordelib_progress_state=ProgressState.progress,
hordelib_message=message,
comfyui_progress=comfyui_progress,
),
)
except Exception as e:
logger.error(f"Progress callback failed ({type(e)}): {e}")

result = self._inference(
payload,
single_image_expected=False,
comfyui_progress_callback=_default_progress_callback,
)

if not isinstance(result, list):
raise RuntimeError(f"Expected a list of PIL.Image.Image but got {type(result)}")
Expand All @@ -981,11 +1040,29 @@ def basic_inference(self, payload: dict | ImageGenerateJobPopResponse) -> list[R

post_processed: list[ResultingImageReturn] | None = None
if post_processing_requested is not None:
if progress_callback is not None:
try:
progress_callback(
ProgressReport(
hordelib_progress_state=ProgressState.post_processing,
hordelib_message="Post Processing.",
),
)
except Exception as e:
logger.error(f"Progress callback failed ({type(e)}): {e}")

post_processed = []
for ret in return_list:
single_image_faults = []
final_image = ret.image
final_rawpng = ret.rawpng

# Ensure facefixers always happen first
post_processing_requested = sorted(
post_processing_requested,
key=lambda x: 1 if x in KNOWN_FACEFIXERS.__members__ else 0,
)

for post_processing in post_processing_requested:
if (
post_processing in KNOWN_UPSCALERS.__members__
Expand Down Expand Up @@ -1025,6 +1102,18 @@ def basic_inference(self, payload: dict | ImageGenerateJobPopResponse) -> list[R
ResultingImageReturn(image=final_image, rawpng=final_rawpng, faults=single_image_faults),
)

if progress_callback is not None:
try:
progress_callback(
ProgressReport(
hordelib_progress_state=ProgressState.finished,
hordelib_message="Inference complete.",
progress=100,
),
)
except Exception as e:
logger.error(f"Progress callback failed ({type(e)}): {e}")

if post_processed is not None:
logger.debug(f"Post-processing complete. Returning {len(post_processed)} images.")
return post_processed
Expand Down
2 changes: 2 additions & 0 deletions hordelib/initialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def initialise(
process_id: int | None = None,
force_normal_vram_mode: bool = True,
extra_comfyui_args: list[str] | None = None,
disable_smart_memory: bool = False,
):
"""Initialise hordelib. This is required before using any other hordelib functions.
Expand Down Expand Up @@ -75,6 +76,7 @@ def initialise(
hordelib.comfy_horde.do_comfy_import(
force_normal_vram_mode=force_normal_vram_mode,
extra_comfyui_args=extra_comfyui_args,
disable_smart_memory=disable_smart_memory,
)

vram_on_start_free = hordelib.comfy_horde.get_torch_free_vram_mb()
Expand Down
2 changes: 2 additions & 0 deletions hordelib/nodes/node_upscale_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def INPUT_TYPES(s):
def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": ""})
out = model_loading.load_state_dict(sd).eval()
return (out,)

Expand Down
70 changes: 62 additions & 8 deletions hordelib/utils/ioredirect.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
import io
from collections import deque
from collections.abc import Callable
from enum import Enum
from time import perf_counter

import regex
from loguru import logger
from pydantic import BaseModel


class ComfyUIProgressUnit(Enum):
"""An enum to represent the different types of progress bars that ComfyUI can output.
This is used to determine how to parse the progress bar and log it.
"""

ITERATIONS_PER_SECOND = 1
SECONDS_PER_ITERATION = 2
UNKNOWN = 3


class ComfyUIProgress(BaseModel):
"""A dataclass to represent the progress of a ComfyUI job.
This is used to determine how to parse the progress bar and log it.
"""

percent: int
current_step: int
total_steps: int
rate: float
rate_unit: ComfyUIProgressUnit

def __str__(self):
return f"{self.percent}%: {self.current_step}/{self.total_steps} ({self.rate} {self.rate_unit})"


class OutputCollector(io.TextIOWrapper):
Expand All @@ -16,9 +46,17 @@ class OutputCollector(io.TextIOWrapper):
start_time: float
slow_message_count: int = 0

def __init__(self):
capture_deque: deque

comfyui_progress_callback: Callable[[ComfyUIProgress, str], None] | None = None
"""A callback function that is called when a progress bar is detected in the output. The callback function should \
accept two arguments: a ComfyUIProgress object and a string. The ComfyUIProgress object contains the parsed \
progress bar information, and the string contains the original message that was captured."""

def __init__(self, *, comfyui_progress_callback: Callable[[ComfyUIProgress, str], None] | None = None):
logger.disable("tqdm") # just.. no
self.deque = deque()
self.capture_deque = deque()
self.comfyui_progress_callback = comfyui_progress_callback
self.start_time = perf_counter()

def write(self, message: str):
Expand All @@ -44,7 +82,7 @@ def write(self, message: str):

if not matches:
logger.debug(f"Unknown progress bar format?: {message}")
self.deque.append(message)
self.capture_deque.append(message)
return

# Remove everything in between '|' and '|'
Expand Down Expand Up @@ -84,11 +122,27 @@ def write(self, message: str):
):
logger.info(message)

self.deque.append(message)
if self.comfyui_progress_callback:
self.comfyui_progress_callback(
ComfyUIProgress(
percent=int(matches.group(1)),
current_step=found_current_step,
total_steps=found_total_steps,
rate=float(iteration_rate) if iteration_rate != "?" else -1.0,
rate_unit=(
ComfyUIProgressUnit.ITERATIONS_PER_SECOND
if is_iterations_per_second
else ComfyUIProgressUnit.SECONDS_PER_ITERATION
),
),
message,
)

self.capture_deque.append(message)

def set_size(self, size):
while len(self.deque) > size:
self.deque.popleft()
while len(self.capture_deque) > size:
self.capture_deque.popleft()

def flush(self):
pass
Expand All @@ -102,5 +156,5 @@ def close(self):

def replay(self):
logger.debug("Replaying output. Seconds in parentheses is the elapsed time spent in ComfyUI. ")
while len(self.deque):
logger.debug(self.deque.popleft())
while len(self.capture_deque):
logger.debug(self.capture_deque.popleft())
Binary file added images_expected/text_to_image_callback_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images_expected/text_to_image_callback_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def init_horde():

import hordelib

hordelib.initialise(setup_logging=True, logging_verbosity=5)
hordelib.initialise(setup_logging=True, logging_verbosity=5, disable_smart_memory=True)
from hordelib.settings import UserSettings

UserSettings.set_ram_to_leave_free_mb("100%")
Expand Down
Loading

0 comments on commit d13e782

Please sign in to comment.