Skip to content

Commit

Permalink
Merge pull request #318 from Haidra-Org/main
Browse files Browse the repository at this point in the history
feat: Memory management improvements; force load SD15 models fully
  • Loading branch information
tazlin authored Aug 26, 2024
2 parents 2c505c4 + 2d58136 commit 7481314
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 20 deletions.
145 changes: 126 additions & 19 deletions hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pprint import pformat
import requests
import psutil
from collections.abc import Callable

import torch
from loguru import logger
Expand All @@ -40,9 +41,10 @@
# There may be other ways to skin this cat, but this strategy minimizes certain kinds of hassle.
#
# If you tamper with the code in this module to bring the imports out of the function, you may find that you have
# broken, among other things, the ability of pytest to do its test discovery because you will have lost the ability for
# modules which, directly or otherwise, import this module without having called `hordelib.initialise()`. Pytest
# discovery will come across those imports, valiantly attempt to import them and fail with a cryptic error message.
# broken, among myriad other things, the ability of pytest to do its test discovery because you will have lost the
# ability for modules which, directly or otherwise, import this module without having called `hordelib.initialise()`.
# Pytest discovery will come across those imports, valiantly attempt to import them and fail with a cryptic error
# message.
#
# Correspondingly, you will find that to be an enormous hassle if you are are trying to leverage pytest in any
# reasonability sophisticated way (outside of tox), and you will be forced to adopt solution below or something
Expand All @@ -51,17 +53,18 @@
#
# Keen readers may have noticed that the aforementioned issues could be side stepped by simply calling
# `hordelib.initialise()` automatically, such as in `test/__init__.py` or in a `conftest.py`. You would be correct,
# but that would be a terrible idea if you ever intended to make alterations to the patch file, as each time you
# triggered pytest discovery which could be as frequently as *every time you save a file* (such as with VSCode), and
# you would enter a situation where the patch was automatically being applied at times you may not intend.
# but that would be a terrible idea as a general practice. It would mean that every time you saved a file in your
# editor, a number of heavyweight operations would be triggered, such as loading comfyui, while pytest discovery runs
# and that would cause slow and unpredictable behavior in your editor.
#
# This would be a nightmare to debug, as this author is able to attest to.
# This would be a nightmare to debug, as this author is able to attest to and is the reason this wall of text exists.
#
# Further, if you are like myself, and enjoy type hints, you will find that any modules have this file in their import
# chain will be un-importable in certain contexts and you would be unable to provide the relevant type hints.
#
# Having read this, I suggest you glance at the code in `hordelib.initialise()` to get a sense of what is going on
# there, and if you're still confused, ask a hordelib dev who would be happy to share the burden of understanding.
# Having exercised a herculean amount of focus to read this far, I suggest you also glance at the code in
# `hordelib.initialise()` to get a sense of what is going on there, and if you're still confused, ask a hordelib dev
# who would be happy to share the burden of understanding.

_comfy_load_models_gpu: types.FunctionType
_comfy_current_loaded_models: list = None # type: ignore
Expand All @@ -76,15 +79,24 @@
_comfy_load_checkpoint_guess_config: types.FunctionType

_comfy_get_torch_device: types.FunctionType
"""Will return the current torch device, typically the GPU."""
_comfy_get_free_memory: types.FunctionType
"""Will return the amount of free memory on the current torch device. This value can be misleading."""
_comfy_get_total_memory: types.FunctionType
"""Will return the total amount of memory on the current torch device."""
_comfy_load_torch_file: types.FunctionType
_comfy_model_loading: types.ModuleType
_comfy_free_memory: types.FunctionType
_comfy_cleanup_models: types.FunctionType
_comfy_soft_empty_cache: types.FunctionType
_comfy_free_memory: Callable[[float, torch.device, list], None]
"""Will aggressively unload models from memory"""
_comfy_cleanup_models: Callable[[bool], None]
"""Will unload unused models from memory"""
_comfy_soft_empty_cache: Callable[[bool], None]
"""Triggers comfyui and torch to empty their caches"""

_comfy_is_changed_cache_get: types.FunctionType
_comfy_is_changed_cache_get: Callable
_comfy_model_patcher_load: Callable

_comfy_interrupt_current_processing: types.FunctionType

_canny: types.ModuleType
_hed: types.ModuleType
Expand All @@ -111,6 +123,9 @@ class InterceptHandler(logging.Handler):

@logger.catch(default=True, reraise=True)
def emit(self, record):
message = record.getMessage()
if "lowvram: loaded module regularly" in message:
return
# Get corresponding Loguru level if it exists.
try:
level = logger.level(record.levelname).name
Expand All @@ -123,7 +138,7 @@ def emit(self, record):
frame = frame.f_back
depth += 1

logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
logger.opt(depth=depth, exception=record.exc_info).log(level, message)


# ComfyUI uses stdlib logging, so we need to intercept it.
Expand All @@ -145,6 +160,8 @@ def do_comfy_import(
global _comfy_free_memory, _comfy_cleanup_models, _comfy_soft_empty_cache
global _canny, _hed, _leres, _midas, _mlsd, _openpose, _pidinet, _uniformer

global _comfy_interrupt_current_processing

if disable_smart_memory:
logger.info("Disabling smart memory")
sys.argv.append("--disable-smart-memory")
Expand Down Expand Up @@ -173,23 +190,36 @@ def do_comfy_import(
from execution import IsChangedCache

global _comfy_is_changed_cache_get
_comfy_is_changed_cache_get = IsChangedCache.get # type: ignore
_comfy_is_changed_cache_get = IsChangedCache.get

IsChangedCache.get = IsChangedCache_get_hijack # type: ignore

from folder_paths import folder_names_and_paths as _comfy_folder_names_and_paths # type: ignore
from folder_paths import supported_pt_extensions as _comfy_supported_pt_extensions # type: ignore
from folder_paths import folder_names_and_paths as _comfy_folder_names_and_paths
from folder_paths import supported_pt_extensions as _comfy_supported_pt_extensions
from comfy.sd import load_checkpoint_guess_config as _comfy_load_checkpoint_guess_config
from comfy.model_management import current_loaded_models as _comfy_current_loaded_models
from comfy.model_management import load_models_gpu as _comfy_load_models_gpu
from comfy.model_management import load_models_gpu

_comfy_load_models_gpu = load_models_gpu # type: ignore
import comfy.model_management

comfy.model_management.load_models_gpu = _load_models_gpu_hijack
from comfy.model_management import get_torch_device as _comfy_get_torch_device
from comfy.model_management import get_free_memory as _comfy_get_free_memory
from comfy.model_management import get_total_memory as _comfy_get_total_memory
from comfy.model_management import free_memory as _comfy_free_memory
from comfy.model_management import cleanup_models as _comfy_cleanup_models
from comfy.model_management import soft_empty_cache as _comfy_soft_empty_cache
from comfy.model_management import interrupt_current_processing as _comfy_interrupt_current_processing
from comfy.utils import load_torch_file as _comfy_load_torch_file
from comfy_extras.chainner_models import model_loading as _comfy_model_loading # type: ignore
from comfy_extras.chainner_models import model_loading as _comfy_model_loading

from comfy.model_patcher import ModelPatcher

global _comfy_model_patcher_load
_comfy_model_patcher_load = ModelPatcher.load
ModelPatcher.load = _model_patcher_load_hijack # type: ignore

from hordelib.nodes.comfy_controlnet_preprocessors import (
canny as _canny,
hed as _hed,
Expand All @@ -208,6 +238,70 @@ def do_comfy_import(


# isort: on
models_not_to_force_load: list = ["cascade", "sdxl"] # other possible values could be `basemodel` or `sd1`
"""Models which should not be forced to load in the comfy model loading hijack.
Possible values include `cascade`, `sdxl`, `basemodel`, `sd1` or any other comfyui classname
which can be passed to comfyui's `load_models_gpu` function (as a `ModelPatcher.model`).
"""

disable_force_loading: bool = False


def _do_not_force_load_model_in_patcher(model_patcher):
for model in models_not_to_force_load:
if model in str(type(model_patcher.model)).lower():
return True

return False


def _load_models_gpu_hijack(*args, **kwargs):
"""Intercepts the comfy load_models_gpu function to force full load.
ComfyUI is too conservative in its loading to GPU for the worker/horde use case where we can have
multiple ComfyUI instances running on the same GPU. This function forces a full load of the model
and the worker/horde-engine takes responsibility for managing the memory or the problems this may
cause.
"""
found_model_to_skip = False
for model_patcher in args[0]:
found_model_to_skip = _do_not_force_load_model_in_patcher(model_patcher)
if found_model_to_skip:
break

global _comfy_current_loaded_models
if found_model_to_skip:
logger.debug("Not overriding model load")
_comfy_load_models_gpu(*args, **kwargs)
return

if "force_full_load" in kwargs:
kwargs.pop("force_full_load")

kwargs["force_full_load"] = True
_comfy_load_models_gpu(*args, **kwargs)


def _model_patcher_load_hijack(*args, **kwargs):
"""Intercepts the comfy ModelPatcher.load function to force full load.
See _load_models_gpu_hijack for more information
"""
global _comfy_model_patcher_load

model_patcher = args[0]
if _do_not_force_load_model_in_patcher(model_patcher):
logger.debug("Not overriding model load")
_comfy_model_patcher_load(*args, **kwargs)
return

if "full_load" in kwargs:
kwargs.pop("full_load")

kwargs["full_load"] = True
_comfy_model_patcher_load(*args, **kwargs)


_last_pipeline_settings_hash = ""

Expand Down Expand Up @@ -324,6 +418,11 @@ def log_free_ram():
)


def interrupt_comfyui_processing():
logger.warning("Interrupting comfyui processing")
_comfy_interrupt_current_processing()


class Comfy_Horde:
"""Handles horde-specific behavior against ComfyUI."""

Expand Down Expand Up @@ -362,6 +461,7 @@ def __init__(
self,
*,
comfyui_callback: typing.Callable[[str, dict, str], None] | None = None,
aggressive_unloading: bool = True,
) -> None:
"""Initialise the Comfy_Horde object.
Expand All @@ -388,6 +488,7 @@ def __init__(
self._load_custom_nodes()

self._comfyui_callback = comfyui_callback
self.aggressive_unloading = aggressive_unloading

def _set_comfyui_paths(self) -> None:
# These set the default paths for comfyui to look for models and embeddings. From within hordelib,
Expand Down Expand Up @@ -764,6 +865,12 @@ def _run_pipeline(
inference.execute(pipeline, self.client_id, {"client_id": self.client_id}, valid[2])
except Exception as e:
logger.exception(f"Exception during comfy execute: {e}")
finally:
if self.aggressive_unloading:
global _comfy_cleanup_models
logger.debug("Cleaning up models")
_comfy_cleanup_models(False)
_comfy_soft_empty_cache(True)

stdio.replay()

Expand Down
2 changes: 2 additions & 0 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,12 @@ def __init__(
self,
*,
comfyui_callback: Callable[[str, dict, str], None] | None = None,
aggressive_unloading: bool = True,
):
if not self._initialised:
self.generator = Comfy_Horde(
comfyui_callback=comfyui_callback if comfyui_callback else self._comfyui_callback,
aggressive_unloading=aggressive_unloading,
)
self.__class__._initialised = True

Expand Down
6 changes: 6 additions & 0 deletions hordelib/initialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def initialise(
extra_comfyui_args: list[str] | None = None,
disable_smart_memory: bool = False,
do_not_load_model_mangers: bool = False,
models_not_to_force_load: list[str] | None = None,
):
"""Initialise hordelib. This is required before using any other hordelib functions.
Expand All @@ -40,6 +41,9 @@ def initialise(
force_low_vram (bool, optional): Whether to forcibly disable ComfyUI's high/med vram modes. Defaults to False.
extra_comfyui_args (list[str] | None, optional): Any additional CLI args for comfyui that should be used. \
Defaults to None.
models_not_to_force_load (list[str] | None, optional): A list of baselines that should not be force loaded.\
**If this is `None`, the defaults are used.** If you wish to override the defaults, pass an empty list. \
Defaults to None.
"""
global _is_initialised

Expand Down Expand Up @@ -79,6 +83,8 @@ def initialise(
extra_comfyui_args=extra_comfyui_args,
disable_smart_memory=disable_smart_memory,
)
if models_not_to_force_load is not None:
hordelib.comfy_horde.models_not_to_force_load = models_not_to_force_load.copy()

vram_on_start_free = hordelib.comfy_horde.get_torch_free_vram_mb()
vram_total = hordelib.comfy_horde.get_torch_total_vram_mb()
Expand Down
Binary file added images_expected/text_to_image_max_resolution.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,13 @@ def init_horde(
hordelib.initialise(
setup_logging=True,
logging_verbosity=5,
disable_smart_memory=False,
disable_smart_memory=True,
force_normal_vram_mode=True,
do_not_load_model_mangers=True,
models_not_to_force_load=[
"sdxl",
"cascade",
],
)
from hordelib.settings import UserSettings

Expand Down
37 changes: 37 additions & 0 deletions tests/test_horde_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,43 @@ def test_text_to_image(
pil_image,
)

@pytest.mark.default_sd15_model
def test_text_to_image_max_resolution(
self,
hordelib_instance: HordeLib,
stable_diffusion_model_name_for_testing: str,
):
data = {
"sampler_name": "k_dpmpp_2m",
"cfg_scale": 7.5,
"denoising_strength": 1.0,
"seed": 123456789,
"height": 2048,
"width": 2048,
"karras": False,
"tiling": False,
"hires_fix": False,
"clip_skip": 1,
"control_type": None,
"image_is_control": False,
"return_control_map": False,
"prompt": "an ancient llamia monster",
"ddim_steps": 5,
"n_iter": 1,
"model": stable_diffusion_model_name_for_testing,
}
pil_image = hordelib_instance.basic_inference_single_image(data).image
assert pil_image is not None
assert isinstance(pil_image, Image.Image)

img_filename = "text_to_image_max_resolution.png"
pil_image.save(f"images/{img_filename}", quality=100)

assert check_single_inference_image_similarity(
f"images_expected/{img_filename}",
pil_image,
)

@pytest.mark.default_sd15_model
def test_text_to_image_n_iter(
self,
Expand Down

0 comments on commit 7481314

Please sign in to comment.