Skip to content

Commit

Permalink
Controlnet code refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 7, 2024
1 parent 1c08bf3 commit c19dcd3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
36 changes: 24 additions & 12 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,16 @@ def get_control(self, x_noisy, t, cond, batched_number):
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)

context = cond.get('crossattn_controlnet', cond['c_crossattn'])
y = cond.get('y', None)
if y is not None:
y = y.to(dtype)
extra = self.extra_args.copy()
for c in ["y", "guidance"]: #TODO
temp = cond.get(c, None)
if temp is not None:
extra[c] = temp.to(dtype)

timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype)

def copy(self):
Expand Down Expand Up @@ -338,12 +341,8 @@ def get_models(self):
def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)

def load_controlnet_mmdit(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
def controlnet_config(sd):
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)

supported_inference_dtypes = model_config.supported_inference_dtypes

Expand All @@ -356,14 +355,27 @@ def load_controlnet_mmdit(sd):
else:
operations = comfy.ops.disable_weight_init

control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
return model_config, operations, load_device, unet_dtype, manual_cast_dtype

def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)

if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))

if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model

def load_controlnet_mmdit(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]

control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)

latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness
Expand Down
4 changes: 2 additions & 2 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["depth"] = 19
dit_config["depth_single_blocks"] = 38
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000
dit_config["qkv_bias"] = True
Expand Down

0 comments on commit c19dcd3

Please sign in to comment.