Skip to content

Commit

Permalink
add images.read to automatically fix all jpeg/png weirdness
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Mar 4, 2024
1 parent 5625ce1 commit 09b5ce6
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 71 deletions.
6 changes: 2 additions & 4 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,15 @@ def decode_base64_to_image(encoding):
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
response = requests.get(encoding, timeout=30, headers=headers)
try:
image = Image.open(BytesIO(response.content))
image = images.apply_exif_orientation(image)
image = images.read(BytesIO(response.content))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid image url") from e

if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
image = images.apply_exif_orientation(image)
image = images.read(BytesIO(base64.b64decode(encoding)))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
Expand Down
64 changes: 18 additions & 46 deletions modules/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
import piexif
import piexif.helper
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin, ImageOps
import string
import json
import hashlib
Expand Down Expand Up @@ -551,12 +551,6 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p
else:
pnginfo_data = None

# Error handling for unsupported transparency in RGB mode
if (image.mode == "RGB" and
"transparency" in image.info and
isinstance(image.info["transparency"], bytes)):
del image.info["transparency"]

image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)

elif extension.lower() in (".jpg", ".jpeg", ".webp"):
Expand Down Expand Up @@ -779,7 +773,7 @@ def image_data(data):
import gradio as gr

try:
image = Image.open(io.BytesIO(data))
image = read(io.BytesIO(data))
textinfo, _ = read_info_from_image(image)
return textinfo, None
except Exception:
Expand Down Expand Up @@ -807,51 +801,29 @@ def flatten(img, bgcolor):
return img.convert('RGB')


# https://www.exiv2.org/tags.html
_EXIF_ORIENT = 274 # exif 'Orientation' tag

def apply_exif_orientation(image):
"""
Applies the exif orientation correctly.
This code exists per the bug:
https://github.com/python-pillow/Pillow/issues/3973
with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
various methods, especially `tobytes`
def read(fp, **kwargs):
image = Image.open(fp, **kwargs)
image = fix_image(image)

Function based on:
https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527
return image

Args:
image (PIL.Image): a PIL image

Returns:
(PIL.Image): the PIL image with exif orientation applied, if applicable
"""
if not hasattr(image, "getexif"):
return image
def fix_image(image: Image.Image):
if image is None:
return None

try:
exif = image.getexif()
except Exception: # https://github.com/facebookresearch/detectron2/issues/1885
exif = None
image = ImageOps.exif_transpose(image)
image = fix_png_transparency(image)
except Exception:
pass

if exif is None:
return image
return image

orientation = exif.get(_EXIF_ORIENT)

method = {
2: Image.FLIP_LEFT_RIGHT,
3: Image.ROTATE_180,
4: Image.FLIP_TOP_BOTTOM,
5: Image.TRANSPOSE,
6: Image.ROTATE_270,
7: Image.TRANSVERSE,
8: Image.ROTATE_90,
}.get(orientation)
def fix_png_transparency(image: Image.Image):
if image.mode not in ("RGB", "P") or not isinstance(image.info.get("transparency"), bytes):
return image

if method is not None:
return image.transpose(method)
image = image.convert("RGBA")
return image
25 changes: 12 additions & 13 deletions modules/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
import gradio as gr

from modules import images as imgutil
from modules import images
from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
Expand All @@ -21,7 +21,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
output_dir = output_dir.strip()
processing.fix_seed(p)

images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
batch_images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))

is_inpaint_batch = False
if inpaint_mask_dir:
Expand All @@ -31,9 +31,9 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
if is_inpaint_batch:
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")

print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
print(f"Will process {len(batch_images)} images, creating {p.n_iter * p.batch_size} new images for each.")

state.job_count = len(images) * p.n_iter
state.job_count = len(batch_images) * p.n_iter

# extract "default" params to use in case getting png info fails
prompt = p.prompt
Expand All @@ -46,16 +46,16 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
batch_results = None
discard_further_results = False
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
for i, image in enumerate(batch_images):
state.job = f"{i+1} out of {len(batch_images)}"
if state.skipped:
state.skipped = False

if state.interrupted or state.stopping_generation:
break

try:
img = Image.open(image)
img = images.read(image)
except UnidentifiedImageError as e:
print(e)
continue
Expand Down Expand Up @@ -86,16 +86,16 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
# otherwise user has many masks with the same name but different extensions
mask_image_path = masks_found[0]

mask_image = Image.open(mask_image_path)
mask_image = images.read(mask_image_path)
p.image_mask = mask_image

if use_png_info:
try:
info_img = img
if png_info_dir:
info_img_path = os.path.join(png_info_dir, os.path.basename(image))
info_img = Image.open(info_img_path)
geninfo, _ = imgutil.read_info_from_image(info_img)
info_img = images.read(info_img_path)
geninfo, _ = images.read_info_from_image(info_img)
parsed_parameters = parse_generation_parameters(geninfo)
parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
except Exception:
Expand Down Expand Up @@ -175,9 +175,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
image = None
mask = None

# Use the EXIF orientation of photos taken by smartphones.
if image is not None:
image = ImageOps.exif_transpose(image)
image = images.fix_image(image)
mask = images.fix_image(mask)

if selected_scale_tab == 1 and not is_batch:
assert image, "Can't scale by because no image is selected"
Expand Down
6 changes: 3 additions & 3 deletions modules/infotext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import gradio as gr
from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions
from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, images
from PIL import Image

sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
Expand Down Expand Up @@ -83,7 +83,7 @@ def image_from_url_text(filedata):
assert is_in_right_dir, 'trying to open image file outside of allowed directories'

filename = filename.rsplit('?', 1)[0]
return Image.open(filename)
return images.read(filename)

if type(filedata) == list:
if len(filedata) == 0:
Expand All @@ -95,7 +95,7 @@ def image_from_url_text(filedata):
filedata = filedata[len("data:image/png;base64,"):]

filedata = base64.decodebytes(filedata.encode('utf-8'))
image = Image.open(io.BytesIO(filedata))
image = images.read(io.BytesIO(filedata))
return image


Expand Down
6 changes: 3 additions & 3 deletions modules/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def get_images(extras_mode, image, image_folder, input_dir):
if extras_mode == 1:
for img in image_folder:
if isinstance(img, Image.Image):
image = img
image = images.fix_image(img)
fn = ''
else:
image = Image.open(os.path.abspath(img.name))
image = images.read(os.path.abspath(img.name))
fn = os.path.splitext(img.orig_name)[0]
yield image, fn
elif extras_mode == 2:
Expand Down Expand Up @@ -56,7 +56,7 @@ def get_images(extras_mode, image, image_folder, input_dir):

if isinstance(image_placeholder, str):
try:
image_data = Image.open(image_placeholder)
image_data = images.read(image_placeholder)
except Exception:
continue
else:
Expand Down
4 changes: 2 additions & 2 deletions modules/textual_inversion/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import random
import tqdm
from modules import devices, shared
from modules import devices, shared, images
import re

from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_to
if shared.state.interrupted:
raise Exception("interrupted")
try:
image = Image.open(path)
image = images.read(path)
#Currently does not work for single color transparency
#We would need to read image.info['transparency'] for that
if use_weight and 'A' in image.getbands():
Expand Down

0 comments on commit 09b5ce6

Please sign in to comment.