Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch folder fix #2093

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 62 additions & 24 deletions extensions-builtin/sd_forge_controlnet/scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,16 +340,11 @@ def optional_tqdm(iterable, use_tqdm):
if input_mask is not None:
control_masks.append(input_mask)

if len(input_list) > 1 and not preprocessor_output_is_image:
logger.info('Batch wise input only support controlnet, control-lora, and t2i adapters!')
break

if has_high_res_fix:
hr_option = HiResFixOption.from_value(unit.hr_option)
else:
hr_option = HiResFixOption.BOTH

alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
def attach_extra_result_image(img: np.ndarray, is_high_res: bool = False):
if (
(is_high_res and hr_option.high_res_enabled) or
Expand All @@ -364,16 +359,20 @@ def attach_extra_result_image(img: np.ndarray, is_high_res: bool = False):
for preprocessor_output in preprocessor_outputs:
control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w)
attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond))
params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1))

params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous()
control_cond = numpy_to_pytorch(control_cond).movedim(-1, 1)
params.control_cond.append(control_cond)

if has_high_res_fix:
for preprocessor_output in preprocessor_outputs:
if has_high_res_fix and hr_option != HiResFixOption.LOW_RES_ONLY:
control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x)
attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond_for_hr_fix), is_high_res=True)
params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1))
params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous()
control_cond_for_hr_fix = numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1)
params.control_cond_for_hr_fix.append(control_cond_for_hr_fix)
elif has_high_res_fix:
params.control_cond_for_hr_fix.append(control_cond)

params.control_cond = torch.cat(params.control_cond, dim=0)
if has_high_res_fix:
params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)
else:
params.control_cond_for_hr_fix = params.control_cond
else:
Expand All @@ -392,15 +391,17 @@ def attach_extra_result_image(img: np.ndarray, is_high_res: bool = False):
control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1]
params.control_mask.append(control_mask)

if has_high_res_fix:
if has_high_res_fix and hr_option != HiResFixOption.LOW_RES_ONLY:
control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
attach_extra_result_image(control_mask_for_hr_fix, is_high_res=True)
control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
params.control_mask_for_hr_fix.append(control_mask_for_hr_fix)
elif has_high_res_fix:
params.control_mask_for_hr_fix.append(control_mask)

params.control_mask = torch.cat(params.control_mask, dim=0)[alignment_indices].contiguous()
params.control_mask = torch.cat(params.control_mask, dim=0)
if has_high_res_fix:
params.control_mask_for_hr_fix = torch.cat(params.control_mask_for_hr_fix, dim=0)[alignment_indices].contiguous()
params.control_mask_for_hr_fix = torch.cat(params.control_mask_for_hr_fix, dim=0)
else:
params.control_mask_for_hr_fix = params.control_mask

Expand All @@ -423,13 +424,15 @@ def attach_extra_result_image(img: np.ndarray, is_high_res: bool = False):

@torch.no_grad()
def process_unit_before_every_sampling(self,
p: StableDiffusionProcessing,
unit: ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):
p: StableDiffusionProcessing,
unit: ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):

is_hr_pass = getattr(p, 'is_hr_pass', False)

current_iteration = getattr(p, 'iteration', 0)
batch_size = getattr(p, 'batch_size', 1)

has_high_res_fix = (
isinstance(p, StableDiffusionProcessingTxt2Img)
and getattr(p, 'enable_hr', False)
Expand All @@ -440,21 +443,34 @@ def process_unit_before_every_sampling(self,
else:
hr_option = HiResFixOption.BOTH

if has_high_res_fix and is_hr_pass and (not hr_option.high_res_enabled):
if has_high_res_fix and is_hr_pass and hr_option == HiResFixOption.LOW_RES_ONLY:
logger.info(f"ControlNet Skipped High-res pass.")
return

if has_high_res_fix and (not is_hr_pass) and (not hr_option.low_res_enabled):
if has_high_res_fix and (not is_hr_pass) and hr_option == HiResFixOption.HIGH_RES_ONLY:
logger.info(f"ControlNet Skipped Low-res pass.")
return

if is_hr_pass:
if is_hr_pass and hr_option != HiResFixOption.LOW_RES_ONLY:
cond = params.control_cond_for_hr_fix
mask = params.control_mask_for_hr_fix
else:
cond = params.control_cond
mask = params.control_mask

if isinstance(cond, torch.Tensor) and len(cond.shape) == 4: # Tensor Batch [B, C, H, W]
total_images = cond.shape[0]
generation_index = current_iteration // batch_size
if generation_index < total_images:
cond = cond[generation_index:generation_index+1]
if mask is not None:
mask = mask[generation_index:generation_index+1]
else:
logger.warning(f"Generation index {generation_index} exceeds available images {total_images}")
cond = cond[-1:]
if mask is not None:
mask = mask[-1:]

kwargs.update(dict(
unit=unit,
params=params,
Expand Down Expand Up @@ -506,7 +522,7 @@ def process_unit_before_every_sampling(self,

params.model.process_before_every_sampling(p, cond, mask, *args, **kwargs)

logger.info(f"ControlNet Method {params.preprocessor.name} patched.")
logger.info(f"ControlNet Method {params.preprocessor.name} patched for iteration {current_iteration} with image index {generation_index if 'generation_index' in locals() else 'N/A'}")
return

@staticmethod
Expand Down Expand Up @@ -548,11 +564,33 @@ def process(self, p, *args, **kwargs):
self.current_params = {}
enabled_units = self.get_enabled_units(args)
Infotext.write_infotext(enabled_units, p)

# Find the maximum number of images in the batches among all units
max_batch_count = 1
for unit in enabled_units:
if unit.input_mode == external_code.InputMode.BATCH:
batch_image_files = shared.listfiles(unit.batch_image_dir)
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
batch_count = len([f for f in batch_image_files if any(f.lower().endswith(ext) for ext in image_extensions)])
max_batch_count = max(max_batch_count, batch_count)
elif unit.input_mode == external_code.InputMode.MERGE:
batch_count = len(unit.batch_input_gallery)
max_batch_count = max(max_batch_count, batch_count)

# Set the number of iterations for the process
initial_batch_count = getattr(p, 'n_iter', 1)
p.n_iter = initial_batch_count * max_batch_count

# Save the original batches index for later recovery
original_batch_index = getattr(p, 'batch_index', 0)
p.batch_index = original_batch_index % max_batch_count

for i, unit in enumerate(enabled_units):
self.bound_check_params(unit)
params = ControlNetCachedParameters()
self.process_unit_after_click_generate(p, unit, params, *args, **kwargs)
self.current_params[i] = params

return

@torch.no_grad()
Expand Down