Skip to content

Commit

Permalink
add missing import, simplify code, use patches module for #13276
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Sep 30, 2023
1 parent e309583 commit 87b5039
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ldm.util import instantiate_from_config

from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer
import tomesd
import numpy as np
Expand Down Expand Up @@ -130,6 +130,8 @@ def calculate_shorthash(self):


def setup_model():
"""called once at startup to do various one-time tasks related to SD models"""

os.makedirs(model_path, exist_ok=True)

enable_midas_autodownload()
Expand Down Expand Up @@ -458,14 +460,17 @@ def load_model_wrapper(model_type):


def patch_given_betas():
original_register_schedule = ldm.models.diffusion.ddpm.DDPM.register_schedule
import ldm.models.diffusion.ddpm

def patched_register_schedule(*args, **kwargs):
if args[1] is not None and isinstance(args[1], ListConfig):
modified_args = list(args) # Convert args tuple to a list
modified_args[1] = np.array(args[1]) # Modify the desired element
args = tuple(modified_args) # Convert the list back to a tuple
"""a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""

if isinstance(args[1], ListConfig):
args = (args[0], np.array(args[1]), *args[2:])

original_register_schedule(*args, **kwargs)
ldm.models.diffusion.ddpm.DDPM.register_schedule = patched_register_schedule

original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)


def repair_config(sd_config):
Expand Down

0 comments on commit 87b5039

Please sign in to comment.