Skip to content

Commit

Permalink
Added forward_timestep_embed_patch type, added helper functions on Mo…
Browse files Browse the repository at this point in the history
…delPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks
  • Loading branch information
Kosinkadink committed Sep 24, 2024
1 parent 7c86407 commit da6c045
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 8 deletions.
13 changes: 12 additions & 1 deletion comfy/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,18 @@ class EnumWeightTarget(enum.Enum):
class _HookRef:
pass

# NOTE: this is an example of how the should_register function should look
def default_should_register(hook: 'Hook', model: 'ModelPatcher', target: EnumWeightTarget):
return True


class Hook:
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None,
hook_keyframe: 'HookKeyframeGroup'=None):
self.hook_type = hook_type
self.hook_ref = hook_ref if hook_ref else _HookRef()
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
self.custom_should_register = default_should_register

@property
def strength(self):
Expand All @@ -57,15 +63,18 @@ def clone(self, subtype: Callable=None):
c.hook_type = self.hook_type
c.hook_ref = self.hook_ref
c.hook_keyframe = self.hook_keyframe
c.custom_should_register = self.custom_should_register
return c

def should_register(self, model: 'ModelPatcher', target: EnumWeightTarget):
return self.custom_should_register(self, model, target)

def __eq__(self, other: 'Hook'):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref

def __hash__(self):
return hash(self.hook_ref)


class WeightHook(Hook):
def __init__(self, strength_model=1.0, strength_clip=1.0):
super().__init__(hook_type=EnumHookType.Weight)
Expand All @@ -85,6 +94,8 @@ def strength_clip(self):
return self._strength_clip * self.strength

def add_hook_patches(self, model: 'ModelPatcher', target: EnumWeightTarget):
if not self.should_register(model, target):
return
weights = None
if target == EnumWeightTarget.Model:
strength = self._strength_model
Expand Down
9 changes: 9 additions & 0 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue
x = layer(x)
return x

Expand Down
40 changes: 34 additions & 6 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,12 @@ def set_model_input_block_patch_after_skip(self, patch):
def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch")

def set_model_emb_patch(self, patch):
self.set_model_patch(patch, "emb_patch")

def set_model_forward_timestep_embed_patch(self, patch):
self.set_model_patch(patch, "forward_timestep_embed_patch")

def add_object_patch(self, name, obj):
self.object_patches[name] = obj

Expand Down Expand Up @@ -769,12 +775,6 @@ def cleanup(self):
for callback in self.get_all_callbacks(CallbacksMP.ON_CLEANUP):
callback(self)

def get_all_additional_models(self):
all_models = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models

def add_callback(self, call_type: str, callback: Callable):
self.add_callback_with_key(call_type, None, callback)

Expand All @@ -784,6 +784,11 @@ def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
c = self.callbacks[call_type].setdefault(key, [])
c.append(callback)

def remove_callbacks_with_key(self, call_type: str, key: str):
c = self.callbacks.get(call_type, {})
if key in c:
c.pop(key)

def get_callbacks(self, call_type: str, key: str):
return self.callbacks.get(call_type, {}).get(key, [])

Expand All @@ -802,6 +807,11 @@ def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
w = self.wrappers[wrapper_type].setdefault(key, [])
w.append(wrapper)

def remove_wrappers_with_key(self, wrapper_type: str, key: str):
w = self.wrappers.get(wrapper_type, {})
if key in w:
w.pop(key)

def get_wrappers(self, wrapper_type: str, key: str):
return self.wrappers.get(wrapper_type, {}).get(key, [])

Expand All @@ -814,12 +824,30 @@ def get_all_wrappers(self, wrapper_type: str):
def set_attachments(self, key: str, attachment):
self.attachments[key] = attachment

def remove_attachments(self, key: str):
if key in self.attachments:
self.attachments.pop(key)

def set_injections(self, key: str, injections: List[PatcherInjection]):
self.injections[key] = injections

def remove_injections(self, key: str):
if key in self.injections:
self.injections.pop(key)

def set_additional_models(self, key: str, models: List['ModelPatcher']):
self.additional_models[key] = models

def remove_additional_models(self, key: str):
if key in self.additional_models:
self.additional_models.pop(key)

def get_all_additional_models(self):
all_models = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models

def use_ejected(self, skip_and_inject_on_exit_only=False):
return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only)

Expand Down
1 change: 0 additions & 1 deletion comfy_extras/nodes_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def INPUT_TYPES(s):
}

RETURN_TYPES = ("CONDITIONING",)
RETURN_NAMES = ("positive", "negative")
CATEGORY = "advanced/hooks/cond single"
FUNCTION = "set_properties"

Expand Down

0 comments on commit da6c045

Please sign in to comment.