From 53621d0db877e07b37ea6eba7c585322fdcafdaa Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Fri, 11 Nov 2022 21:55:45 +0100 Subject: [PATCH 1/3] Move validators into separate directory --- connexion/validators.py | 321 ------------------------------ connexion/validators/__init__.py | 26 +++ connexion/validators/form_data.py | 147 ++++++++++++++ connexion/validators/json.py | 155 +++++++++++++++ 4 files changed, 328 insertions(+), 321 deletions(-) delete mode 100644 connexion/validators.py create mode 100644 connexion/validators/__init__.py create mode 100644 connexion/validators/form_data.py create mode 100644 connexion/validators/json.py diff --git a/connexion/validators.py b/connexion/validators.py deleted file mode 100644 index e4f6b179d..000000000 --- a/connexion/validators.py +++ /dev/null @@ -1,321 +0,0 @@ -""" -Contains validator classes used by the validation middleware. -""" -import json -import logging -import typing as t - -from jsonschema import Draft4Validator, ValidationError, draft4_format_checker -from starlette.datastructures import FormData, Headers, UploadFile -from starlette.formparsers import FormParser, MultiPartParser -from starlette.types import Receive, Scope, Send - -from connexion.datastructures import MediaTypeDict -from connexion.decorators.uri_parsing import AbstractURIParser -from connexion.decorators.validation import ( - ParameterValidator, - TypeValidationError, - coerce_type, -) -from connexion.exceptions import ( - BadRequestProblem, - ExtraParameterProblem, - NonConformingResponseBody, -) -from connexion.json_schema import Draft4RequestValidator, Draft4ResponseValidator -from connexion.utils import is_null - -logger = logging.getLogger("connexion.middleware.validators") - - -class JSONRequestBodyValidator: - """Request body validator for json content types.""" - - def __init__( - self, - scope: Scope, - receive: Receive, - *, - schema: dict, - validator: t.Type[Draft4Validator] = Draft4RequestValidator, - nullable=False, - encoding: str, - **kwargs, - ) -> None: - self._scope = scope - self._receive = receive - self.schema = schema - self.has_default = schema.get("default", False) - self.nullable = nullable - self.validator = validator(schema, format_checker=draft4_format_checker) - self.encoding = encoding - self._messages: t.List[t.MutableMapping[str, t.Any]] = [] - - @classmethod - def _error_path_message(cls, exception): - error_path = ".".join(str(item) for item in exception.path) - error_path_msg = f" - '{error_path}'" if error_path else "" - return error_path_msg - - def validate(self, body: dict): - try: - self.validator.validate(body) - except ValidationError as exception: - error_path_msg = self._error_path_message(exception=exception) - logger.error( - f"Validation error: {exception.message}{error_path_msg}", - extra={"validator": "body"}, - ) - raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}") - - @staticmethod - def parse(body: str) -> dict: - try: - return json.loads(body) - except json.decoder.JSONDecodeError as e: - raise BadRequestProblem(str(e)) - - async def wrapped_receive(self) -> Receive: - more_body = True - while more_body: - message = await self._receive() - self._messages.append(message) - more_body = message.get("more_body", False) - - bytes_body = b"".join([message.get("body", b"") for message in self._messages]) - decoded_body = bytes_body.decode(self.encoding) - - if decoded_body and not (self.nullable and is_null(decoded_body)): - body = self.parse(decoded_body) - self.validate(body) - - async def receive() -> t.MutableMapping[str, t.Any]: - while self._messages: - return self._messages.pop(0) - return await self._receive() - - return receive - - -class JSONResponseBodyValidator: - """Response body validator for json content types.""" - - def __init__( - self, - scope: Scope, - send: Send, - *, - schema: dict, - validator: t.Type[Draft4Validator] = Draft4ResponseValidator, - nullable=False, - encoding: str, - ) -> None: - self._scope = scope - self._send = send - self.schema = schema - self.has_default = schema.get("default", False) - self.nullable = nullable - self.validator = validator(schema, format_checker=draft4_format_checker) - self.encoding = encoding - self._messages: t.List[t.MutableMapping[str, t.Any]] = [] - - @classmethod - def _error_path_message(cls, exception): - error_path = ".".join(str(item) for item in exception.path) - error_path_msg = f" - '{error_path}'" if error_path else "" - return error_path_msg - - def validate(self, body: dict): - try: - self.validator.validate(body) - except ValidationError as exception: - error_path_msg = self._error_path_message(exception=exception) - logger.error( - f"Validation error: {exception.message}{error_path_msg}", - extra={"validator": "body"}, - ) - raise NonConformingResponseBody( - message=f"{exception.message}{error_path_msg}" - ) - - @staticmethod - def parse(body: str) -> dict: - try: - return json.loads(body) - except json.decoder.JSONDecodeError as e: - raise BadRequestProblem(str(e)) - - async def send(self, message: t.MutableMapping[str, t.Any]) -> None: - self._messages.append(message) - - if message["type"] == "http.response.start" or message.get("more_body", False): - return - - bytes_body = b"".join([message.get("body", b"") for message in self._messages]) - decoded_body = bytes_body.decode(self.encoding) - - if decoded_body and not (self.nullable and is_null(decoded_body)): - body = self.parse(decoded_body) - self.validate(body) - - while self._messages: - await self._send(self._messages.pop(0)) - - -class TextResponseBodyValidator(JSONResponseBodyValidator): - @staticmethod - def parse(body: str) -> str: # type: ignore - try: - return json.loads(body) - except json.decoder.JSONDecodeError: - return body - - -class FormDataValidator: - """Request body validator for form content types.""" - - def __init__( - self, - scope: Scope, - receive: Receive, - *, - schema: dict, - validator: t.Type[Draft4Validator] = None, - uri_parser: t.Optional[AbstractURIParser] = None, - nullable=False, - encoding: str, - strict_validation: bool, - ) -> None: - self._scope = scope - self._receive = receive - self.schema = schema - self.has_default = schema.get("default", False) - self.nullable = nullable - validator_cls = validator or Draft4RequestValidator - self.validator = validator_cls(schema, format_checker=draft4_format_checker) - self.uri_parser = uri_parser - self.encoding = encoding - self._messages: t.List[t.MutableMapping[str, t.Any]] = [] - self.headers = Headers(scope=scope) - self.strict_validation = strict_validation - self.check_empty() - - @property - def form_parser_cls(self): - return FormParser - - def check_empty(self): - """`receive` is never called if body is empty, so we need to check this case at - initialization.""" - if not int(self.headers.get("content-length", 0)) and self.schema.get( - "required", [] - ): - self._validate({}) - - @classmethod - def _error_path_message(cls, exception): - error_path = ".".join(str(item) for item in exception.path) - error_path_msg = f" - '{error_path}'" if error_path else "" - return error_path_msg - - def _validate(self, data: dict) -> None: - try: - self.validator.validate(data) - except ValidationError as exception: - error_path_msg = self._error_path_message(exception=exception) - logger.error( - f"Validation error: {exception.message}{error_path_msg}", - extra={"validator": "body"}, - ) - raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}") - - def validate(self, data: FormData) -> None: - if self.strict_validation: - form_params = data.keys() - spec_params = self.schema.get("properties", {}).keys() - errors = set(form_params).difference(set(spec_params)) - if errors: - raise ExtraParameterProblem(errors, []) - - props = self.schema.get("properties", {}) - errs = [] - if self.uri_parser is not None: - # Don't parse file_data - form_data = {} - file_data = {} - for k, v in data.items(): - if isinstance(v, str): - form_data[k] = data.getlist(k) - elif isinstance(v, UploadFile): - file_data[k] = data.getlist(k) - - data = self.uri_parser.resolve_form(form_data) - # Add the files again - data.update(file_data) - else: - data = {k: data.getlist(k) for k in data} - - for k, param_defn in props.items(): - if k in data: - if param_defn.get("format", "") == "binary": - # Replace files with empty strings for validation - data[k] = "" - continue - - try: - data[k] = coerce_type(param_defn, data[k], "requestBody", k) - except TypeValidationError as e: - logger.exception(e) - errs += [str(e)] - - if errs: - raise BadRequestProblem(detail=errs) - - self._validate(data) - - async def wrapped_receive(self) -> Receive: - async def stream() -> t.AsyncGenerator[bytes, None]: - more_body = True - while more_body: - message = await self._receive() - self._messages.append(message) - more_body = message.get("more_body", False) - yield message.get("body", b"") - yield b"" - - form_parser = self.form_parser_cls(self.headers, stream()) - form = await form_parser.parse() - - if form and not (self.nullable and is_null(form)): - self.validate(form) - - async def receive() -> t.MutableMapping[str, t.Any]: - while self._messages: - return self._messages.pop(0) - return await self._receive() - - return receive - - -class MultiPartFormDataValidator(FormDataValidator): - @property - def form_parser_cls(self): - return MultiPartParser - - -VALIDATOR_MAP = { - "parameter": ParameterValidator, - "body": MediaTypeDict( - { - "*/*json": JSONRequestBodyValidator, - "application/x-www-form-urlencoded": FormDataValidator, - "multipart/form-data": MultiPartFormDataValidator, - } - ), - "response": MediaTypeDict( - { - "*/*json": JSONResponseBodyValidator, - "text/plain": TextResponseBodyValidator, - } - ), -} diff --git a/connexion/validators/__init__.py b/connexion/validators/__init__.py new file mode 100644 index 000000000..1c608917b --- /dev/null +++ b/connexion/validators/__init__.py @@ -0,0 +1,26 @@ +from connexion.datastructures import MediaTypeDict +from connexion.decorators.validation import ParameterValidator + +from .form_data import FormDataValidator, MultiPartFormDataValidator +from .json import ( + JSONRequestBodyValidator, + JSONResponseBodyValidator, + TextResponseBodyValidator, +) + +VALIDATOR_MAP = { + "parameter": ParameterValidator, + "body": MediaTypeDict( + { + "*/*json": JSONRequestBodyValidator, + "application/x-www-form-urlencoded": FormDataValidator, + "multipart/form-data": MultiPartFormDataValidator, + } + ), + "response": MediaTypeDict( + { + "*/*json": JSONResponseBodyValidator, + "text/plain": TextResponseBodyValidator, + } + ), +} diff --git a/connexion/validators/form_data.py b/connexion/validators/form_data.py new file mode 100644 index 000000000..dd825ae30 --- /dev/null +++ b/connexion/validators/form_data.py @@ -0,0 +1,147 @@ +import logging +import typing as t + +from jsonschema import Draft4Validator, ValidationError, draft4_format_checker +from starlette.datastructures import FormData, Headers, UploadFile +from starlette.formparsers import FormParser, MultiPartParser +from starlette.types import Receive, Scope + +from connexion.decorators.uri_parsing import AbstractURIParser +from connexion.decorators.validation import TypeValidationError, coerce_type +from connexion.exceptions import BadRequestProblem, ExtraParameterProblem +from connexion.json_schema import Draft4RequestValidator +from connexion.utils import is_null + +logger = logging.getLogger("connexion.validators.form_data") + + +class FormDataValidator: + """Request body validator for form content types.""" + + def __init__( + self, + scope: Scope, + receive: Receive, + *, + schema: dict, + validator: t.Type[Draft4Validator] = None, + uri_parser: t.Optional[AbstractURIParser] = None, + nullable=False, + encoding: str, + strict_validation: bool, + ) -> None: + self._scope = scope + self._receive = receive + self.schema = schema + self.has_default = schema.get("default", False) + self.nullable = nullable + validator_cls = validator or Draft4RequestValidator + self.validator = validator_cls(schema, format_checker=draft4_format_checker) + self.uri_parser = uri_parser + self.encoding = encoding + self._messages: t.List[t.MutableMapping[str, t.Any]] = [] + self.headers = Headers(scope=scope) + self.strict_validation = strict_validation + self.check_empty() + + @property + def form_parser_cls(self): + return FormParser + + def check_empty(self): + """`receive` is never called if body is empty, so we need to check this case at + initialization.""" + if not int(self.headers.get("content-length", 0)) and self.schema.get( + "required", [] + ): + self._validate({}) + + @classmethod + def _error_path_message(cls, exception): + error_path = ".".join(str(item) for item in exception.path) + error_path_msg = f" - '{error_path}'" if error_path else "" + return error_path_msg + + def _validate(self, data: dict) -> None: + try: + self.validator.validate(data) + except ValidationError as exception: + error_path_msg = self._error_path_message(exception=exception) + logger.error( + f"Validation error: {exception.message}{error_path_msg}", + extra={"validator": "body"}, + ) + raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}") + + def validate(self, data: FormData) -> None: + if self.strict_validation: + form_params = data.keys() + spec_params = self.schema.get("properties", {}).keys() + errors = set(form_params).difference(set(spec_params)) + if errors: + raise ExtraParameterProblem(errors, []) + + props = self.schema.get("properties", {}) + errs = [] + if self.uri_parser is not None: + # Don't parse file_data + form_data = {} + file_data = {} + for k, v in data.items(): + if isinstance(v, str): + form_data[k] = data.getlist(k) + elif isinstance(v, UploadFile): + file_data[k] = data.getlist(k) + + data = self.uri_parser.resolve_form(form_data) + # Add the files again + data.update(file_data) + else: + data = {k: data.getlist(k) for k in data} + + for k, param_defn in props.items(): + if k in data: + if param_defn.get("format", "") == "binary": + # Replace files with empty strings for validation + data[k] = "" + continue + + try: + data[k] = coerce_type(param_defn, data[k], "requestBody", k) + except TypeValidationError as e: + logger.exception(e) + errs += [str(e)] + + if errs: + raise BadRequestProblem(detail=errs) + + self._validate(data) + + async def wrapped_receive(self) -> Receive: + async def stream() -> t.AsyncGenerator[bytes, None]: + more_body = True + while more_body: + message = await self._receive() + self._messages.append(message) + more_body = message.get("more_body", False) + yield message.get("body", b"") + yield b"" + + form_parser = self.form_parser_cls(self.headers, stream()) + form = await form_parser.parse() + + if form and not (self.nullable and is_null(form)): + self.validate(form) + + async def receive() -> t.MutableMapping[str, t.Any]: + while self._messages: + return self._messages.pop(0) + return await self._receive() + + return receive + + +class MultiPartFormDataValidator(FormDataValidator): + @property + def form_parser_cls(self): + return MultiPartParser diff --git a/connexion/validators/json.py b/connexion/validators/json.py new file mode 100644 index 000000000..229215c45 --- /dev/null +++ b/connexion/validators/json.py @@ -0,0 +1,155 @@ +import json +import logging +import typing as t + +from jsonschema import Draft4Validator, ValidationError, draft4_format_checker +from starlette.types import Receive, Scope, Send + +from connexion.exceptions import BadRequestProblem, NonConformingResponseBody +from connexion.json_schema import Draft4RequestValidator, Draft4ResponseValidator +from connexion.utils import is_null + +logger = logging.getLogger("connexion.validators.json") + + +class JSONRequestBodyValidator: + """Request body validator for json content types.""" + + def __init__( + self, + scope: Scope, + receive: Receive, + *, + schema: dict, + validator: t.Type[Draft4Validator] = Draft4RequestValidator, + nullable=False, + encoding: str, + **kwargs, + ) -> None: + self._scope = scope + self._receive = receive + self.schema = schema + self.has_default = schema.get("default", False) + self.nullable = nullable + self.validator = validator(schema, format_checker=draft4_format_checker) + self.encoding = encoding + self._messages: t.List[t.MutableMapping[str, t.Any]] = [] + + @classmethod + def _error_path_message(cls, exception): + error_path = ".".join(str(item) for item in exception.path) + error_path_msg = f" - '{error_path}'" if error_path else "" + return error_path_msg + + def validate(self, body: dict): + try: + self.validator.validate(body) + except ValidationError as exception: + error_path_msg = self._error_path_message(exception=exception) + logger.error( + f"Validation error: {exception.message}{error_path_msg}", + extra={"validator": "body"}, + ) + raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}") + + @staticmethod + def parse(body: str) -> dict: + try: + return json.loads(body) + except json.decoder.JSONDecodeError as e: + raise BadRequestProblem(str(e)) + + async def wrapped_receive(self) -> Receive: + more_body = True + while more_body: + message = await self._receive() + self._messages.append(message) + more_body = message.get("more_body", False) + + bytes_body = b"".join([message.get("body", b"") for message in self._messages]) + decoded_body = bytes_body.decode(self.encoding) + + if decoded_body and not (self.nullable and is_null(decoded_body)): + body = self.parse(decoded_body) + self.validate(body) + + async def receive() -> t.MutableMapping[str, t.Any]: + while self._messages: + return self._messages.pop(0) + return await self._receive() + + return receive + + +class JSONResponseBodyValidator: + """Response body validator for json content types.""" + + def __init__( + self, + scope: Scope, + send: Send, + *, + schema: dict, + validator: t.Type[Draft4Validator] = Draft4ResponseValidator, + nullable=False, + encoding: str, + ) -> None: + self._scope = scope + self._send = send + self.schema = schema + self.has_default = schema.get("default", False) + self.nullable = nullable + self.validator = validator(schema, format_checker=draft4_format_checker) + self.encoding = encoding + self._messages: t.List[t.MutableMapping[str, t.Any]] = [] + + @classmethod + def _error_path_message(cls, exception): + error_path = ".".join(str(item) for item in exception.path) + error_path_msg = f" - '{error_path}'" if error_path else "" + return error_path_msg + + def validate(self, body: dict): + try: + self.validator.validate(body) + except ValidationError as exception: + error_path_msg = self._error_path_message(exception=exception) + logger.error( + f"Validation error: {exception.message}{error_path_msg}", + extra={"validator": "body"}, + ) + raise NonConformingResponseBody( + message=f"{exception.message}{error_path_msg}" + ) + + @staticmethod + def parse(body: str) -> dict: + try: + return json.loads(body) + except json.decoder.JSONDecodeError as e: + raise BadRequestProblem(str(e)) + + async def send(self, message: t.MutableMapping[str, t.Any]) -> None: + self._messages.append(message) + + if message["type"] == "http.response.start" or message.get("more_body", False): + return + + bytes_body = b"".join([message.get("body", b"") for message in self._messages]) + decoded_body = bytes_body.decode(self.encoding) + + if decoded_body and not (self.nullable and is_null(decoded_body)): + body = self.parse(decoded_body) + self.validate(body) + + while self._messages: + await self._send(self._messages.pop(0)) + + +class TextResponseBodyValidator(JSONResponseBodyValidator): + @staticmethod + def parse(body: str) -> str: # type: ignore + try: + return json.loads(body) + except json.decoder.JSONDecodeError: + return body From 2581a7e4c44601cb284e3b58197f603b034649bd Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Mon, 14 Nov 2022 23:15:31 +0100 Subject: [PATCH 2/3] Move parameter validation to middleware --- connexion/apis/abstract.py | 6 - connexion/apps/abstract.py | 1 - connexion/decorators/uri_parsing.py | 20 +- connexion/decorators/validation.py | 215 --------------------- connexion/exceptions.py | 19 ++ connexion/middleware/request_validation.py | 20 +- connexion/middleware/routing.py | 4 + connexion/operations/abstract.py | 32 --- connexion/operations/openapi.py | 4 - connexion/operations/swagger2.py | 4 - connexion/utils.py | 52 +++++ connexion/validators/__init__.py | 2 +- connexion/validators/form_data.py | 9 +- connexion/validators/parameter.py | 148 ++++++++++++++ tests/api/test_parameters.py | 12 +- tests/api/test_responses.py | 2 +- tests/decorators/test_validation.py | 9 +- tests/test_validation.py | 121 +++++++----- 18 files changed, 342 insertions(+), 338 deletions(-) delete mode 100644 connexion/decorators/validation.py create mode 100644 connexion/validators/parameter.py diff --git a/connexion/apis/abstract.py b/connexion/apis/abstract.py index 9c02a47d6..5706b6aed 100644 --- a/connexion/apis/abstract.py +++ b/connexion/apis/abstract.py @@ -186,7 +186,6 @@ def __init__( resolver=None, debug=False, resolver_error_handler=None, - validator_map=None, pythonic_params=False, options=None, **kwargs, @@ -194,15 +193,11 @@ def __init__( """ :type validate_responses: bool :type strict_validation: bool - :param validator_map: Custom validators for the types "parameter", "body" and "response". - :type validator_map: dict :type resolver_error_handler: callable | None :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended to any shadowed built-ins :type pythonic_params: bool """ - self.validator_map = validator_map - logger.debug("Validate Responses: %s", str(validate_responses)) self.validate_responses = validate_responses @@ -245,7 +240,6 @@ def add_operation(self, path, method): method, self.resolver, validate_responses=self.validate_responses, - validator_map=self.validator_map, strict_validation=self.strict_validation, pythonic_params=self.pythonic_params, uri_parser_class=self.options.uri_parser_class, diff --git a/connexion/apps/abstract.py b/connexion/apps/abstract.py index 78a842f88..adb4c53c6 100644 --- a/connexion/apps/abstract.py +++ b/connexion/apps/abstract.py @@ -213,7 +213,6 @@ def add_api( strict_validation=strict_validation, auth_all_paths=auth_all_paths, debug=self.debug, - validator_map=validator_map, pythonic_params=pythonic_params, options=api_options.as_dict(), ) diff --git a/connexion/decorators/uri_parsing.py b/connexion/decorators/uri_parsing.py index 8a7542cc8..7697fb997 100644 --- a/connexion/decorators/uri_parsing.py +++ b/connexion/decorators/uri_parsing.py @@ -8,8 +8,9 @@ import logging import re -from .. import utils -from .decorator import BaseDecorator +from connexion.decorators.decorator import BaseDecorator +from connexion.exceptions import TypeValidationError +from connexion.utils import all_json, coerce_type, deep_merge, is_null, is_nullable logger = logging.getLogger("connexion.decorators.uri_parsing") @@ -119,6 +120,15 @@ def resolve_params(self, params, _in): else: resolved_param[k] = values[-1] + if not (is_nullable(param_defn) and is_null(resolved_param[k])): + try: + # TODO: coerce types in a single place + resolved_param[k] = coerce_type( + param_defn, resolved_param[k], "parameter", k + ) + except TypeValidationError: + pass + return resolved_param def __call__(self, function): @@ -182,9 +192,7 @@ def resolve_form(self, form_data): ) if defn and defn["type"] == "array": form_data[k] = self._split(form_data[k], encoding, "form") - elif "contentType" in encoding and utils.all_json( - [encoding.get("contentType")] - ): + elif "contentType" in encoding and all_json([encoding.get("contentType")]): form_data[k] = json.loads(form_data[k]) return form_data @@ -231,7 +239,7 @@ def _preprocess_deep_objects(self, query_data): ret = dict.fromkeys(root_keys, [{}]) for k, v, is_deep_object in deep: if is_deep_object: - ret[k] = [utils.deep_merge(v[0], ret[k][0])] + ret[k] = [deep_merge(v[0], ret[k][0])] else: ret[k] = v return ret diff --git a/connexion/decorators/validation.py b/connexion/decorators/validation.py deleted file mode 100644 index f96fc33ca..000000000 --- a/connexion/decorators/validation.py +++ /dev/null @@ -1,215 +0,0 @@ -""" -This module defines view function decorators to validate request and response parameters and bodies. -""" - -import collections -import copy -import functools -import logging - -from jsonschema import Draft4Validator, ValidationError - -from ..exceptions import BadRequestProblem, ExtraParameterProblem -from ..utils import boolean, is_null, is_nullable - -logger = logging.getLogger("connexion.decorators.validation") - -TYPE_MAP = {"integer": int, "number": float, "boolean": boolean, "object": dict} - -try: - draft4_format_checker = Draft4Validator.FORMAT_CHECKER # type: ignore -except AttributeError: # jsonschema < 4.5.0 - from jsonschema import draft4_format_checker - - -class TypeValidationError(Exception): - def __init__(self, schema_type, parameter_type, parameter_name): - """ - Exception raise when type validation fails - - :type schema_type: str - :type parameter_type: str - :type parameter_name: str - :return: - """ - self.schema_type = schema_type - self.parameter_type = parameter_type - self.parameter_name = parameter_name - - def __str__(self): - msg = "Wrong type, expected '{schema_type}' for {parameter_type} parameter '{parameter_name}'" - return msg.format(**vars(self)) - - -def coerce_type(param, value, parameter_type, parameter_name=None): - def make_type(value, type_literal): - type_func = TYPE_MAP.get(type_literal) - return type_func(value) - - param_schema = param.get("schema", param) - if is_nullable(param_schema) and is_null(value): - return None - - param_type = param_schema.get("type") - parameter_name = parameter_name if parameter_name else param.get("name") - if param_type == "array": - converted_params = [] - if parameter_type == "header": - value = value.split(",") - for v in value: - try: - converted = make_type(v, param_schema["items"]["type"]) - except (ValueError, TypeError): - converted = v - converted_params.append(converted) - return converted_params - elif param_type == "object": - if param_schema.get("properties"): - - def cast_leaves(d, schema): - if type(d) is not dict: - try: - return make_type(d, schema["type"]) - except (ValueError, TypeError): - return d - for k, v in d.items(): - if k in schema["properties"]: - d[k] = cast_leaves(v, schema["properties"][k]) - return d - - return cast_leaves(value, param_schema) - return value - else: - try: - return make_type(value, param_type) - except ValueError: - raise TypeValidationError(param_type, parameter_type, parameter_name) - except TypeError: - return value - - -def validate_parameter_list(request_params, spec_params): - request_params = set(request_params) - spec_params = set(spec_params) - - return request_params.difference(spec_params) - - -class ParameterValidator: - def __init__(self, parameters, api, strict_validation=False): - """ - :param parameters: List of request parameter dictionaries - :param api: api that the validator is attached to - :param strict_validation: Flag indicating if parameters not in spec are allowed - """ - self.parameters = collections.defaultdict(list) - for p in parameters: - self.parameters[p["in"]].append(p) - - self.api = api - self.strict_validation = strict_validation - - @staticmethod - def validate_parameter(parameter_type, value, param, param_name=None): - if value is not None: - if is_nullable(param) and is_null(value): - return - - try: - converted_value = coerce_type(param, value, parameter_type, param_name) - except TypeValidationError as e: - return str(e) - - param = copy.deepcopy(param) - param = param.get("schema", param) - if "required" in param: - del param["required"] - try: - Draft4Validator(param, format_checker=draft4_format_checker).validate( - converted_value - ) - except ValidationError as exception: - debug_msg = ( - "Error while converting value {converted_value} from param " - "{type_converted_value} of type real type {param_type} to the declared type {param}" - ) - fmt_params = dict( - converted_value=str(converted_value), - type_converted_value=type(converted_value), - param_type=param.get("type"), - param=param, - ) - logger.info(debug_msg.format(**fmt_params)) - return str(exception) - - elif param.get("required"): - return "Missing {parameter_type} parameter '{param[name]}'".format( - **locals() - ) - - def validate_query_parameter_list(self, request): - request_params = request.query.keys() - spec_params = [x["name"] for x in self.parameters.get("query", [])] - return validate_parameter_list(request_params, spec_params) - - def validate_query_parameter(self, param, request): - """ - Validate a single query parameter (request.args in Flask) - - :type param: dict - :rtype: str - """ - val = request.query.get(param["name"]) - return self.validate_parameter("query", val, param) - - def validate_path_parameter(self, param, request): - val = request.path_params.get(param["name"].replace("-", "_")) - return self.validate_parameter("path", val, param) - - def validate_header_parameter(self, param, request): - val = request.headers.get(param["name"]) - return self.validate_parameter("header", val, param) - - def validate_cookie_parameter(self, param, request): - val = request.cookies.get(param["name"]) - return self.validate_parameter("cookie", val, param) - - def __call__(self, function): - """ - :type function: types.FunctionType - :rtype: types.FunctionType - """ - - @functools.wraps(function) - def wrapper(request): - logger.debug("%s validating parameters...", request.url) - - if self.strict_validation: - query_errors = self.validate_query_parameter_list(request) - - if query_errors: - raise ExtraParameterProblem([], query_errors) - - for param in self.parameters.get("query", []): - error = self.validate_query_parameter(param, request) - if error: - raise BadRequestProblem(detail=error) - - for param in self.parameters.get("path", []): - error = self.validate_path_parameter(param, request) - if error: - raise BadRequestProblem(detail=error) - - for param in self.parameters.get("header", []): - error = self.validate_header_parameter(param, request) - if error: - raise BadRequestProblem(detail=error) - - for param in self.parameters.get("cookie", []): - error = self.validate_cookie_parameter(param, request) - if error: - raise BadRequestProblem(detail=error) - - return function(request) - - return wrapper diff --git a/connexion/exceptions.py b/connexion/exceptions.py index 8ab7d0bda..3f4e77dac 100644 --- a/connexion/exceptions.py +++ b/connexion/exceptions.py @@ -206,3 +206,22 @@ def __init__( ) super().__init__(title=title, detail=detail, **kwargs) + + +class TypeValidationError(Exception): + def __init__(self, schema_type, parameter_type, parameter_name): + """ + Exception raise when type validation fails + + :type schema_type: str + :type parameter_type: str + :type parameter_name: str + :return: + """ + self.schema_type = schema_type + self.parameter_type = parameter_type + self.parameter_name = parameter_name + + def __str__(self): + msg = "Wrong type, expected '{schema_type}' for {parameter_type} parameter '{parameter_name}'" + return msg.format(**vars(self)) diff --git a/connexion/middleware/request_validation.py b/connexion/middleware/request_validation.py index 26ec72400..d42585cd3 100644 --- a/connexion/middleware/request_validation.py +++ b/connexion/middleware/request_validation.py @@ -8,7 +8,6 @@ from connexion import utils from connexion.datastructures import MediaTypeDict -from connexion.decorators.uri_parsing import AbstractURIParser from connexion.exceptions import UnsupportedMediaTypeProblem from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.operations import AbstractOperation @@ -25,14 +24,12 @@ def __init__( operation: AbstractOperation, strict_validation: bool = False, validator_map: t.Optional[dict] = None, - uri_parser_class: t.Optional[AbstractURIParser] = None, ) -> None: self.next_app = next_app self._operation = operation self.strict_validation = strict_validation self._validator_map = VALIDATOR_MAP.copy() self._validator_map.update(validator_map or {}) - self.uri_parser_class = uri_parser_class def extract_content_type( self, headers: t.List[t.Tuple[bytes, bytes]] @@ -73,12 +70,24 @@ def validate_mime_type(self, mime_type: str) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send): receive_fn = receive + # Validate parameters & headers + uri_parser_class = self._operation._uri_parser_class + uri_parser = uri_parser_class( + self._operation.parameters, self._operation.body_definition() + ) + parameter_validator_cls = self._validator_map["parameter"] + parameter_validator = parameter_validator_cls( # type: ignore + self._operation.parameters, + uri_parser=uri_parser, + strict_validation=self.strict_validation, + ) + parameter_validator.validate(scope) + + # Extract content type headers = scope["headers"] mime_type, encoding = self.extract_content_type(headers) self.validate_mime_type(mime_type) - # TODO: Validate parameters - # Validate body schema = self._operation.body_schema(mime_type) if schema: @@ -137,7 +146,6 @@ def make_operation( operation=operation, strict_validation=self.strict_validation, validator_map=self.validator_map, - uri_parser_class=self.uri_parser_class, ) diff --git a/connexion/middleware/routing.py b/connexion/middleware/routing.py index 905ecb315..94197127d 100644 --- a/connexion/middleware/routing.py +++ b/connexion/middleware/routing.py @@ -25,6 +25,10 @@ def from_operation(cls, operation: AbstractOperation, next_app: ASGIApp): async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """Attach operation to scope and pass it to the next app""" original_scope = _scope.get() + # Pass resolved path params along + original_scope.setdefault("path_params", {}).update( + scope.get("path_params", {}) + ) api_base_path = scope.get("root_path", "")[ len(original_scope.get("root_path", "")) : diff --git a/connexion/operations/abstract.py b/connexion/operations/abstract.py index fb40123bb..fa0355f5b 100644 --- a/connexion/operations/abstract.py +++ b/connexion/operations/abstract.py @@ -9,17 +9,12 @@ from ..decorators.decorator import RequestResponseDecorator from ..decorators.parameter import parameter_to_arg from ..decorators.produces import BaseSerializer, Produces -from ..decorators.validation import ParameterValidator from ..utils import all_json logger = logging.getLogger("connexion.operations.abstract") DEFAULT_MIMETYPE = "application/json" -VALIDATOR_MAP = { - "parameter": ParameterValidator, -} - class AbstractOperation(metaclass=abc.ABCMeta): @@ -52,7 +47,6 @@ def __init__( validate_responses=False, strict_validation=False, randomize_endpoint=None, - validator_map=None, pythonic_params=False, uri_parser_class=None, ): @@ -77,8 +71,6 @@ def __init__( :type strict_validation: bool :param randomize_endpoint: number of random characters to append to operation name :type randomize_endpoint: integer - :param validator_map: Custom validators for the types "parameter", "body" and "response". - :type validator_map: dict :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended to any shadowed built-ins :type pythonic_params: bool @@ -104,9 +96,6 @@ def __init__( self._responses = self._operation.get("responses", {}) - self._validator_map = dict(VALIDATOR_MAP) - self._validator_map.update(validator_map or {}) - @property def api(self): return self._api @@ -140,13 +129,6 @@ def responses(self): """ return self._responses - @property - def validator_map(self): - """ - Validators to use for parameter, body, and response validation - """ - return self._validator_map - @property def operation_id(self): """ @@ -388,9 +370,6 @@ def function(self): logger.debug("... Adding produces decorator (%r)", produces_decorator) function = produces_decorator(function) - for validation_decorator in self.__validation_decorators: - function = validation_decorator(function) - uri_parsing_decorator = self._uri_parsing_decorator function = uri_parsing_decorator(function) @@ -442,17 +421,6 @@ def __content_type_decorator(self): else: return BaseSerializer() - @property - def __validation_decorators(self): - """ - :rtype: types.FunctionType - """ - ParameterValidator = self.validator_map["parameter"] - if self.parameters: - yield ParameterValidator( - self.parameters, self.api, strict_validation=self.strict_validation - ) - def json_loads(self, data): """ A wrapper for calling the API specific JSON loader. diff --git a/connexion/operations/openapi.py b/connexion/operations/openapi.py index fcd96a7e3..7f301b827 100644 --- a/connexion/operations/openapi.py +++ b/connexion/operations/openapi.py @@ -35,7 +35,6 @@ def __init__( validate_responses=False, strict_validation=False, randomize_endpoint=None, - validator_map=None, pythonic_params=False, uri_parser_class=None, ): @@ -72,8 +71,6 @@ def __init__( :type strict_validation: bool :param randomize_endpoint: number of random characters to append to operation name :type randomize_endpoint: integer - :param validator_map: Custom validators for the types "parameter", "body" and "response". - :type validator_map: dict :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended to any shadowed built-ins :type pythonic_params: bool @@ -97,7 +94,6 @@ def __init__( validate_responses=validate_responses, strict_validation=strict_validation, randomize_endpoint=randomize_endpoint, - validator_map=validator_map, pythonic_params=pythonic_params, uri_parser_class=uri_parser_class, ) diff --git a/connexion/operations/swagger2.py b/connexion/operations/swagger2.py index 1124d4ba0..f1211ea4d 100644 --- a/connexion/operations/swagger2.py +++ b/connexion/operations/swagger2.py @@ -52,7 +52,6 @@ def __init__( validate_responses=False, strict_validation=False, randomize_endpoint=None, - validator_map=None, pythonic_params=False, uri_parser_class=None, ): @@ -87,8 +86,6 @@ def __init__( :type strict_validation: bool :param randomize_endpoint: number of random characters to append to operation name :type randomize_endpoint: integer - :param validator_map: Custom validators for the types "parameter", "body" and "response". - :type validator_map: dict :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended to any shadowed built-ins :type pythonic_params: bool @@ -110,7 +107,6 @@ def __init__( validate_responses=validate_responses, strict_validation=strict_validation, randomize_endpoint=randomize_endpoint, - validator_map=validator_map, pythonic_params=pythonic_params, uri_parser_class=uri_parser_class, ) diff --git a/connexion/utils.py b/connexion/utils.py index cc6cdd632..e37dba768 100644 --- a/connexion/utils.py +++ b/connexion/utils.py @@ -9,6 +9,8 @@ import yaml +from connexion.exceptions import TypeValidationError + def boolean(s): """ @@ -296,3 +298,53 @@ def extract_content_type( mime_type = content_type break return mime_type, encoding + + +def coerce_type(param, value, parameter_type, parameter_name=None): + # TODO: clean up + TYPE_MAP = {"integer": int, "number": float, "boolean": boolean, "object": dict} + + def make_type(value, type_literal): + type_func = TYPE_MAP.get(type_literal) + return type_func(value) + + param_schema = param.get("schema", param) + if is_nullable(param_schema) and is_null(value): + return None + + param_type = param_schema.get("type") + parameter_name = parameter_name if parameter_name else param.get("name") + if param_type == "array": + converted_params = [] + if parameter_type == "header": + value = value.split(",") + for v in value: + try: + converted = make_type(v, param_schema["items"]["type"]) + except (ValueError, TypeError): + converted = v + converted_params.append(converted) + return converted_params + elif param_type == "object": + if param_schema.get("properties"): + + def cast_leaves(d, schema): + if type(d) is not dict: + try: + return make_type(d, schema["type"]) + except (ValueError, TypeError): + return d + for k, v in d.items(): + if k in schema["properties"]: + d[k] = cast_leaves(v, schema["properties"][k]) + return d + + return cast_leaves(value, param_schema) + return value + else: + try: + return make_type(value, param_type) + except ValueError: + raise TypeValidationError(param_type, parameter_type, parameter_name) + except TypeError: + return value diff --git a/connexion/validators/__init__.py b/connexion/validators/__init__.py index 1c608917b..fa1840bcc 100644 --- a/connexion/validators/__init__.py +++ b/connexion/validators/__init__.py @@ -1,5 +1,4 @@ from connexion.datastructures import MediaTypeDict -from connexion.decorators.validation import ParameterValidator from .form_data import FormDataValidator, MultiPartFormDataValidator from .json import ( @@ -7,6 +6,7 @@ JSONResponseBodyValidator, TextResponseBodyValidator, ) +from .parameter import ParameterValidator VALIDATOR_MAP = { "parameter": ParameterValidator, diff --git a/connexion/validators/form_data.py b/connexion/validators/form_data.py index dd825ae30..23a3d1121 100644 --- a/connexion/validators/form_data.py +++ b/connexion/validators/form_data.py @@ -7,10 +7,13 @@ from starlette.types import Receive, Scope from connexion.decorators.uri_parsing import AbstractURIParser -from connexion.decorators.validation import TypeValidationError, coerce_type -from connexion.exceptions import BadRequestProblem, ExtraParameterProblem +from connexion.exceptions import ( + BadRequestProblem, + ExtraParameterProblem, + TypeValidationError, +) from connexion.json_schema import Draft4RequestValidator -from connexion.utils import is_null +from connexion.utils import coerce_type, is_null logger = logging.getLogger("connexion.validators.form_data") diff --git a/connexion/validators/parameter.py b/connexion/validators/parameter.py new file mode 100644 index 000000000..4c02f6249 --- /dev/null +++ b/connexion/validators/parameter.py @@ -0,0 +1,148 @@ +import collections +import copy +import logging + +from jsonschema import Draft4Validator, ValidationError +from starlette.requests import Request + +from connexion.exceptions import ( + BadRequestProblem, + ExtraParameterProblem, + TypeValidationError, +) +from connexion.utils import boolean, coerce_type, is_null, is_nullable + +logger = logging.getLogger("connexion.validators.parameter") + +TYPE_MAP = {"integer": int, "number": float, "boolean": boolean, "object": dict} + +try: + draft4_format_checker = Draft4Validator.FORMAT_CHECKER # type: ignore +except AttributeError: # jsonschema < 4.5.0 + from jsonschema import draft4_format_checker + + +class ParameterValidator: + def __init__(self, parameters, uri_parser, strict_validation=False): + """ + :param parameters: List of request parameter dictionaries + :param uri_parser: class to use for uri parsing + :param strict_validation: Flag indicating if parameters not in spec are allowed + """ + self.parameters = collections.defaultdict(list) + for p in parameters: + self.parameters[p["in"]].append(p) + + self.uri_parser = uri_parser + self.strict_validation = strict_validation + + @staticmethod + def validate_parameter(parameter_type, value, param, param_name=None): + if value is not None: + if is_nullable(param) and is_null(value): + return + + try: + converted_value = coerce_type(param, value, parameter_type, param_name) + except TypeValidationError as e: + return str(e) + + param = copy.deepcopy(param) + param = param.get("schema", param) + if "required" in param: + del param["required"] + try: + Draft4Validator(param, format_checker=draft4_format_checker).validate( + converted_value + ) + except ValidationError as exception: + debug_msg = ( + "Error while converting value {converted_value} from param " + "{type_converted_value} of type real type {param_type} to the declared type {param}" + ) + fmt_params = dict( + converted_value=str(converted_value), + type_converted_value=type(converted_value), + param_type=param.get("type"), + param=param, + ) + logger.info(debug_msg.format(**fmt_params)) + return str(exception) + + elif param.get("required"): + return "Missing {parameter_type} parameter '{param[name]}'".format( + **locals() + ) + + @staticmethod + def validate_parameter_list(request_params, spec_params): + request_params = set(request_params) + spec_params = set(spec_params) + + return request_params.difference(spec_params) + + def validate_query_parameter_list(self, request): + request_params = request.query_params.keys() + spec_params = [x["name"] for x in self.parameters.get("query", [])] + return self.validate_parameter_list(request_params, spec_params) + + def validate_query_parameter(self, param, request): + """ + Validate a single query parameter (request.args in Flask) + + :type param: dict + :rtype: str + """ + # Convert to dict of lists + query_params = { + k: request.query_params.getlist(k) for k in request.query_params + } + query_params = self.uri_parser.resolve_query(query_params) + val = query_params.get(param["name"]) + return self.validate_parameter("query", val, param) + + def validate_path_parameter(self, param, request): + val = request.path_params.get(param["name"].replace("-", "_")) + return self.validate_parameter("path", val, param) + + def validate_header_parameter(self, param, request): + val = request.headers.get(param["name"]) + return self.validate_parameter("header", val, param) + + def validate_cookie_parameter(self, param, request): + val = request.cookies.get(param["name"]) + return self.validate_parameter("cookie", val, param) + + def validate(self, scope): + logger.debug("%s validating parameters...", scope.get("path")) + + request = Request(scope) + self.validate_request(request) + + def validate_request(self, request): + + if self.strict_validation: + query_errors = self.validate_query_parameter_list(request) + + if query_errors: + raise ExtraParameterProblem([], query_errors) + + for param in self.parameters.get("query", []): + error = self.validate_query_parameter(param, request) + if error: + raise BadRequestProblem(detail=error) + + for param in self.parameters.get("path", []): + error = self.validate_path_parameter(param, request) + if error: + raise BadRequestProblem(detail=error) + + for param in self.parameters.get("header", []): + error = self.validate_header_parameter(param, request) + if error: + raise BadRequestProblem(detail=error) + + for param in self.parameters.get("cookie", []): + error = self.validate_cookie_parameter(param, request) + if error: + raise BadRequestProblem(detail=error) diff --git a/tests/api/test_parameters.py b/tests/api/test_parameters.py index 62a4c94a1..2219dedcb 100644 --- a/tests/api/test_parameters.py +++ b/tests/api/test_parameters.py @@ -172,7 +172,7 @@ def test_path_parameter_someint__bad(simple_app): # non-integer values will not match Flask route app_client = simple_app.app.test_client() resp = app_client.get("/v1.0/test-int-path/foo") # type: flask.Response - assert resp.status_code == 404 + assert resp.status_code == 400, resp.text @pytest.mark.parametrize( @@ -205,7 +205,7 @@ def test_path_parameter_somefloat__bad(simple_app): # non-float values will not match Flask route app_client = simple_app.app.test_client() resp = app_client.get("/v1.0/test-float-path/123,45") # type: flask.Response - assert resp.status_code == 404 + assert resp.status_code == 400, resp.text def test_default_param(strict_app): @@ -348,19 +348,19 @@ def test_bool_param(simple_app): def test_bool_array_param(simple_app): app_client = simple_app.app.test_client() resp = app_client.get("/v1.0/test-bool-array-param?thruthiness=true,true,true") - assert resp.status_code == 200 + assert resp.status_code == 200, resp.text response = json.loads(resp.data.decode("utf-8", "replace")) assert response is True app_client = simple_app.app.test_client() resp = app_client.get("/v1.0/test-bool-array-param?thruthiness=true,true,false") - assert resp.status_code == 200 + assert resp.status_code == 200, resp.text response = json.loads(resp.data.decode("utf-8", "replace")) assert response is False app_client = simple_app.app.test_client() resp = app_client.get("/v1.0/test-bool-array-param") - assert resp.status_code == 200 + assert resp.status_code == 200, resp.text def test_required_param_miss_config(simple_app): @@ -557,7 +557,7 @@ def test_parameters_snake_case(snake_case_app): resp = app_client.get( "/v1.0/test-get-camel-case-version?truthiness=true&orderBy=asc" ) - assert resp.status_code == 200 + assert resp.status_code == 200, resp.text assert resp.get_json() == {"truthiness": True, "order_by": "asc"} resp = app_client.get("/v1.0/test-get-camel-case-version?truthiness=5") assert resp.status_code == 400 diff --git a/tests/api/test_responses.py b/tests/api/test_responses.py index c017c08e7..127245c6e 100644 --- a/tests/api/test_responses.py +++ b/tests/api/test_responses.py @@ -171,7 +171,7 @@ def test_exploded_deep_object_param_endpoint_openapi_multiple_data_types( response = app_client.get( "/v1.0/exploded-deep-object-param?id[foo]=bar&id[fooint]=2&id[fooboo]=false" ) # type: flask.Response - assert response.status_code == 200 + assert response.status_code == 200, response.text response_data = json.loads(response.data.decode("utf-8", "replace")) assert response_data == { "foo": "bar", diff --git a/tests/decorators/test_validation.py b/tests/decorators/test_validation.py index d8cb5a162..4f9c0aa06 100644 --- a/tests/decorators/test_validation.py +++ b/tests/decorators/test_validation.py @@ -1,9 +1,8 @@ from unittest.mock import MagicMock import pytest -from connexion.apis.flask_api import FlaskApi -from connexion.decorators.validation import ParameterValidator from connexion.json_schema import Draft4RequestValidator, Draft4ResponseValidator +from connexion.validators.parameter import ParameterValidator from jsonschema import ValidationError @@ -75,7 +74,7 @@ def test_get_valid_parameter_with_enum_array_header(): def test_invalid_type(monkeypatch): logger = MagicMock() - monkeypatch.setattr("connexion.decorators.validation.logger", logger) + monkeypatch.setattr("connexion.validators.parameter.logger", logger) result = ParameterValidator.validate_parameter( "formdata", 20, {"type": "string", "name": "foo"} ) @@ -91,8 +90,6 @@ def test_invalid_type(monkeypatch): def test_invalid_type_value_error(monkeypatch): - logger = MagicMock() - monkeypatch.setattr("connexion.decorators.validation.logger", logger) value = {"test": 1, "second": 2} result = ParameterValidator.validate_parameter( "formdata", value, {"type": "boolean", "name": "foo"} @@ -101,8 +98,6 @@ def test_invalid_type_value_error(monkeypatch): def test_enum_error(monkeypatch): - logger = MagicMock() - monkeypatch.setattr("connexion.decorators.validation.logger", logger) value = "INVALID" param = {"schema": {"type": "string", "enum": ["valid"]}, "name": "test_path_param"} result = ParameterValidator.validate_parameter("path", value, param) diff --git a/tests/test_validation.py b/tests/test_validation.py index d6e64670e..eb0a60664 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,27 +1,14 @@ from unittest.mock import MagicMock +from urllib.parse import quote_plus -import flask import pytest -from connexion.apis.flask_api import FlaskApi -from connexion.decorators.validation import ParameterValidator +from connexion.decorators.uri_parsing import Swagger2URIParser from connexion.exceptions import BadRequestProblem +from connexion.validators.parameter import ParameterValidator +from starlette.datastructures import QueryParams def test_parameter_validator(monkeypatch): - request = MagicMock(name="request") - request.args = {} - request.headers = {} - request.cookies = {} - request.params = {} - app = MagicMock(name="app") - - app.response_class = flask.Response - monkeypatch.setattr("flask.request", request) - monkeypatch.setattr("flask.current_app", app) - - def orig_handler(*args, **kwargs): - return "OK" - params = [ {"name": "p1", "in": "path", "type": "integer", "required": True}, {"name": "h1", "in": "header", "type": "string", "enum": ["a", "b"]}, @@ -36,78 +23,120 @@ def orig_handler(*args, **kwargs): "items": {"type": "integer", "minimum": 0}, }, ] - validator = ParameterValidator(params, FlaskApi) - handler = validator(orig_handler) - kwargs = {"query": {}, "headers": {}, "cookies": {}} + uri_parser = Swagger2URIParser(params, {}) + validator = ParameterValidator(params, uri_parser=uri_parser) + + kwargs = {"query_params": {}, "headers": {}, "cookies": {}} request = MagicMock(path_params={}, **kwargs) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail == "Missing path parameter 'p1'" + request = MagicMock(path_params={"p1": "123"}, **kwargs) - assert handler(request) == "OK" + try: + validator.validate_request(request) + except Exception as e: + pytest.fail(str(e)) + request = MagicMock(path_params={"p1": ""}, **kwargs) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail == "Wrong type, expected 'integer' for path parameter 'p1'" + request = MagicMock(path_params={"p1": "foo"}, **kwargs) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail == "Wrong type, expected 'integer' for path parameter 'p1'" + request = MagicMock(path_params={"p1": "1.2"}, **kwargs) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail == "Wrong type, expected 'integer' for path parameter 'p1'" - request = MagicMock(path_params={"p1": 1}, query={"q1": "4"}, headers={}) + request = MagicMock( + path_params={"p1": 1}, query_params=QueryParams("q1=4"), headers={}, cookies={} + ) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail.startswith("4 is greater than the maximum of 3") + request = MagicMock( - path_params={"p1": 1}, query={"q1": "3"}, headers={}, cookies={} + path_params={"p1": 1}, query_params=QueryParams("q1=3"), headers={}, cookies={} ) - assert handler(request) == "OK" + try: + validator.validate_request(request) + except Exception as e: + pytest.fail(str(e)) + query_params = QueryParams(f"a1={quote_plus('1,2')}") request = MagicMock( - path_params={"p1": 1}, query={"a1": ["1", "2"]}, headers={}, cookies={} + path_params={"p1": 1}, query_params=query_params, headers={}, cookies={} + ) + try: + validator.validate_request(request) + except Exception as e: + pytest.fail(str(e)) + + query_params = QueryParams(f"a1={quote_plus('1,a')}") + request = MagicMock( + path_params={"p1": 1}, query_params=query_params, headers={}, cookies={} ) - assert handler(request) == "OK" - request = MagicMock(path_params={"p1": 1}, query={"a1": ["1", "a"]}, headers={}) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail.startswith("'a' is not of type 'integer'") + request = MagicMock( - path_params={"p1": "123"}, query={}, headers={}, cookies={"c1": "b"} + path_params={"p1": "123"}, query_params={}, headers={}, cookies={"c1": "b"} ) - assert handler(request) == "OK" + try: + validator.validate_request(request) + except Exception as e: + pytest.fail(str(e)) request = MagicMock( path_params={"p1": "123"}, query={}, headers={}, cookies={"c1": "x"} ) with pytest.raises(BadRequestProblem) as exc: - assert handler(request) + assert validator.validate_request(request) + assert exc.value.detail.startswith("'x' is not one of ['a', 'b']") - request = MagicMock(path_params={"p1": 1}, query={"a1": ["1", "-1"]}, headers={}) + + query_params = QueryParams(f"a1={quote_plus('1,-1')}") + request = MagicMock( + path_params={"p1": 1}, query_params=query_params, headers={}, cookies={} + ) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail.startswith("-1 is less than the minimum of 0") - request = MagicMock(path_params={"p1": 1}, query={"a1": ["1"]}, headers={}) + + query_params = QueryParams("a1=1") + request = MagicMock( + path_params={"p1": 1}, query_params=query_params, headers={}, cookies={} + ) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail.startswith("[1] is too short") + + query_params = QueryParams(f"a1={quote_plus('1,2,3,4')}") request = MagicMock( - path_params={"p1": 1}, query={"a1": ["1", "2", "3", "4"]}, headers={} + path_params={"p1": 1}, query_params=query_params, headers={}, cookies={} ) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail.startswith("[1, 2, 3, 4] is too long") request = MagicMock( - path_params={"p1": "123"}, query={}, headers={"h1": "a"}, cookies={} + path_params={"p1": "123"}, query_params={}, headers={"h1": "a"}, cookies={} ) - assert handler(request) == "OK" + try: + validator.validate_request(request) + except Exception as e: + pytest.fail(str(e)) - request = MagicMock(path_params={"p1": "123"}, query={}, headers={"h1": "x"}) + request = MagicMock( + path_params={"p1": "123"}, query_params={}, headers={"h1": "x"}, cookies={} + ) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail.startswith("'x' is not one of ['a', 'b']") From 825c682086683bf59e8b50d586a9e0b45e4b2e33 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Mon, 14 Nov 2022 23:31:09 +0100 Subject: [PATCH 3/3] Remove strict_validation and validate_responses from API and Operation classes --- connexion/apis/abstract.py | 12 ------------ connexion/operations/abstract.py | 23 ----------------------- connexion/operations/openapi.py | 8 -------- connexion/operations/swagger2.py | 9 --------- tests/test_operation2.py | 3 +-- 5 files changed, 1 insertion(+), 54 deletions(-) diff --git a/connexion/apis/abstract.py b/connexion/apis/abstract.py index 5706b6aed..06128f901 100644 --- a/connexion/apis/abstract.py +++ b/connexion/apis/abstract.py @@ -181,8 +181,6 @@ def __init__( specification, base_path=None, arguments=None, - validate_responses=False, - strict_validation=False, resolver=None, debug=False, resolver_error_handler=None, @@ -191,19 +189,11 @@ def __init__( **kwargs, ): """ - :type validate_responses: bool - :type strict_validation: bool :type resolver_error_handler: callable | None :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended to any shadowed built-ins :type pythonic_params: bool """ - logger.debug("Validate Responses: %s", str(validate_responses)) - self.validate_responses = validate_responses - - logger.debug("Strict Request Validation: %s", str(strict_validation)) - self.strict_validation = strict_validation - logger.debug("Pythonic params: %s", str(pythonic_params)) self.pythonic_params = pythonic_params @@ -239,8 +229,6 @@ def add_operation(self, path, method): path, method, self.resolver, - validate_responses=self.validate_responses, - strict_validation=self.strict_validation, pythonic_params=self.pythonic_params, uri_parser_class=self.options.uri_parser_class, ) diff --git a/connexion/operations/abstract.py b/connexion/operations/abstract.py index fa0355f5b..eec082e27 100644 --- a/connexion/operations/abstract.py +++ b/connexion/operations/abstract.py @@ -44,8 +44,6 @@ def __init__( resolver, app_security=None, security_schemes=None, - validate_responses=False, - strict_validation=False, randomize_endpoint=None, pythonic_params=False, uri_parser_class=None, @@ -65,10 +63,6 @@ def __init__( :param security_schemes: `Security Definitions Object `_ :type security_schemes: dict - :param validate_responses: True enables validation. Validation errors generate HTTP 500 responses. - :type validate_responses: bool - :param strict_validation: True enables validation on invalid request parameters - :type strict_validation: bool :param randomize_endpoint: number of random characters to append to operation name :type randomize_endpoint: integer :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended @@ -84,8 +78,6 @@ def __init__( self._resolver = resolver self._security = operation.get("security", app_security) self._security_schemes = security_schemes - self._validate_responses = validate_responses - self._strict_validation = strict_validation self._pythonic_params = pythonic_params self._uri_parser_class = uri_parser_class self._randomize_endpoint = randomize_endpoint @@ -150,13 +142,6 @@ def router_controller(self): """ return self._router_controller - @property - def strict_validation(self): - """ - If True, validate all requests against the spec - """ - return self._strict_validation - @property def pythonic_params(self): """ @@ -164,14 +149,6 @@ def pythonic_params(self): """ return self._pythonic_params - @property - def validate_responses(self): - """ - If True, check the response against the response schema, and return an - error if the response does not validate. - """ - return self._validate_responses - @staticmethod def _get_file_arguments(files, arguments, has_kwargs=False): return {k: v for k, v in files.items() if k in arguments or has_kwargs} diff --git a/connexion/operations/openapi.py b/connexion/operations/openapi.py index 7f301b827..3d253dc17 100644 --- a/connexion/operations/openapi.py +++ b/connexion/operations/openapi.py @@ -32,8 +32,6 @@ def __init__( app_security=None, security_schemes=None, components=None, - validate_responses=False, - strict_validation=False, randomize_endpoint=None, pythonic_params=False, uri_parser_class=None, @@ -65,10 +63,6 @@ def __init__( :param components: `Components Object `_ :type components: dict - :param validate_responses: True enables validation. Validation errors generate HTTP 500 responses. - :type validate_responses: bool - :param strict_validation: True enables validation on invalid request parameters - :type strict_validation: bool :param randomize_endpoint: number of random characters to append to operation name :type randomize_endpoint: integer :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended @@ -91,8 +85,6 @@ def __init__( resolver=resolver, app_security=app_security, security_schemes=security_schemes, - validate_responses=validate_responses, - strict_validation=strict_validation, randomize_endpoint=randomize_endpoint, pythonic_params=pythonic_params, uri_parser_class=uri_parser_class, diff --git a/connexion/operations/swagger2.py b/connexion/operations/swagger2.py index f1211ea4d..ad5c1046b 100644 --- a/connexion/operations/swagger2.py +++ b/connexion/operations/swagger2.py @@ -49,8 +49,6 @@ def __init__( app_security=None, security_schemes=None, definitions=None, - validate_responses=False, - strict_validation=False, randomize_endpoint=None, pythonic_params=False, uri_parser_class=None, @@ -80,10 +78,6 @@ def __init__( :param definitions: `Definitions Object `_ :type definitions: dict - :param validate_responses: True enables validation. Validation errors generate HTTP 500 responses. - :type validate_responses: bool - :param strict_validation: True enables validation on invalid request parameters - :type strict_validation: bool :param randomize_endpoint: number of random characters to append to operation name :type randomize_endpoint: integer :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended @@ -104,8 +98,6 @@ def __init__( resolver=resolver, app_security=app_security, security_schemes=security_schemes, - validate_responses=validate_responses, - strict_validation=strict_validation, randomize_endpoint=randomize_endpoint, pythonic_params=pythonic_params, uri_parser_class=uri_parser_class, @@ -314,7 +306,6 @@ def _transform_form(self, form_parameters: t.List[dict]) -> dict: "type": "object", "properties": properties, "required": required, - "additionalProperties": not self.strict_validation, } } diff --git a/tests/test_operation2.py b/tests/test_operation2.py index 2e169979f..2ac20839e 100644 --- a/tests/test_operation2.py +++ b/tests/test_operation2.py @@ -706,7 +706,7 @@ def test_oauth_scopes_in_or(security_handler_factory): def test_form_transformation(api): - mock_self = mock.Mock(strict_validation=True) + mock_self = mock.Mock() swagger_form_parameters = [ { @@ -747,7 +747,6 @@ def test_form_transformation(api): }, }, "required": ["param"], - "additionalProperties": False, }, "encoding": { "array_param": {