diff --git a/scripts/depth-image-io.py b/scripts/depth-image-io.py index a81370a..becd16e 100644 --- a/scripts/depth-image-io.py +++ b/scripts/depth-image-io.py @@ -88,6 +88,12 @@ def concat_processed_outputs(p1, p2): p1.images += p2.images return p1 +def sanitize_pil_image_mode(img): + # Set of PIL image modes for which, in a greyscale image, the first channel would NOT be a good representation of brightness: + invalid_modes = {'P', 'CMYK', 'HSV'} + if img.mode in invalid_modes: + img = img.convert(mode='RGB') + return img class Script(scripts.Script): def title(self): @@ -98,7 +104,7 @@ def ui(self, is_img2img): with gr.Accordion("Notes and Hints (click to expand)", open=False): gr.Markdown(instructions) gr.Markdown("---\n\nPut depth image here ⤵") - input_depth_img = gr.Image(source='upload', type="pil") + input_depth_img = gr.Image(source='upload', type="pil", image_mode=None) # If we don't say image_mode is None, Gradio will auto-convert to RGB and potentially destroy data. with gr.Accordion("Batch Processing (Experimental)", open=False): batch_many_to_many = gr.Checkbox(False, label="Batch each depth image against every single color image. (Warning: Use cautiously with large batches!)") batch_img_input = gr.File(file_types=['image'], file_count='multiple', label="Input Color Images") @@ -118,12 +124,12 @@ def run_inner(self, p, input_depth_img, show_depth): def alt_depth_image_conditioning(self, source_image): conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) if use_custom_depth_input: - depth_data = rearrange(np.array(input_depth_img.convert("RGB")), "h w c -> c 1 1 h w")[0] # Rearrange and discard anything past the first channel. - # Curious minds may want to know why we convert to 'RGB'; this is because it's - # theoretically possible the image was for instance palletized. Converting to 'L' mode - # is also possible but the image is LIKELY to in fact be a black-and-white RGB image - # (because almost no one uses single-channel images) so this is likely to be more efficient. - depth_data = torch.from_numpy(depth_data).to(device=shared.device).to(dtype=torch.float32) # That the range is 1.0-255.0 doesn't matter; we're going to normalize it anyway. + depth_data = np.array(sanitize_pil_image_mode(input_depth_img)) + if len(np.shape(depth_data)) == 2: # that is, if it's a single-channel image with only width and height. + depth_data = rearrange(depth_data, "h w -> 1 1 h w") + else: + depth_data = rearrange(depth_data, "h w c -> c 1 1 h w")[0] # Rearrange and discard anything past the first channel. + depth_data = torch.from_numpy(depth_data).to(device=shared.device).to(dtype=torch.float32) # Whatever the color range was (e.g. 0 to 255 for 8-bit) doesn't matter; we're going to normalize it anyway. depth_data = repeat(depth_data, "1 ... -> n ...", n=self.batch_size) else: transformer = AddMiDaS(model_type="dpt_hybrid")