Skip to content

Commit

Permalink
Refactored callbacks+wrappers to allow storing lists by id
Browse files Browse the repository at this point in the history
  • Loading branch information
Kosinkadink committed Sep 22, 2024
1 parent a154d0d commit 7c86407
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 39 deletions.
98 changes: 61 additions & 37 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ class CallbacksMP:
@classmethod
def init_callbacks(cls):
return {
cls.ON_CLONE: [],
cls.ON_LOAD: [],
cls.ON_CLEANUP: [],
cls.ON_PRE_RUN: [],
cls.ON_PREPARE_STATE: [],
cls.ON_APPLY_HOOKS: [],
cls.ON_REGISTER_ALL_HOOK_PATCHES: [],
cls.ON_INJECT_MODEL: [],
cls.ON_EJECT_MODEL: [],
cls.ON_CLONE: {None: []},
cls.ON_LOAD: {None: []},
cls.ON_CLEANUP: {None: []},
cls.ON_PRE_RUN: {None: []},
cls.ON_PREPARE_STATE: {None: []},
cls.ON_APPLY_HOOKS: {None: []},
cls.ON_REGISTER_ALL_HOOK_PATCHES: {None: []},
cls.ON_INJECT_MODEL: {None: []},
cls.ON_EJECT_MODEL: {None: []},
}

class WrappersMP:
Expand All @@ -132,8 +132,8 @@ class WrappersMP:
@classmethod
def init_wrappers(cls):
return {
cls.OUTER_SAMPLE: [],
cls.CALC_COND_BATCH: [],
cls.OUTER_SAMPLE: {None: []},
cls.CALC_COND_BATCH: {None: []},
}

class WrapperExecutor:
Expand Down Expand Up @@ -244,8 +244,8 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up

self.attachments: Dict[str] = {}
self.additional_models: Dict[str, List[ModelPatcher]] = {}
self.callbacks: Dict[str, List[Callable]] = CallbacksMP.init_callbacks()
self.wrappers: Dict[str, List[Callable]] = WrappersMP.init_wrappers()
self.callbacks: Dict[str, Dict[str, List[Callable]]] = CallbacksMP.init_callbacks()
self.wrappers: Dict[str, Dict[str, List[Callable]]] = WrappersMP.init_wrappers()

self.is_injected = False
self.skip_injection = False
Expand Down Expand Up @@ -305,10 +305,14 @@ def clone(self):
n.additional_models[k] = [x.clone() for x in c]
# callbacks
for k, c in self.callbacks.items():
n.callbacks[k] = c.copy()
n.callbacks[k] = {}
for k1, c1 in c.items():
n.callbacks[k][k1] = c1.copy()
# sample wrappers
for k, w in self.wrappers.items():
n.wrappers[k] = w.copy()
n.wrappers[k] = {}
for k1, w1 in w.items():
n.wrappers[k][k1] = w1.copy()
# injection
n.is_injected = self.is_injected
n.skip_injection = self.skip_injection
Expand All @@ -327,7 +331,7 @@ def clone(self):
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
n.hook_mode = self.hook_mode

for callback in self.get_callbacks(CallbacksMP.ON_CLONE):
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n

Expand Down Expand Up @@ -630,7 +634,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter

for callback in self.get_callbacks(CallbacksMP.ON_LOAD):
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)

self.apply_hooks(self.forced_hooks)
Expand Down Expand Up @@ -762,7 +766,7 @@ def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float3

def cleanup(self):
self.clean_hooks()
for callback in self.get_callbacks(CallbacksMP.ON_CLEANUP):
for callback in self.get_all_callbacks(CallbacksMP.ON_CLEANUP):
callback(self)

def get_all_additional_models(self):
Expand All @@ -771,21 +775,41 @@ def get_all_additional_models(self):
all_models.extend(models)
return all_models

def add_callback(self, key: str, callback: Callable):
if key not in self.callbacks:
raise Exception(f"Callback '{key}' is not recognized.")
self.callbacks[key].append(callback)

def get_callbacks(self, key: str):
return self.callbacks.get(key, [])
def add_callback(self, call_type: str, callback: Callable):
self.add_callback_with_key(call_type, None, callback)

def add_wrapper(self, key: str, wrapper: Callable):
if key not in self.wrappers:
raise Exception(f"Wrapper '{key}' is not recognized.")
self.wrappers[key].append(wrapper)
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
if call_type not in self.callbacks:
raise Exception(f"Callback '{call_type}' is not recognized.")
c = self.callbacks[call_type].setdefault(key, [])
c.append(callback)

def get_callbacks(self, call_type: str, key: str):
return self.callbacks.get(call_type, {}).get(key, [])

def get_wrappers(self, key: str):
return self.wrappers.get(key, [])
def get_all_callbacks(self, call_type: str):
c_list = []
for c in self.callbacks.get(call_type, {}).values():
c_list.extend(c)
return c_list

def add_wrapper(self, wrapper_type: str, wrapper: Callable):
self.add_wrapper_with_key(wrapper_type, None, wrapper)

def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
if wrapper_type not in self.wrappers:
raise Exception(f"Wrapper '{wrapper_type}' is not recognized.")
w = self.wrappers[wrapper_type].setdefault(key, [])
w.append(wrapper)

def get_wrappers(self, wrapper_type: str, key: str):
return self.wrappers.get(wrapper_type, {}).get(key, [])

def get_all_wrappers(self, wrapper_type: str):
w_list = []
for w in self.wrappers.get(wrapper_type, {}).values():
w_list.extend(w)
return w_list

def set_attachments(self, key: str, attachment):
self.attachments[key] = attachment
Expand All @@ -807,7 +831,7 @@ def inject_model(self):
inj.inject(self)
self.is_injected = True
if self.is_injected:
for callback in self.get_callbacks(CallbacksMP.ON_INJECT_MODEL):
for callback in self.get_all_callbacks(CallbacksMP.ON_INJECT_MODEL):
callback(self)

def eject_model(self):
Expand All @@ -817,15 +841,15 @@ def eject_model(self):
for inj in injections:
inj.eject(self)
self.is_injected = False
for callback in self.get_callbacks(CallbacksMP.ON_EJECT_MODEL):
for callback in self.get_all_callbacks(CallbacksMP.ON_EJECT_MODEL):
callback(self)

def pre_run(self):
for callback in self.get_callbacks(CallbacksMP.ON_PRE_RUN):
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)

def prepare_state(self, timestep):
for callback in self.get_callbacks(CallbacksMP.ON_PREPARE_STATE):
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep)

def restore_hook_patches(self):
Expand Down Expand Up @@ -863,7 +887,7 @@ def register_all_hook_patches(self, hooks_dict: Dict[comfy.hooks.EnumHookType, D
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
for hook in weight_hooks_to_register:
hook.add_hook_patches(self, target)
for callback in self.get_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
callback(self, hooks_dict, target)

def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0, is_diff=False):
Expand Down Expand Up @@ -945,7 +969,7 @@ def apply_hooks(self, hooks: comfy.hooks.HookGroup):
if self.current_hooks == hooks:
return
self.patch_hooks(hooks=hooks)
for callback in self.get_callbacks(CallbacksMP.ON_APPLY_HOOKS):
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
callback(self, hooks)

def patch_hooks(self, hooks: comfy.hooks.HookGroup):
Expand Down
4 changes: 2 additions & 2 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def finalize_default_conds(hooked_to_run: Dict[comfy.hooks.HookGroup,List[Tuple[
def calc_cond_batch(model: 'BaseModel', conds: List[List[Dict]], x_in: torch.Tensor, timestep, model_options):
executor = comfy.model_patcher.WrapperExecutor.new_executor(
outer_calc_cond_batch,
model.current_patcher.get_wrappers(comfy.model_patcher.WrappersMP.CALC_COND_BATCH)
model.current_patcher.get_all_wrappers(comfy.model_patcher.WrappersMP.CALC_COND_BATCH)
)
return executor._execute(model, conds, x_in, timestep, model_options)

Expand Down Expand Up @@ -808,7 +808,7 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds)
executor = comfy.model_patcher.WrapperClassExecutor.new_executor(
self.outer_sample,
self.model_patcher.get_wrappers(comfy.model_patcher.WrappersMP.OUTER_SAMPLE)
self.model_patcher.get_all_wrappers(comfy.model_patcher.WrappersMP.OUTER_SAMPLE)
)
output = executor._execute(self, noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
Expand Down

0 comments on commit 7c86407

Please sign in to comment.