forked from shiimizu/ComfyUI-TiledDiffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
158 lines (134 loc) · 5.5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class Store:
def __repr__(self):
keys = sorted(self.__dict__)
items = ("{}={!r}".format(k, self.__dict__[k]) for k in keys)
return "{}({})".format(type(self).__name__, ", ".join(items))
def __eq__(self, other):
return self.__dict__ == other.__dict__
store = Store()
# ==================== Hook into sampling functions for ControlNet ====================
import comfy.samplers
def KSAMPLER_sample(*args, **kwargs):
orig_fn = store.KSAMPLER_sample
extra_args = None
model_options = None
try:
extra_args = kwargs['extra_args'] if 'extra_args' in kwargs else args[3]
model_options = extra_args['model_options']
except Exception: ...
if model_options is not None and 'tiled_diffusion' in model_options and extra_args is not None:
sigmas_ = kwargs['sigmas'] if 'sigmas' in kwargs else args[2]
sigmas_all = model_options.pop('sigmas', None)
sigmas = sigmas_all if sigmas_all is not None else sigmas_
store.sigmas = sigmas
store.model_options = model_options
store.extra_args = extra_args
else:
for attr in ['sigmas', 'model_options', 'extra_args']:
_delattr(store, attr)
return orig_fn(*args, **kwargs)
def KSampler_sample(*args, **kwargs):
orig_fn = store.KSampler_sample
self = args[0]
model_patcher = getattr(self, 'model', None)
model_options = getattr(model_patcher, 'model_options', None)
if model_options is not None and 'tiled_diffusion' in model_options:
sigmas = None
try: sigmas = kwargs['sigmas'] if 'sigmas' in kwargs else args[10]
except Exception: ...
if sigmas is None:
sigmas = getattr(self, 'sigmas', None)
if sigmas is not None:
model_options = model_options.copy()
model_options['sigmas'] = sigmas
self.model.model_options = model_options
return orig_fn(*args, **kwargs)
def get_area_and_mult(*args, **kwargs):
conds = kwargs['conds'] if 'conds' in kwargs else args[0]
if (model_options:=getattr(store, 'model_options', None)) is not None and 'tiled_diffusion' in model_options:
if 'control' in conds:
control = conds['control']
if not hasattr(control, 'get_control_orig'):
control.get_control_orig = control.get_control
control.get_control = lambda *a, **kw: control
else:
if 'control' in conds:
control = conds['control']
if hasattr(control, 'get_control_orig') and control.get_control != control.get_control_orig:
control.get_control = control.get_control_orig
return store.get_area_and_mult(*args, **kwargs)
def _delattr(obj, attr):
try:
if hasattr(obj, attr): delattr(obj, attr)
except Exception: ...
def register_hooks():
patches = [
(comfy.samplers.KSampler, 'sample', KSampler_sample),
(comfy.samplers.KSAMPLER, 'sample', KSAMPLER_sample),
(comfy.samplers, 'get_area_and_mult', get_area_and_mult),
]
for parent, fn_name, fn_patch in patches:
if not hasattr(parent, f"_{fn_name}"):
setattr(store, f"_{fn_name}", getattr(parent, fn_name))
setattr(store, fn_patch.__name__, getattr(parent, fn_name))
setattr(parent, fn_name, fn_patch)
register_hooks()
# ==================== Patch pre_run_control ====================
# Is this necessary anymore?
def pre_run_control(model, conds):
s = model.model_sampling
for t in range(len(conds)):
x = conds[t]
timestep_start = None
timestep_end = None
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
try: x['control'].cleanup()
except Exception: ...
x['control'].pre_run(model, percent_to_timestep_function)
comfy.samplers.pre_run_control = pre_run_control
# ==================== Patch SAG ====================
from math import sqrt
import torch.nn.functional as F
import comfy_extras.nodes_sag
from comfy_extras.nodes_sag import gaussian_blur_2d
def calc_closest_factors(a):
for b in range(int(sqrt(a)), 0, -1):
if a % b == 0:
c = a // b
return (b, c)
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
m = calc_closest_factors(hw1)
mh = max(m) if lh > lw else min(m)
mw = m[1] if mh == m[0] else m[0]
mid_shape = mh, mw
# Reshape
mask = (
mask.reshape(b, *mid_shape)
.unsqueeze(1)
.type(attn.dtype)
)
# Upsample
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
comfy_extras.nodes_sag.create_blur_map = create_blur_map
# ==================== Patch Gligen ====================
def _set_position(self, boxes, masks, positive_embeddings):
objs = self.position_net(boxes, masks, positive_embeddings)
def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
nonlocal objs
_objs = objs.repeat(-(x.shape[0] // -objs.shape[0]),1,1) if x.shape[0] > objs.shape[0] else objs
return module(x, _objs.to(device=x.device, dtype=x.dtype))
return func
import comfy.gligen
comfy.gligen.Gligen._set_position = _set_position