From 2f6c7051525d16727ef88288dc63b05dd2efe350 Mon Sep 17 00:00:00 2001 From: Heisenbergye Date: Mon, 9 Dec 2024 09:23:23 +0000 Subject: [PATCH] fix images validation in _parse_image method --- src/api/models/bedrock.py | 62 +++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 22 deletions(-) diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 48e9c59..e00dce3 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -738,35 +738,53 @@ def _create_response_stream( return None - def _parse_image(self, image_url: str) -> tuple[bytes, str]: - """Try to get the raw data from an image url. + def _get_supported_image_types(self) -> list[str]: + return ["png", "jpeg", "gif", "webp"] - Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html - returns a tuple of (Image Data, Content Type) - """ - pattern = r"^data:(image/[a-z]*);base64,\s*" - content_type = re.search(pattern, image_url) - # if already base64 encoded. - # Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp' - if content_type: - image_data = re.sub(pattern, "", image_url) - return base64.b64decode(image_data), content_type.group(1) - - # Send a request to the image URL - response = requests.get(image_url) - # Check if the request was successful - if response.status_code == 200: + def _parse_image(self, image_url: str) -> tuple[bytes, str]: + + if "base64" in image_url: + + # base 64 is passed as data:image/jpeg;base64, + image_metadata, image_data = image_url.split(",") + + content_type_match = re.match(r"data:(.*?);base64", image_metadata) + content_type = content_type_match.group(1) + image_format = content_type.split("/")[1] + supported_image_formats = self._get_supported_image_types() + + if image_format in supported_image_formats: + return base64.b64decode(image_data), content_type + else: + # Handle the case when the image format is not supported + raise ValueError( + "Unsupported image format: {}. Supported formats: {}".format( + content_type, supported_image_formats + ) + ) + elif image_url.startswith(('http://', 'https://')): + # Send a request to the image URL + response = requests.get(image_url) + # if fetch image + response.raise_for_status() + content_type = response.headers.get("Content-Type") - if not content_type.startswith("image"): - content_type = "image/jpeg" + if not content_type or "image" not in content_type: + #content_type = "image/jpeg" + raise ValueError( + f"URL does not point to a valid image content_type, such as: (content-type: 'image/jpeg')" + ) + # Get the image content image_content = response.content return image_content, content_type else: - raise HTTPException( - status_code=500, detail="Unable to access the image url" - ) + raise ValueError( + "Unsupported image type. Expected either image url or base64 encoded string - \ + e.g. 'data:image/jpeg;base64,'" + ) + def _parse_content_parts( self,