From 965416b4c554f2e81fc0b8edd3d75a564d26bc92 Mon Sep 17 00:00:00 2001 From: Nziner Date: Wed, 16 Oct 2024 20:04:35 +0600 Subject: [PATCH] Batch folder fix --- .../sd_forge_controlnet/scripts/controlnet.py | 86 +++++++++++++------ 1 file changed, 62 insertions(+), 24 deletions(-) diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index 353601b51..e9647981e 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -340,16 +340,11 @@ def optional_tqdm(iterable, use_tqdm): if input_mask is not None: control_masks.append(input_mask) - if len(input_list) > 1 and not preprocessor_output_is_image: - logger.info('Batch wise input only support controlnet, control-lora, and t2i adapters!') - break - if has_high_res_fix: hr_option = HiResFixOption.from_value(unit.hr_option) else: hr_option = HiResFixOption.BOTH - alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)] def attach_extra_result_image(img: np.ndarray, is_high_res: bool = False): if ( (is_high_res and hr_option.high_res_enabled) or @@ -364,16 +359,20 @@ def attach_extra_result_image(img: np.ndarray, is_high_res: bool = False): for preprocessor_output in preprocessor_outputs: control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w) attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond)) - params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1)) - - params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous() + control_cond = numpy_to_pytorch(control_cond).movedim(-1, 1) + params.control_cond.append(control_cond) - if has_high_res_fix: - for preprocessor_output in preprocessor_outputs: + if has_high_res_fix and hr_option != HiResFixOption.LOW_RES_ONLY: control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x) attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond_for_hr_fix), is_high_res=True) - params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1)) - params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous() + control_cond_for_hr_fix = numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1) + params.control_cond_for_hr_fix.append(control_cond_for_hr_fix) + elif has_high_res_fix: + params.control_cond_for_hr_fix.append(control_cond) + + params.control_cond = torch.cat(params.control_cond, dim=0) + if has_high_res_fix: + params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0) else: params.control_cond_for_hr_fix = params.control_cond else: @@ -392,15 +391,17 @@ def attach_extra_result_image(img: np.ndarray, is_high_res: bool = False): control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1] params.control_mask.append(control_mask) - if has_high_res_fix: + if has_high_res_fix and hr_option != HiResFixOption.LOW_RES_ONLY: control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border) attach_extra_result_image(control_mask_for_hr_fix, is_high_res=True) control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1] params.control_mask_for_hr_fix.append(control_mask_for_hr_fix) + elif has_high_res_fix: + params.control_mask_for_hr_fix.append(control_mask) - params.control_mask = torch.cat(params.control_mask, dim=0)[alignment_indices].contiguous() + params.control_mask = torch.cat(params.control_mask, dim=0) if has_high_res_fix: - params.control_mask_for_hr_fix = torch.cat(params.control_mask_for_hr_fix, dim=0)[alignment_indices].contiguous() + params.control_mask_for_hr_fix = torch.cat(params.control_mask_for_hr_fix, dim=0) else: params.control_mask_for_hr_fix = params.control_mask @@ -423,13 +424,15 @@ def attach_extra_result_image(img: np.ndarray, is_high_res: bool = False): @torch.no_grad() def process_unit_before_every_sampling(self, - p: StableDiffusionProcessing, - unit: ControlNetUnit, - params: ControlNetCachedParameters, - *args, **kwargs): + p: StableDiffusionProcessing, + unit: ControlNetUnit, + params: ControlNetCachedParameters, + *args, **kwargs): is_hr_pass = getattr(p, 'is_hr_pass', False) - + current_iteration = getattr(p, 'iteration', 0) + batch_size = getattr(p, 'batch_size', 1) + has_high_res_fix = ( isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False) @@ -440,21 +443,34 @@ def process_unit_before_every_sampling(self, else: hr_option = HiResFixOption.BOTH - if has_high_res_fix and is_hr_pass and (not hr_option.high_res_enabled): + if has_high_res_fix and is_hr_pass and hr_option == HiResFixOption.LOW_RES_ONLY: logger.info(f"ControlNet Skipped High-res pass.") return - if has_high_res_fix and (not is_hr_pass) and (not hr_option.low_res_enabled): + if has_high_res_fix and (not is_hr_pass) and hr_option == HiResFixOption.HIGH_RES_ONLY: logger.info(f"ControlNet Skipped Low-res pass.") return - if is_hr_pass: + if is_hr_pass and hr_option != HiResFixOption.LOW_RES_ONLY: cond = params.control_cond_for_hr_fix mask = params.control_mask_for_hr_fix else: cond = params.control_cond mask = params.control_mask + if isinstance(cond, torch.Tensor) and len(cond.shape) == 4: # Tensor Batch [B, C, H, W] + total_images = cond.shape[0] + generation_index = current_iteration // batch_size + if generation_index < total_images: + cond = cond[generation_index:generation_index+1] + if mask is not None: + mask = mask[generation_index:generation_index+1] + else: + logger.warning(f"Generation index {generation_index} exceeds available images {total_images}") + cond = cond[-1:] + if mask is not None: + mask = mask[-1:] + kwargs.update(dict( unit=unit, params=params, @@ -506,7 +522,7 @@ def process_unit_before_every_sampling(self, params.model.process_before_every_sampling(p, cond, mask, *args, **kwargs) - logger.info(f"ControlNet Method {params.preprocessor.name} patched.") + logger.info(f"ControlNet Method {params.preprocessor.name} patched for iteration {current_iteration} with image index {generation_index if 'generation_index' in locals() else 'N/A'}") return @staticmethod @@ -548,11 +564,33 @@ def process(self, p, *args, **kwargs): self.current_params = {} enabled_units = self.get_enabled_units(args) Infotext.write_infotext(enabled_units, p) + + # Find the maximum number of images in the batches among all units + max_batch_count = 1 + for unit in enabled_units: + if unit.input_mode == external_code.InputMode.BATCH: + batch_image_files = shared.listfiles(unit.batch_image_dir) + image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] + batch_count = len([f for f in batch_image_files if any(f.lower().endswith(ext) for ext in image_extensions)]) + max_batch_count = max(max_batch_count, batch_count) + elif unit.input_mode == external_code.InputMode.MERGE: + batch_count = len(unit.batch_input_gallery) + max_batch_count = max(max_batch_count, batch_count) + + # Set the number of iterations for the process + initial_batch_count = getattr(p, 'n_iter', 1) + p.n_iter = initial_batch_count * max_batch_count + + # Save the original batches index for later recovery + original_batch_index = getattr(p, 'batch_index', 0) + p.batch_index = original_batch_index % max_batch_count + for i, unit in enumerate(enabled_units): self.bound_check_params(unit) params = ControlNetCachedParameters() self.process_unit_after_click_generate(p, unit, params, *args, **kwargs) self.current_params[i] = params + return @torch.no_grad()