Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Memory management improvements; force load SD15 models fully #318

Merged
merged 11 commits into from
Aug 26, 2024
145 changes: 126 additions & 19 deletions hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pprint import pformat
import requests
import psutil
from collections.abc import Callable

import torch
from loguru import logger
Expand All @@ -40,9 +41,10 @@
# There may be other ways to skin this cat, but this strategy minimizes certain kinds of hassle.
#
# If you tamper with the code in this module to bring the imports out of the function, you may find that you have
# broken, among other things, the ability of pytest to do its test discovery because you will have lost the ability for
# modules which, directly or otherwise, import this module without having called `hordelib.initialise()`. Pytest
# discovery will come across those imports, valiantly attempt to import them and fail with a cryptic error message.
# broken, among myriad other things, the ability of pytest to do its test discovery because you will have lost the
# ability for modules which, directly or otherwise, import this module without having called `hordelib.initialise()`.
# Pytest discovery will come across those imports, valiantly attempt to import them and fail with a cryptic error
# message.
#
# Correspondingly, you will find that to be an enormous hassle if you are are trying to leverage pytest in any
# reasonability sophisticated way (outside of tox), and you will be forced to adopt solution below or something
Expand All @@ -51,17 +53,18 @@
#
# Keen readers may have noticed that the aforementioned issues could be side stepped by simply calling
# `hordelib.initialise()` automatically, such as in `test/__init__.py` or in a `conftest.py`. You would be correct,
# but that would be a terrible idea if you ever intended to make alterations to the patch file, as each time you
# triggered pytest discovery which could be as frequently as *every time you save a file* (such as with VSCode), and
# you would enter a situation where the patch was automatically being applied at times you may not intend.
# but that would be a terrible idea as a general practice. It would mean that every time you saved a file in your
# editor, a number of heavyweight operations would be triggered, such as loading comfyui, while pytest discovery runs
# and that would cause slow and unpredictable behavior in your editor.
#
# This would be a nightmare to debug, as this author is able to attest to.
# This would be a nightmare to debug, as this author is able to attest to and is the reason this wall of text exists.
#
# Further, if you are like myself, and enjoy type hints, you will find that any modules have this file in their import
# chain will be un-importable in certain contexts and you would be unable to provide the relevant type hints.
#
# Having read this, I suggest you glance at the code in `hordelib.initialise()` to get a sense of what is going on
# there, and if you're still confused, ask a hordelib dev who would be happy to share the burden of understanding.
# Having exercised a herculean amount of focus to read this far, I suggest you also glance at the code in
# `hordelib.initialise()` to get a sense of what is going on there, and if you're still confused, ask a hordelib dev
# who would be happy to share the burden of understanding.

_comfy_load_models_gpu: types.FunctionType
_comfy_current_loaded_models: list = None # type: ignore
Expand All @@ -76,15 +79,24 @@
_comfy_load_checkpoint_guess_config: types.FunctionType

_comfy_get_torch_device: types.FunctionType
"""Will return the current torch device, typically the GPU."""
_comfy_get_free_memory: types.FunctionType
"""Will return the amount of free memory on the current torch device. This value can be misleading."""
_comfy_get_total_memory: types.FunctionType
"""Will return the total amount of memory on the current torch device."""
_comfy_load_torch_file: types.FunctionType
_comfy_model_loading: types.ModuleType
_comfy_free_memory: types.FunctionType
_comfy_cleanup_models: types.FunctionType
_comfy_soft_empty_cache: types.FunctionType
_comfy_free_memory: Callable[[float, torch.device, list], None]
"""Will aggressively unload models from memory"""
_comfy_cleanup_models: Callable[[bool], None]
"""Will unload unused models from memory"""
_comfy_soft_empty_cache: Callable[[bool], None]
"""Triggers comfyui and torch to empty their caches"""

_comfy_is_changed_cache_get: types.FunctionType
_comfy_is_changed_cache_get: Callable
_comfy_model_patcher_load: Callable

_comfy_interrupt_current_processing: types.FunctionType

_canny: types.ModuleType
_hed: types.ModuleType
Expand All @@ -111,6 +123,9 @@ class InterceptHandler(logging.Handler):

@logger.catch(default=True, reraise=True)
def emit(self, record):
message = record.getMessage()
if "lowvram: loaded module regularly" in message:
return
# Get corresponding Loguru level if it exists.
try:
level = logger.level(record.levelname).name
Expand All @@ -123,7 +138,7 @@ def emit(self, record):
frame = frame.f_back
depth += 1

logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
logger.opt(depth=depth, exception=record.exc_info).log(level, message)


# ComfyUI uses stdlib logging, so we need to intercept it.
Expand All @@ -145,6 +160,8 @@ def do_comfy_import(
global _comfy_free_memory, _comfy_cleanup_models, _comfy_soft_empty_cache
global _canny, _hed, _leres, _midas, _mlsd, _openpose, _pidinet, _uniformer

global _comfy_interrupt_current_processing

if disable_smart_memory:
logger.info("Disabling smart memory")
sys.argv.append("--disable-smart-memory")
Expand Down Expand Up @@ -173,23 +190,36 @@ def do_comfy_import(
from execution import IsChangedCache

global _comfy_is_changed_cache_get
_comfy_is_changed_cache_get = IsChangedCache.get # type: ignore
_comfy_is_changed_cache_get = IsChangedCache.get

IsChangedCache.get = IsChangedCache_get_hijack # type: ignore

from folder_paths import folder_names_and_paths as _comfy_folder_names_and_paths # type: ignore
from folder_paths import supported_pt_extensions as _comfy_supported_pt_extensions # type: ignore
from folder_paths import folder_names_and_paths as _comfy_folder_names_and_paths
from folder_paths import supported_pt_extensions as _comfy_supported_pt_extensions
from comfy.sd import load_checkpoint_guess_config as _comfy_load_checkpoint_guess_config
from comfy.model_management import current_loaded_models as _comfy_current_loaded_models
from comfy.model_management import load_models_gpu as _comfy_load_models_gpu
from comfy.model_management import load_models_gpu

_comfy_load_models_gpu = load_models_gpu # type: ignore
import comfy.model_management

comfy.model_management.load_models_gpu = _load_models_gpu_hijack
from comfy.model_management import get_torch_device as _comfy_get_torch_device
from comfy.model_management import get_free_memory as _comfy_get_free_memory
from comfy.model_management import get_total_memory as _comfy_get_total_memory
from comfy.model_management import free_memory as _comfy_free_memory
from comfy.model_management import cleanup_models as _comfy_cleanup_models
from comfy.model_management import soft_empty_cache as _comfy_soft_empty_cache
from comfy.model_management import interrupt_current_processing as _comfy_interrupt_current_processing
from comfy.utils import load_torch_file as _comfy_load_torch_file
from comfy_extras.chainner_models import model_loading as _comfy_model_loading # type: ignore
from comfy_extras.chainner_models import model_loading as _comfy_model_loading

from comfy.model_patcher import ModelPatcher

global _comfy_model_patcher_load
_comfy_model_patcher_load = ModelPatcher.load
ModelPatcher.load = _model_patcher_load_hijack # type: ignore

from hordelib.nodes.comfy_controlnet_preprocessors import (
canny as _canny,
hed as _hed,
Expand All @@ -208,6 +238,70 @@ def do_comfy_import(


# isort: on
models_not_to_force_load: list = ["cascade", "sdxl"] # other possible values could be `basemodel` or `sd1`
"""Models which should not be forced to load in the comfy model loading hijack.

Possible values include `cascade`, `sdxl`, `basemodel`, `sd1` or any other comfyui classname
which can be passed to comfyui's `load_models_gpu` function (as a `ModelPatcher.model`).
"""

disable_force_loading: bool = False


def _do_not_force_load_model_in_patcher(model_patcher):
for model in models_not_to_force_load:
if model in str(type(model_patcher.model)).lower():
return True

return False


def _load_models_gpu_hijack(*args, **kwargs):
"""Intercepts the comfy load_models_gpu function to force full load.

ComfyUI is too conservative in its loading to GPU for the worker/horde use case where we can have
multiple ComfyUI instances running on the same GPU. This function forces a full load of the model
and the worker/horde-engine takes responsibility for managing the memory or the problems this may
cause.
"""
found_model_to_skip = False
for model_patcher in args[0]:
found_model_to_skip = _do_not_force_load_model_in_patcher(model_patcher)
if found_model_to_skip:
break

global _comfy_current_loaded_models
if found_model_to_skip:
logger.debug("Not overriding model load")
_comfy_load_models_gpu(*args, **kwargs)
return

if "force_full_load" in kwargs:
kwargs.pop("force_full_load")

kwargs["force_full_load"] = True
_comfy_load_models_gpu(*args, **kwargs)


def _model_patcher_load_hijack(*args, **kwargs):
"""Intercepts the comfy ModelPatcher.load function to force full load.

See _load_models_gpu_hijack for more information
"""
global _comfy_model_patcher_load

model_patcher = args[0]
if _do_not_force_load_model_in_patcher(model_patcher):
logger.debug("Not overriding model load")
_comfy_model_patcher_load(*args, **kwargs)
return

if "full_load" in kwargs:
kwargs.pop("full_load")

kwargs["full_load"] = True
_comfy_model_patcher_load(*args, **kwargs)


_last_pipeline_settings_hash = ""

Expand Down Expand Up @@ -324,6 +418,11 @@ def log_free_ram():
)


def interrupt_comfyui_processing():
logger.warning("Interrupting comfyui processing")
_comfy_interrupt_current_processing()


class Comfy_Horde:
"""Handles horde-specific behavior against ComfyUI."""

Expand Down Expand Up @@ -362,6 +461,7 @@ def __init__(
self,
*,
comfyui_callback: typing.Callable[[str, dict, str], None] | None = None,
aggressive_unloading: bool = True,
) -> None:
"""Initialise the Comfy_Horde object.

Expand All @@ -388,6 +488,7 @@ def __init__(
self._load_custom_nodes()

self._comfyui_callback = comfyui_callback
self.aggressive_unloading = aggressive_unloading

def _set_comfyui_paths(self) -> None:
# These set the default paths for comfyui to look for models and embeddings. From within hordelib,
Expand Down Expand Up @@ -764,6 +865,12 @@ def _run_pipeline(
inference.execute(pipeline, self.client_id, {"client_id": self.client_id}, valid[2])
except Exception as e:
logger.exception(f"Exception during comfy execute: {e}")
finally:
if self.aggressive_unloading:
global _comfy_cleanup_models
logger.debug("Cleaning up models")
_comfy_cleanup_models(False)
_comfy_soft_empty_cache(True)

stdio.replay()

Expand Down
2 changes: 2 additions & 0 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,12 @@ def __init__(
self,
*,
comfyui_callback: Callable[[str, dict, str], None] | None = None,
aggressive_unloading: bool = True,
):
if not self._initialised:
self.generator = Comfy_Horde(
comfyui_callback=comfyui_callback if comfyui_callback else self._comfyui_callback,
aggressive_unloading=aggressive_unloading,
)
self.__class__._initialised = True

Expand Down
6 changes: 6 additions & 0 deletions hordelib/initialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def initialise(
extra_comfyui_args: list[str] | None = None,
disable_smart_memory: bool = False,
do_not_load_model_mangers: bool = False,
models_not_to_force_load: list[str] | None = None,
):
"""Initialise hordelib. This is required before using any other hordelib functions.

Expand All @@ -40,6 +41,9 @@ def initialise(
force_low_vram (bool, optional): Whether to forcibly disable ComfyUI's high/med vram modes. Defaults to False.
extra_comfyui_args (list[str] | None, optional): Any additional CLI args for comfyui that should be used. \
Defaults to None.
models_not_to_force_load (list[str] | None, optional): A list of baselines that should not be force loaded.\
**If this is `None`, the defaults are used.** If you wish to override the defaults, pass an empty list. \
Defaults to None.
"""
global _is_initialised

Expand Down Expand Up @@ -79,6 +83,8 @@ def initialise(
extra_comfyui_args=extra_comfyui_args,
disable_smart_memory=disable_smart_memory,
)
if models_not_to_force_load is not None:
hordelib.comfy_horde.models_not_to_force_load = models_not_to_force_load.copy()

vram_on_start_free = hordelib.comfy_horde.get_torch_free_vram_mb()
vram_total = hordelib.comfy_horde.get_torch_total_vram_mb()
Expand Down
Binary file added images_expected/text_to_image_max_resolution.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,13 @@ def init_horde(
hordelib.initialise(
setup_logging=True,
logging_verbosity=5,
disable_smart_memory=False,
disable_smart_memory=True,
force_normal_vram_mode=True,
do_not_load_model_mangers=True,
models_not_to_force_load=[
"sdxl",
"cascade",
],
)
from hordelib.settings import UserSettings

Expand Down
37 changes: 37 additions & 0 deletions tests/test_horde_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,43 @@ def test_text_to_image(
pil_image,
)

@pytest.mark.default_sd15_model
def test_text_to_image_max_resolution(
self,
hordelib_instance: HordeLib,
stable_diffusion_model_name_for_testing: str,
):
data = {
"sampler_name": "k_dpmpp_2m",
"cfg_scale": 7.5,
"denoising_strength": 1.0,
"seed": 123456789,
"height": 2048,
"width": 2048,
"karras": False,
"tiling": False,
"hires_fix": False,
"clip_skip": 1,
"control_type": None,
"image_is_control": False,
"return_control_map": False,
"prompt": "an ancient llamia monster",
"ddim_steps": 5,
"n_iter": 1,
"model": stable_diffusion_model_name_for_testing,
}
pil_image = hordelib_instance.basic_inference_single_image(data).image
assert pil_image is not None
assert isinstance(pil_image, Image.Image)

img_filename = "text_to_image_max_resolution.png"
pil_image.save(f"images/{img_filename}", quality=100)

assert check_single_inference_image_similarity(
f"images_expected/{img_filename}",
pil_image,
)

@pytest.mark.default_sd15_model
def test_text_to_image_n_iter(
self,
Expand Down
Loading