Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle invalid images #2312

Merged
merged 6 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions lmdeploy/vl/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from lmdeploy.model import BaseModel
from lmdeploy.utils import get_logger
from lmdeploy.vl.constants import IMAGE_TOKEN
from lmdeploy.vl.utils import encode_image_base64, load_image
from lmdeploy.vl.utils import load_image

logger = get_logger('lmdeploy')

Expand Down Expand Up @@ -43,15 +43,11 @@ def prompt_to_messages(self, prompt: VLPromptType):
# 'image_url': means url or local path to image.
# 'image_data': means PIL.Image.Image object.
if isinstance(image, str):
image_base64_data = encode_image_base64(image)
if image_base64_data == '':
logger.error(f'failed to load file {image}')
continue
image = load_image(image)
item = {
'type': 'image_url',
'image_url': {
'url':
f'data:image/jpeg;base64,{image_base64_data}'
'type': 'image_data',
'image_data': {
'data': image
}
}
elif isinstance(image, PIL.Image.Image):
Expand Down
88 changes: 52 additions & 36 deletions lmdeploy/vl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,44 @@
from typing import Union

import requests
from PIL import Image
from PIL import Image, ImageFile

from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


def encode_image_base64(image: Union[str, Image.Image]) -> str:
"""encode raw date to base64 format."""
res = ''
if isinstance(image, str):
url_or_path = image
if url_or_path.startswith('http'):
FETCH_TIMEOUT = int(os.environ.get('LMDEPLOY_FETCH_TIMEOUT', 10))
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 '
'(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
}
try:
buffered = BytesIO()
FETCH_TIMEOUT = int(os.environ.get('LMDEPLOY_FETCH_TIMEOUT', 10))
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 '
'(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
}
try:
if isinstance(image, str):
url_or_path = image
if url_or_path.startswith('http'):
response = requests.get(url_or_path,
headers=headers,
timeout=FETCH_TIMEOUT)
response.raise_for_status()
res = base64.b64encode(response.content).decode('utf-8')
except Exception:
pass
elif os.path.exists(url_or_path):
with open(url_or_path, 'rb') as image_file:
res = base64.b64encode(image_file.read()).decode('utf-8')
elif isinstance(image, Image.Image):
buffered = BytesIO()
buffered.write(response.content)
elif os.path.exists(url_or_path):
with open(url_or_path, 'rb') as image_file:
buffered.write(image_file.read())
elif isinstance(image, Image.Image):
image.save(buffered, format='PNG')
except Exception as error:
if isinstance(image, str) and len(image) > 100:
image = image[:100] + ' ...'
logger.error(f'{error}, image={image}')
# use dummy image
image = Image.new('RGB', (32, 32))
image.save(buffered, format='PNG')
res = base64.b64encode(buffered.getvalue()).decode('utf-8')
res = base64.b64encode(buffered.getvalue()).decode('utf-8')
return res


Expand All @@ -45,27 +53,35 @@ def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:

def load_image(image_url: Union[str, Image.Image]) -> Image.Image:
"""load image from url, local path or openai GPT4V."""
if isinstance(image_url, Image.Image):
return image_url

FETCH_TIMEOUT = int(os.environ.get('LMDEPLOY_FETCH_TIMEOUT', 10))
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 '
'(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
}
if image_url.startswith('http'):
response = requests.get(image_url,
headers=headers,
timeout=FETCH_TIMEOUT)
response.raise_for_status()
try:
ImageFile.LOAD_TRUNCATED_IMAGES = True
if isinstance(image_url, Image.Image):
img = image_url
elif image_url.startswith('http'):
response = requests.get(image_url,
headers=headers,
timeout=FETCH_TIMEOUT)
response.raise_for_status()
img = Image.open(BytesIO(response.content))
elif image_url.startswith('data:image'):
img = load_image_from_base64(image_url.split(',')[1])
else:
# Load image from local path
img = Image.open(image_url)

# Open the image using PIL
img = Image.open(BytesIO(response.content))
elif image_url.startswith('data:image'):
img = load_image_from_base64(image_url.split(',')[1])
else:
# Load image from local path
img = Image.open(image_url)
# check image valid
img = img.convert('RGB')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to validate the image?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently all models will convert the image to RGB, and checking here will not increase the time consumption

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to validate the image?

I'm not sure if load() function satisfied the need. Besides, if convert('RGB') is a must for all the model, we may do it here and remove them from forward function.

except Exception as error:
if isinstance(image_url, str) and len(image_url) > 100:
image_url = image_url[:100] + ' ...'
logger.error(f'{error}, image_url={image_url}')
# use dummy image
img = Image.new('RGB', (32, 32))

return img
2 changes: 1 addition & 1 deletion tests/test_lmdeploy/test_vl_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ def test_encode_image_base64():
im1 = load_image(url)
base64 = encode_image_base64(url)
im2 = load_image_from_base64(base64)
assert im1 == im2
assert im1 == im2.convert('RGB')
Loading