From d5cddbda5cc15dfae2c8af75f594923c9822c92b Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Sun, 12 May 2024 13:04:24 +0200 Subject: [PATCH] Protect against huge image sizes --- .../nodes/properties/outputs/numpy_outputs.py | 5 ++++- backend/src/nodes/utils/utils.py | 21 +++++++++++++++++++ .../image_dimension/resize/resize.py | 6 ++++-- .../resize/resize_pixel_art.py | 5 +++++ .../image_dimension/resize/resize_to_side.py | 6 ++++-- src/common/types/chainner-scope.ts | 12 +++++++++++ 6 files changed, 50 insertions(+), 5 deletions(-) diff --git a/backend/src/nodes/properties/outputs/numpy_outputs.py b/backend/src/nodes/properties/outputs/numpy_outputs.py index 719ef9caab..b3acf37032 100644 --- a/backend/src/nodes/properties/outputs/numpy_outputs.py +++ b/backend/src/nodes/properties/outputs/numpy_outputs.py @@ -11,7 +11,7 @@ from ...impl.image_utils import normalize, to_uint8 from ...impl.resize import ResizeFilter, resize from ...utils.format import format_image_with_channels -from ...utils.utils import get_h_w_c, round_half_up +from ...utils.utils import IMAGE_SIZE_LIMIT, get_h_w_c, round_half_up class NumPyOutput(BaseOutput[np.ndarray]): @@ -60,6 +60,9 @@ def __init__( ) if shape_as is not None: image_type = navi.intersect_with_error(image_type, f"Input{shape_as}") + image_type = navi.fn( + "assert_image_size", image_type, navi.literal(IMAGE_SIZE_LIMIT) + ) super().__init__(image_type, label, kind=kind, has_handle=has_handle) diff --git a/backend/src/nodes/utils/utils.py b/backend/src/nodes/utils/utils.py index 7c85dc19c3..bd460187e0 100644 --- a/backend/src/nodes/utils/utils.py +++ b/backend/src/nodes/utils/utils.py @@ -9,6 +9,7 @@ from typing import Tuple import numpy as np +import psutil from sanic.log import logger Size = Tuple[int, int] @@ -40,6 +41,26 @@ def get_h_w_c(image: np.ndarray) -> tuple[int, int, int]: return h, w, c +IMAGE_SIZE_LIMIT = int(psutil.virtual_memory().total * 0.5) +""" +The maximum size of an image in bytes that can be processed by the backend. +""" + + +def assert_image_dimensions(shape: tuple[int, int] | tuple[int, int, int]): + h, w = shape[:2] + c = 1 if len(shape) == 2 else shape[2] + + size_in_bytes = h * w * c * 4 + + if size_in_bytes > IMAGE_SIZE_LIMIT: + size_format = f"{round(size_in_bytes / 1024 / 1024 / 1024, 1)} GB" + raise AssertionError( + f"Your machine does not have enough RAM for a {w}x{h}x{c} image ({size_format}). " + f"Please reduce the size of the image." + ) + + def alphanumeric_sort(value: str) -> list[str | int]: """Key function to sort strings containing numbers by proper numerical order.""" diff --git a/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize.py b/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize.py index a2b7428d58..a4554974d5 100644 --- a/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize.py +++ b/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize.py @@ -15,7 +15,7 @@ ResizeFilterInput, ) from nodes.properties.outputs import ImageOutput -from nodes.utils.utils import get_h_w_c, round_half_up +from nodes.utils.utils import assert_image_dimensions, get_h_w_c, round_half_up from .. import resize_group @@ -114,7 +114,7 @@ def resize_node( filter: ResizeFilter, separate_alpha: bool, ) -> np.ndarray: - h, w, _ = get_h_w_c(img) + h, w, c = get_h_w_c(img) out_dims: tuple[int, int] if mode == ImageResizeMode.PERCENTAGE: @@ -125,6 +125,8 @@ def resize_node( else: out_dims = (width, height) + assert_image_dimensions((out_dims[1], out_dims[0], c)) + return resize( img, out_dims, diff --git a/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize_pixel_art.py b/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize_pixel_art.py index 301789a64e..fd2609b79b 100644 --- a/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize_pixel_art.py +++ b/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize_pixel_art.py @@ -10,6 +10,7 @@ ImageInput, ) from nodes.properties.outputs import ImageOutput +from nodes.utils.utils import assert_image_dimensions, get_h_w_c from .. import resize_group @@ -106,4 +107,8 @@ def resize_pixel_art_node( img: np.ndarray, algorithm: ResizeAlgorithm, ) -> np.ndarray: + h, w, c = get_h_w_c(img) + + assert_image_dimensions((h * algorithm.scale, w * algorithm.scale, c)) + return pixel_art_upscale(img, algorithm.algorithm, algorithm.scale) diff --git a/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize_to_side.py b/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize_to_side.py index 3ba744acd1..6490fa20cc 100644 --- a/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize_to_side.py +++ b/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize_to_side.py @@ -12,7 +12,7 @@ ResizeFilterInput, ) from nodes.properties.outputs import ImageOutput -from nodes.utils.utils import get_h_w_c, round_half_up +from nodes.utils.utils import assert_image_dimensions, get_h_w_c, round_half_up from .. import resize_group @@ -173,7 +173,9 @@ def resize_to_side_node( condition: ResizeCondition, filter: ResizeFilter, ) -> np.ndarray: - h, w, _ = get_h_w_c(img) + h, w, c = get_h_w_c(img) out_dims = resize_to_side_conditional(w, h, target, side, condition) + assert_image_dimensions((out_dims[1], out_dims[0], c)) + return resize(img, out_dims, filter) diff --git a/src/common/types/chainner-scope.ts b/src/common/types/chainner-scope.ts index 39558e10c6..51c7ad3d91 100644 --- a/src/common/types/chainner-scope.ts +++ b/src/common/types/chainner-scope.ts @@ -47,6 +47,18 @@ struct Image { channels: int(1..), } struct Color { channels: int(1..) } +def assert_image_size(value: any, max_size: uint): any { + match value { + Image as image => { + if image.width * image.height * image.channels * 4 > max_size { + error("The output image of this operation is too large for your machine. Please reduce the size of the image.") + } else { + image + } + }, + _ => value, + } +} struct Video;