diff --git a/lmdeploy/vl/templates.py b/lmdeploy/vl/templates.py index 0339d10a6..952a2f153 100644 --- a/lmdeploy/vl/templates.py +++ b/lmdeploy/vl/templates.py @@ -9,7 +9,7 @@ from lmdeploy.model import BaseModel from lmdeploy.utils import get_logger from lmdeploy.vl.constants import IMAGE_TOKEN -from lmdeploy.vl.utils import encode_image_base64, load_image +from lmdeploy.vl.utils import load_image logger = get_logger('lmdeploy') @@ -43,15 +43,11 @@ def prompt_to_messages(self, prompt: VLPromptType): # 'image_url': means url or local path to image. # 'image_data': means PIL.Image.Image object. if isinstance(image, str): - image_base64_data = encode_image_base64(image) - if image_base64_data == '': - logger.error(f'failed to load file {image}') - continue + image = load_image(image) item = { - 'type': 'image_url', - 'image_url': { - 'url': - f'data:image/jpeg;base64,{image_base64_data}' + 'type': 'image_data', + 'image_data': { + 'data': image } } elif isinstance(image, PIL.Image.Image): diff --git a/lmdeploy/vl/utils.py b/lmdeploy/vl/utils.py index a0a41e304..e42641523 100644 --- a/lmdeploy/vl/utils.py +++ b/lmdeploy/vl/utils.py @@ -5,36 +5,44 @@ from typing import Union import requests -from PIL import Image +from PIL import Image, ImageFile + +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') def encode_image_base64(image: Union[str, Image.Image]) -> str: """encode raw date to base64 format.""" - res = '' - if isinstance(image, str): - url_or_path = image - if url_or_path.startswith('http'): - FETCH_TIMEOUT = int(os.environ.get('LMDEPLOY_FETCH_TIMEOUT', 10)) - headers = { - 'User-Agent': - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' - '(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' - } - try: + buffered = BytesIO() + FETCH_TIMEOUT = int(os.environ.get('LMDEPLOY_FETCH_TIMEOUT', 10)) + headers = { + 'User-Agent': + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + '(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' + } + try: + if isinstance(image, str): + url_or_path = image + if url_or_path.startswith('http'): response = requests.get(url_or_path, headers=headers, timeout=FETCH_TIMEOUT) response.raise_for_status() - res = base64.b64encode(response.content).decode('utf-8') - except Exception: - pass - elif os.path.exists(url_or_path): - with open(url_or_path, 'rb') as image_file: - res = base64.b64encode(image_file.read()).decode('utf-8') - elif isinstance(image, Image.Image): - buffered = BytesIO() + buffered.write(response.content) + elif os.path.exists(url_or_path): + with open(url_or_path, 'rb') as image_file: + buffered.write(image_file.read()) + elif isinstance(image, Image.Image): + image.save(buffered, format='PNG') + except Exception as error: + if isinstance(image, str) and len(image) > 100: + image = image[:100] + ' ...' + logger.error(f'{error}, image={image}') + # use dummy image + image = Image.new('RGB', (32, 32)) image.save(buffered, format='PNG') - res = base64.b64encode(buffered.getvalue()).decode('utf-8') + res = base64.b64encode(buffered.getvalue()).decode('utf-8') return res @@ -45,27 +53,35 @@ def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: def load_image(image_url: Union[str, Image.Image]) -> Image.Image: """load image from url, local path or openai GPT4V.""" - if isinstance(image_url, Image.Image): - return image_url - FETCH_TIMEOUT = int(os.environ.get('LMDEPLOY_FETCH_TIMEOUT', 10)) headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' '(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' } - if image_url.startswith('http'): - response = requests.get(image_url, - headers=headers, - timeout=FETCH_TIMEOUT) - response.raise_for_status() + try: + ImageFile.LOAD_TRUNCATED_IMAGES = True + if isinstance(image_url, Image.Image): + img = image_url + elif image_url.startswith('http'): + response = requests.get(image_url, + headers=headers, + timeout=FETCH_TIMEOUT) + response.raise_for_status() + img = Image.open(BytesIO(response.content)) + elif image_url.startswith('data:image'): + img = load_image_from_base64(image_url.split(',')[1]) + else: + # Load image from local path + img = Image.open(image_url) - # Open the image using PIL - img = Image.open(BytesIO(response.content)) - elif image_url.startswith('data:image'): - img = load_image_from_base64(image_url.split(',')[1]) - else: - # Load image from local path - img = Image.open(image_url) + # check image valid + img = img.convert('RGB') + except Exception as error: + if isinstance(image_url, str) and len(image_url) > 100: + image_url = image_url[:100] + ' ...' + logger.error(f'{error}, image_url={image_url}') + # use dummy image + img = Image.new('RGB', (32, 32)) return img diff --git a/tests/test_lmdeploy/test_vl_encode.py b/tests/test_lmdeploy/test_vl_encode.py index 0a4a0a723..3d58994bb 100644 --- a/tests/test_lmdeploy/test_vl_encode.py +++ b/tests/test_lmdeploy/test_vl_encode.py @@ -7,4 +7,31 @@ def test_encode_image_base64(): im1 = load_image(url) base64 = encode_image_base64(url) im2 = load_image_from_base64(base64) - assert im1 == im2 + assert im1 == im2.convert('RGB') + + +def test_load_truncated_image(): + url = 'https://github.com/irexyc/lmdeploy/releases/download/v0.0.1/tr.jpeg' + im = load_image(url) + assert im.width == 1638 + assert im.height == 2048 + + +def test_load_invalid_url(): + url = ('https://raw.githubusercontent.com/open-mmlab/' + 'mmdeploy/main/tests/data/tiger.jpeg') + # invalid + im1 = load_image(url[:-1]) + assert im1.width == 32 + assert im1.height == 32 + # valid + im2 = load_image(url) + assert im2.height == 182 + assert im2.width == 278 + + +def test_load_invalid_base64(): + base64 = '' + im = load_image(base64) + assert im.width == 32 + assert im.height == 32