Skip to content

Commit

Permalink
Fixed (at least some of the) problems with 16-bit depth image input f…
Browse files Browse the repository at this point in the history
…ormats.
  • Loading branch information
AnonymousCervine authored and AnonymousCervine committed Feb 4, 2023
1 parent 5d49f71 commit eb22863
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions scripts/depth-image-io.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def concat_processed_outputs(p1, p2):
p1.images += p2.images
return p1

def sanitize_pil_image_mode(img):
# Set of PIL image modes for which, in a greyscale image, the first channel would NOT be a good representation of brightness:
invalid_modes = {'P', 'CMYK', 'HSV'}
if img.mode in invalid_modes:
img = img.convert(mode='RGB')
return img

class Script(scripts.Script):
def title(self):
Expand All @@ -98,7 +104,7 @@ def ui(self, is_img2img):
with gr.Accordion("Notes and Hints (click to expand)", open=False):
gr.Markdown(instructions)
gr.Markdown("---\n\nPut depth image here ⤵")
input_depth_img = gr.Image(source='upload', type="pil")
input_depth_img = gr.Image(source='upload', type="pil", image_mode=None) # If we don't say image_mode is None, Gradio will auto-convert to RGB and potentially destroy data.
with gr.Accordion("Batch Processing (Experimental)", open=False):
batch_many_to_many = gr.Checkbox(False, label="Batch each depth image against every single color image. (Warning: Use cautiously with large batches!)")
batch_img_input = gr.File(file_types=['image'], file_count='multiple', label="Input Color Images")
Expand All @@ -118,12 +124,12 @@ def run_inner(self, p, input_depth_img, show_depth):
def alt_depth_image_conditioning(self, source_image):
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
if use_custom_depth_input:
depth_data = rearrange(np.array(input_depth_img.convert("RGB")), "h w c -> c 1 1 h w")[0] # Rearrange and discard anything past the first channel.
# Curious minds may want to know why we convert to 'RGB'; this is because it's
# theoretically possible the image was for instance palletized. Converting to 'L' mode
# is also possible but the image is LIKELY to in fact be a black-and-white RGB image
# (because almost no one uses single-channel images) so this is likely to be more efficient.
depth_data = torch.from_numpy(depth_data).to(device=shared.device).to(dtype=torch.float32) # That the range is 1.0-255.0 doesn't matter; we're going to normalize it anyway.
depth_data = np.array(sanitize_pil_image_mode(input_depth_img))
if len(np.shape(depth_data)) == 2: # that is, if it's a single-channel image with only width and height.
depth_data = rearrange(depth_data, "h w -> 1 1 h w")
else:
depth_data = rearrange(depth_data, "h w c -> c 1 1 h w")[0] # Rearrange and discard anything past the first channel.
depth_data = torch.from_numpy(depth_data).to(device=shared.device).to(dtype=torch.float32) # Whatever the color range was (e.g. 0 to 255 for 8-bit) doesn't matter; we're going to normalize it anyway.
depth_data = repeat(depth_data, "1 ... -> n ...", n=self.batch_size)
else:
transformer = AddMiDaS(model_type="dpt_hybrid")
Expand Down

0 comments on commit eb22863

Please sign in to comment.