From 2581a7e4c44601cb284e3b58197f603b034649bd Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Mon, 14 Nov 2022 23:15:31 +0100 Subject: [PATCH] Move parameter validation to middleware --- connexion/apis/abstract.py | 6 - connexion/apps/abstract.py | 1 - connexion/decorators/uri_parsing.py | 20 +- connexion/decorators/validation.py | 215 --------------------- connexion/exceptions.py | 19 ++ connexion/middleware/request_validation.py | 20 +- connexion/middleware/routing.py | 4 + connexion/operations/abstract.py | 32 --- connexion/operations/openapi.py | 4 - connexion/operations/swagger2.py | 4 - connexion/utils.py | 52 +++++ connexion/validators/__init__.py | 2 +- connexion/validators/form_data.py | 9 +- connexion/validators/parameter.py | 148 ++++++++++++++ tests/api/test_parameters.py | 12 +- tests/api/test_responses.py | 2 +- tests/decorators/test_validation.py | 9 +- tests/test_validation.py | 121 +++++++----- 18 files changed, 342 insertions(+), 338 deletions(-) delete mode 100644 connexion/decorators/validation.py create mode 100644 connexion/validators/parameter.py diff --git a/connexion/apis/abstract.py b/connexion/apis/abstract.py index 9c02a47d6..5706b6aed 100644 --- a/connexion/apis/abstract.py +++ b/connexion/apis/abstract.py @@ -186,7 +186,6 @@ def __init__( resolver=None, debug=False, resolver_error_handler=None, - validator_map=None, pythonic_params=False, options=None, **kwargs, @@ -194,15 +193,11 @@ def __init__( """ :type validate_responses: bool :type strict_validation: bool - :param validator_map: Custom validators for the types "parameter", "body" and "response". - :type validator_map: dict :type resolver_error_handler: callable | None :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended to any shadowed built-ins :type pythonic_params: bool """ - self.validator_map = validator_map - logger.debug("Validate Responses: %s", str(validate_responses)) self.validate_responses = validate_responses @@ -245,7 +240,6 @@ def add_operation(self, path, method): method, self.resolver, validate_responses=self.validate_responses, - validator_map=self.validator_map, strict_validation=self.strict_validation, pythonic_params=self.pythonic_params, uri_parser_class=self.options.uri_parser_class, diff --git a/connexion/apps/abstract.py b/connexion/apps/abstract.py index 78a842f88..adb4c53c6 100644 --- a/connexion/apps/abstract.py +++ b/connexion/apps/abstract.py @@ -213,7 +213,6 @@ def add_api( strict_validation=strict_validation, auth_all_paths=auth_all_paths, debug=self.debug, - validator_map=validator_map, pythonic_params=pythonic_params, options=api_options.as_dict(), ) diff --git a/connexion/decorators/uri_parsing.py b/connexion/decorators/uri_parsing.py index 8a7542cc8..7697fb997 100644 --- a/connexion/decorators/uri_parsing.py +++ b/connexion/decorators/uri_parsing.py @@ -8,8 +8,9 @@ import logging import re -from .. import utils -from .decorator import BaseDecorator +from connexion.decorators.decorator import BaseDecorator +from connexion.exceptions import TypeValidationError +from connexion.utils import all_json, coerce_type, deep_merge, is_null, is_nullable logger = logging.getLogger("connexion.decorators.uri_parsing") @@ -119,6 +120,15 @@ def resolve_params(self, params, _in): else: resolved_param[k] = values[-1] + if not (is_nullable(param_defn) and is_null(resolved_param[k])): + try: + # TODO: coerce types in a single place + resolved_param[k] = coerce_type( + param_defn, resolved_param[k], "parameter", k + ) + except TypeValidationError: + pass + return resolved_param def __call__(self, function): @@ -182,9 +192,7 @@ def resolve_form(self, form_data): ) if defn and defn["type"] == "array": form_data[k] = self._split(form_data[k], encoding, "form") - elif "contentType" in encoding and utils.all_json( - [encoding.get("contentType")] - ): + elif "contentType" in encoding and all_json([encoding.get("contentType")]): form_data[k] = json.loads(form_data[k]) return form_data @@ -231,7 +239,7 @@ def _preprocess_deep_objects(self, query_data): ret = dict.fromkeys(root_keys, [{}]) for k, v, is_deep_object in deep: if is_deep_object: - ret[k] = [utils.deep_merge(v[0], ret[k][0])] + ret[k] = [deep_merge(v[0], ret[k][0])] else: ret[k] = v return ret diff --git a/connexion/decorators/validation.py b/connexion/decorators/validation.py deleted file mode 100644 index f96fc33ca..000000000 --- a/connexion/decorators/validation.py +++ /dev/null @@ -1,215 +0,0 @@ -""" -This module defines view function decorators to validate request and response parameters and bodies. -""" - -import collections -import copy -import functools -import logging - -from jsonschema import Draft4Validator, ValidationError - -from ..exceptions import BadRequestProblem, ExtraParameterProblem -from ..utils import boolean, is_null, is_nullable - -logger = logging.getLogger("connexion.decorators.validation") - -TYPE_MAP = {"integer": int, "number": float, "boolean": boolean, "object": dict} - -try: - draft4_format_checker = Draft4Validator.FORMAT_CHECKER # type: ignore -except AttributeError: # jsonschema < 4.5.0 - from jsonschema import draft4_format_checker - - -class TypeValidationError(Exception): - def __init__(self, schema_type, parameter_type, parameter_name): - """ - Exception raise when type validation fails - - :type schema_type: str - :type parameter_type: str - :type parameter_name: str - :return: - """ - self.schema_type = schema_type - self.parameter_type = parameter_type - self.parameter_name = parameter_name - - def __str__(self): - msg = "Wrong type, expected '{schema_type}' for {parameter_type} parameter '{parameter_name}'" - return msg.format(**vars(self)) - - -def coerce_type(param, value, parameter_type, parameter_name=None): - def make_type(value, type_literal): - type_func = TYPE_MAP.get(type_literal) - return type_func(value) - - param_schema = param.get("schema", param) - if is_nullable(param_schema) and is_null(value): - return None - - param_type = param_schema.get("type") - parameter_name = parameter_name if parameter_name else param.get("name") - if param_type == "array": - converted_params = [] - if parameter_type == "header": - value = value.split(",") - for v in value: - try: - converted = make_type(v, param_schema["items"]["type"]) - except (ValueError, TypeError): - converted = v - converted_params.append(converted) - return converted_params - elif param_type == "object": - if param_schema.get("properties"): - - def cast_leaves(d, schema): - if type(d) is not dict: - try: - return make_type(d, schema["type"]) - except (ValueError, TypeError): - return d - for k, v in d.items(): - if k in schema["properties"]: - d[k] = cast_leaves(v, schema["properties"][k]) - return d - - return cast_leaves(value, param_schema) - return value - else: - try: - return make_type(value, param_type) - except ValueError: - raise TypeValidationError(param_type, parameter_type, parameter_name) - except TypeError: - return value - - -def validate_parameter_list(request_params, spec_params): - request_params = set(request_params) - spec_params = set(spec_params) - - return request_params.difference(spec_params) - - -class ParameterValidator: - def __init__(self, parameters, api, strict_validation=False): - """ - :param parameters: List of request parameter dictionaries - :param api: api that the validator is attached to - :param strict_validation: Flag indicating if parameters not in spec are allowed - """ - self.parameters = collections.defaultdict(list) - for p in parameters: - self.parameters[p["in"]].append(p) - - self.api = api - self.strict_validation = strict_validation - - @staticmethod - def validate_parameter(parameter_type, value, param, param_name=None): - if value is not None: - if is_nullable(param) and is_null(value): - return - - try: - converted_value = coerce_type(param, value, parameter_type, param_name) - except TypeValidationError as e: - return str(e) - - param = copy.deepcopy(param) - param = param.get("schema", param) - if "required" in param: - del param["required"] - try: - Draft4Validator(param, format_checker=draft4_format_checker).validate( - converted_value - ) - except ValidationError as exception: - debug_msg = ( - "Error while converting value {converted_value} from param " - "{type_converted_value} of type real type {param_type} to the declared type {param}" - ) - fmt_params = dict( - converted_value=str(converted_value), - type_converted_value=type(converted_value), - param_type=param.get("type"), - param=param, - ) - logger.info(debug_msg.format(**fmt_params)) - return str(exception) - - elif param.get("required"): - return "Missing {parameter_type} parameter '{param[name]}'".format( - **locals() - ) - - def validate_query_parameter_list(self, request): - request_params = request.query.keys() - spec_params = [x["name"] for x in self.parameters.get("query", [])] - return validate_parameter_list(request_params, spec_params) - - def validate_query_parameter(self, param, request): - """ - Validate a single query parameter (request.args in Flask) - - :type param: dict - :rtype: str - """ - val = request.query.get(param["name"]) - return self.validate_parameter("query", val, param) - - def validate_path_parameter(self, param, request): - val = request.path_params.get(param["name"].replace("-", "_")) - return self.validate_parameter("path", val, param) - - def validate_header_parameter(self, param, request): - val = request.headers.get(param["name"]) - return self.validate_parameter("header", val, param) - - def validate_cookie_parameter(self, param, request): - val = request.cookies.get(param["name"]) - return self.validate_parameter("cookie", val, param) - - def __call__(self, function): - """ - :type function: types.FunctionType - :rtype: types.FunctionType - """ - - @functools.wraps(function) - def wrapper(request): - logger.debug("%s validating parameters...", request.url) - - if self.strict_validation: - query_errors = self.validate_query_parameter_list(request) - - if query_errors: - raise ExtraParameterProblem([], query_errors) - - for param in self.parameters.get("query", []): - error = self.validate_query_parameter(param, request) - if error: - raise BadRequestProblem(detail=error) - - for param in self.parameters.get("path", []): - error = self.validate_path_parameter(param, request) - if error: - raise BadRequestProblem(detail=error) - - for param in self.parameters.get("header", []): - error = self.validate_header_parameter(param, request) - if error: - raise BadRequestProblem(detail=error) - - for param in self.parameters.get("cookie", []): - error = self.validate_cookie_parameter(param, request) - if error: - raise BadRequestProblem(detail=error) - - return function(request) - - return wrapper diff --git a/connexion/exceptions.py b/connexion/exceptions.py index 8ab7d0bda..3f4e77dac 100644 --- a/connexion/exceptions.py +++ b/connexion/exceptions.py @@ -206,3 +206,22 @@ def __init__( ) super().__init__(title=title, detail=detail, **kwargs) + + +class TypeValidationError(Exception): + def __init__(self, schema_type, parameter_type, parameter_name): + """ + Exception raise when type validation fails + + :type schema_type: str + :type parameter_type: str + :type parameter_name: str + :return: + """ + self.schema_type = schema_type + self.parameter_type = parameter_type + self.parameter_name = parameter_name + + def __str__(self): + msg = "Wrong type, expected '{schema_type}' for {parameter_type} parameter '{parameter_name}'" + return msg.format(**vars(self)) diff --git a/connexion/middleware/request_validation.py b/connexion/middleware/request_validation.py index 26ec72400..d42585cd3 100644 --- a/connexion/middleware/request_validation.py +++ b/connexion/middleware/request_validation.py @@ -8,7 +8,6 @@ from connexion import utils from connexion.datastructures import MediaTypeDict -from connexion.decorators.uri_parsing import AbstractURIParser from connexion.exceptions import UnsupportedMediaTypeProblem from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.operations import AbstractOperation @@ -25,14 +24,12 @@ def __init__( operation: AbstractOperation, strict_validation: bool = False, validator_map: t.Optional[dict] = None, - uri_parser_class: t.Optional[AbstractURIParser] = None, ) -> None: self.next_app = next_app self._operation = operation self.strict_validation = strict_validation self._validator_map = VALIDATOR_MAP.copy() self._validator_map.update(validator_map or {}) - self.uri_parser_class = uri_parser_class def extract_content_type( self, headers: t.List[t.Tuple[bytes, bytes]] @@ -73,12 +70,24 @@ def validate_mime_type(self, mime_type: str) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send): receive_fn = receive + # Validate parameters & headers + uri_parser_class = self._operation._uri_parser_class + uri_parser = uri_parser_class( + self._operation.parameters, self._operation.body_definition() + ) + parameter_validator_cls = self._validator_map["parameter"] + parameter_validator = parameter_validator_cls( # type: ignore + self._operation.parameters, + uri_parser=uri_parser, + strict_validation=self.strict_validation, + ) + parameter_validator.validate(scope) + + # Extract content type headers = scope["headers"] mime_type, encoding = self.extract_content_type(headers) self.validate_mime_type(mime_type) - # TODO: Validate parameters - # Validate body schema = self._operation.body_schema(mime_type) if schema: @@ -137,7 +146,6 @@ def make_operation( operation=operation, strict_validation=self.strict_validation, validator_map=self.validator_map, - uri_parser_class=self.uri_parser_class, ) diff --git a/connexion/middleware/routing.py b/connexion/middleware/routing.py index 905ecb315..94197127d 100644 --- a/connexion/middleware/routing.py +++ b/connexion/middleware/routing.py @@ -25,6 +25,10 @@ def from_operation(cls, operation: AbstractOperation, next_app: ASGIApp): async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """Attach operation to scope and pass it to the next app""" original_scope = _scope.get() + # Pass resolved path params along + original_scope.setdefault("path_params", {}).update( + scope.get("path_params", {}) + ) api_base_path = scope.get("root_path", "")[ len(original_scope.get("root_path", "")) : diff --git a/connexion/operations/abstract.py b/connexion/operations/abstract.py index fb40123bb..fa0355f5b 100644 --- a/connexion/operations/abstract.py +++ b/connexion/operations/abstract.py @@ -9,17 +9,12 @@ from ..decorators.decorator import RequestResponseDecorator from ..decorators.parameter import parameter_to_arg from ..decorators.produces import BaseSerializer, Produces -from ..decorators.validation import ParameterValidator from ..utils import all_json logger = logging.getLogger("connexion.operations.abstract") DEFAULT_MIMETYPE = "application/json" -VALIDATOR_MAP = { - "parameter": ParameterValidator, -} - class AbstractOperation(metaclass=abc.ABCMeta): @@ -52,7 +47,6 @@ def __init__( validate_responses=False, strict_validation=False, randomize_endpoint=None, - validator_map=None, pythonic_params=False, uri_parser_class=None, ): @@ -77,8 +71,6 @@ def __init__( :type strict_validation: bool :param randomize_endpoint: number of random characters to append to operation name :type randomize_endpoint: integer - :param validator_map: Custom validators for the types "parameter", "body" and "response". - :type validator_map: dict :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended to any shadowed built-ins :type pythonic_params: bool @@ -104,9 +96,6 @@ def __init__( self._responses = self._operation.get("responses", {}) - self._validator_map = dict(VALIDATOR_MAP) - self._validator_map.update(validator_map or {}) - @property def api(self): return self._api @@ -140,13 +129,6 @@ def responses(self): """ return self._responses - @property - def validator_map(self): - """ - Validators to use for parameter, body, and response validation - """ - return self._validator_map - @property def operation_id(self): """ @@ -388,9 +370,6 @@ def function(self): logger.debug("... Adding produces decorator (%r)", produces_decorator) function = produces_decorator(function) - for validation_decorator in self.__validation_decorators: - function = validation_decorator(function) - uri_parsing_decorator = self._uri_parsing_decorator function = uri_parsing_decorator(function) @@ -442,17 +421,6 @@ def __content_type_decorator(self): else: return BaseSerializer() - @property - def __validation_decorators(self): - """ - :rtype: types.FunctionType - """ - ParameterValidator = self.validator_map["parameter"] - if self.parameters: - yield ParameterValidator( - self.parameters, self.api, strict_validation=self.strict_validation - ) - def json_loads(self, data): """ A wrapper for calling the API specific JSON loader. diff --git a/connexion/operations/openapi.py b/connexion/operations/openapi.py index fcd96a7e3..7f301b827 100644 --- a/connexion/operations/openapi.py +++ b/connexion/operations/openapi.py @@ -35,7 +35,6 @@ def __init__( validate_responses=False, strict_validation=False, randomize_endpoint=None, - validator_map=None, pythonic_params=False, uri_parser_class=None, ): @@ -72,8 +71,6 @@ def __init__( :type strict_validation: bool :param randomize_endpoint: number of random characters to append to operation name :type randomize_endpoint: integer - :param validator_map: Custom validators for the types "parameter", "body" and "response". - :type validator_map: dict :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended to any shadowed built-ins :type pythonic_params: bool @@ -97,7 +94,6 @@ def __init__( validate_responses=validate_responses, strict_validation=strict_validation, randomize_endpoint=randomize_endpoint, - validator_map=validator_map, pythonic_params=pythonic_params, uri_parser_class=uri_parser_class, ) diff --git a/connexion/operations/swagger2.py b/connexion/operations/swagger2.py index 1124d4ba0..f1211ea4d 100644 --- a/connexion/operations/swagger2.py +++ b/connexion/operations/swagger2.py @@ -52,7 +52,6 @@ def __init__( validate_responses=False, strict_validation=False, randomize_endpoint=None, - validator_map=None, pythonic_params=False, uri_parser_class=None, ): @@ -87,8 +86,6 @@ def __init__( :type strict_validation: bool :param randomize_endpoint: number of random characters to append to operation name :type randomize_endpoint: integer - :param validator_map: Custom validators for the types "parameter", "body" and "response". - :type validator_map: dict :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended to any shadowed built-ins :type pythonic_params: bool @@ -110,7 +107,6 @@ def __init__( validate_responses=validate_responses, strict_validation=strict_validation, randomize_endpoint=randomize_endpoint, - validator_map=validator_map, pythonic_params=pythonic_params, uri_parser_class=uri_parser_class, ) diff --git a/connexion/utils.py b/connexion/utils.py index cc6cdd632..e37dba768 100644 --- a/connexion/utils.py +++ b/connexion/utils.py @@ -9,6 +9,8 @@ import yaml +from connexion.exceptions import TypeValidationError + def boolean(s): """ @@ -296,3 +298,53 @@ def extract_content_type( mime_type = content_type break return mime_type, encoding + + +def coerce_type(param, value, parameter_type, parameter_name=None): + # TODO: clean up + TYPE_MAP = {"integer": int, "number": float, "boolean": boolean, "object": dict} + + def make_type(value, type_literal): + type_func = TYPE_MAP.get(type_literal) + return type_func(value) + + param_schema = param.get("schema", param) + if is_nullable(param_schema) and is_null(value): + return None + + param_type = param_schema.get("type") + parameter_name = parameter_name if parameter_name else param.get("name") + if param_type == "array": + converted_params = [] + if parameter_type == "header": + value = value.split(",") + for v in value: + try: + converted = make_type(v, param_schema["items"]["type"]) + except (ValueError, TypeError): + converted = v + converted_params.append(converted) + return converted_params + elif param_type == "object": + if param_schema.get("properties"): + + def cast_leaves(d, schema): + if type(d) is not dict: + try: + return make_type(d, schema["type"]) + except (ValueError, TypeError): + return d + for k, v in d.items(): + if k in schema["properties"]: + d[k] = cast_leaves(v, schema["properties"][k]) + return d + + return cast_leaves(value, param_schema) + return value + else: + try: + return make_type(value, param_type) + except ValueError: + raise TypeValidationError(param_type, parameter_type, parameter_name) + except TypeError: + return value diff --git a/connexion/validators/__init__.py b/connexion/validators/__init__.py index 1c608917b..fa1840bcc 100644 --- a/connexion/validators/__init__.py +++ b/connexion/validators/__init__.py @@ -1,5 +1,4 @@ from connexion.datastructures import MediaTypeDict -from connexion.decorators.validation import ParameterValidator from .form_data import FormDataValidator, MultiPartFormDataValidator from .json import ( @@ -7,6 +6,7 @@ JSONResponseBodyValidator, TextResponseBodyValidator, ) +from .parameter import ParameterValidator VALIDATOR_MAP = { "parameter": ParameterValidator, diff --git a/connexion/validators/form_data.py b/connexion/validators/form_data.py index dd825ae30..23a3d1121 100644 --- a/connexion/validators/form_data.py +++ b/connexion/validators/form_data.py @@ -7,10 +7,13 @@ from starlette.types import Receive, Scope from connexion.decorators.uri_parsing import AbstractURIParser -from connexion.decorators.validation import TypeValidationError, coerce_type -from connexion.exceptions import BadRequestProblem, ExtraParameterProblem +from connexion.exceptions import ( + BadRequestProblem, + ExtraParameterProblem, + TypeValidationError, +) from connexion.json_schema import Draft4RequestValidator -from connexion.utils import is_null +from connexion.utils import coerce_type, is_null logger = logging.getLogger("connexion.validators.form_data") diff --git a/connexion/validators/parameter.py b/connexion/validators/parameter.py new file mode 100644 index 000000000..4c02f6249 --- /dev/null +++ b/connexion/validators/parameter.py @@ -0,0 +1,148 @@ +import collections +import copy +import logging + +from jsonschema import Draft4Validator, ValidationError +from starlette.requests import Request + +from connexion.exceptions import ( + BadRequestProblem, + ExtraParameterProblem, + TypeValidationError, +) +from connexion.utils import boolean, coerce_type, is_null, is_nullable + +logger = logging.getLogger("connexion.validators.parameter") + +TYPE_MAP = {"integer": int, "number": float, "boolean": boolean, "object": dict} + +try: + draft4_format_checker = Draft4Validator.FORMAT_CHECKER # type: ignore +except AttributeError: # jsonschema < 4.5.0 + from jsonschema import draft4_format_checker + + +class ParameterValidator: + def __init__(self, parameters, uri_parser, strict_validation=False): + """ + :param parameters: List of request parameter dictionaries + :param uri_parser: class to use for uri parsing + :param strict_validation: Flag indicating if parameters not in spec are allowed + """ + self.parameters = collections.defaultdict(list) + for p in parameters: + self.parameters[p["in"]].append(p) + + self.uri_parser = uri_parser + self.strict_validation = strict_validation + + @staticmethod + def validate_parameter(parameter_type, value, param, param_name=None): + if value is not None: + if is_nullable(param) and is_null(value): + return + + try: + converted_value = coerce_type(param, value, parameter_type, param_name) + except TypeValidationError as e: + return str(e) + + param = copy.deepcopy(param) + param = param.get("schema", param) + if "required" in param: + del param["required"] + try: + Draft4Validator(param, format_checker=draft4_format_checker).validate( + converted_value + ) + except ValidationError as exception: + debug_msg = ( + "Error while converting value {converted_value} from param " + "{type_converted_value} of type real type {param_type} to the declared type {param}" + ) + fmt_params = dict( + converted_value=str(converted_value), + type_converted_value=type(converted_value), + param_type=param.get("type"), + param=param, + ) + logger.info(debug_msg.format(**fmt_params)) + return str(exception) + + elif param.get("required"): + return "Missing {parameter_type} parameter '{param[name]}'".format( + **locals() + ) + + @staticmethod + def validate_parameter_list(request_params, spec_params): + request_params = set(request_params) + spec_params = set(spec_params) + + return request_params.difference(spec_params) + + def validate_query_parameter_list(self, request): + request_params = request.query_params.keys() + spec_params = [x["name"] for x in self.parameters.get("query", [])] + return self.validate_parameter_list(request_params, spec_params) + + def validate_query_parameter(self, param, request): + """ + Validate a single query parameter (request.args in Flask) + + :type param: dict + :rtype: str + """ + # Convert to dict of lists + query_params = { + k: request.query_params.getlist(k) for k in request.query_params + } + query_params = self.uri_parser.resolve_query(query_params) + val = query_params.get(param["name"]) + return self.validate_parameter("query", val, param) + + def validate_path_parameter(self, param, request): + val = request.path_params.get(param["name"].replace("-", "_")) + return self.validate_parameter("path", val, param) + + def validate_header_parameter(self, param, request): + val = request.headers.get(param["name"]) + return self.validate_parameter("header", val, param) + + def validate_cookie_parameter(self, param, request): + val = request.cookies.get(param["name"]) + return self.validate_parameter("cookie", val, param) + + def validate(self, scope): + logger.debug("%s validating parameters...", scope.get("path")) + + request = Request(scope) + self.validate_request(request) + + def validate_request(self, request): + + if self.strict_validation: + query_errors = self.validate_query_parameter_list(request) + + if query_errors: + raise ExtraParameterProblem([], query_errors) + + for param in self.parameters.get("query", []): + error = self.validate_query_parameter(param, request) + if error: + raise BadRequestProblem(detail=error) + + for param in self.parameters.get("path", []): + error = self.validate_path_parameter(param, request) + if error: + raise BadRequestProblem(detail=error) + + for param in self.parameters.get("header", []): + error = self.validate_header_parameter(param, request) + if error: + raise BadRequestProblem(detail=error) + + for param in self.parameters.get("cookie", []): + error = self.validate_cookie_parameter(param, request) + if error: + raise BadRequestProblem(detail=error) diff --git a/tests/api/test_parameters.py b/tests/api/test_parameters.py index 62a4c94a1..2219dedcb 100644 --- a/tests/api/test_parameters.py +++ b/tests/api/test_parameters.py @@ -172,7 +172,7 @@ def test_path_parameter_someint__bad(simple_app): # non-integer values will not match Flask route app_client = simple_app.app.test_client() resp = app_client.get("/v1.0/test-int-path/foo") # type: flask.Response - assert resp.status_code == 404 + assert resp.status_code == 400, resp.text @pytest.mark.parametrize( @@ -205,7 +205,7 @@ def test_path_parameter_somefloat__bad(simple_app): # non-float values will not match Flask route app_client = simple_app.app.test_client() resp = app_client.get("/v1.0/test-float-path/123,45") # type: flask.Response - assert resp.status_code == 404 + assert resp.status_code == 400, resp.text def test_default_param(strict_app): @@ -348,19 +348,19 @@ def test_bool_param(simple_app): def test_bool_array_param(simple_app): app_client = simple_app.app.test_client() resp = app_client.get("/v1.0/test-bool-array-param?thruthiness=true,true,true") - assert resp.status_code == 200 + assert resp.status_code == 200, resp.text response = json.loads(resp.data.decode("utf-8", "replace")) assert response is True app_client = simple_app.app.test_client() resp = app_client.get("/v1.0/test-bool-array-param?thruthiness=true,true,false") - assert resp.status_code == 200 + assert resp.status_code == 200, resp.text response = json.loads(resp.data.decode("utf-8", "replace")) assert response is False app_client = simple_app.app.test_client() resp = app_client.get("/v1.0/test-bool-array-param") - assert resp.status_code == 200 + assert resp.status_code == 200, resp.text def test_required_param_miss_config(simple_app): @@ -557,7 +557,7 @@ def test_parameters_snake_case(snake_case_app): resp = app_client.get( "/v1.0/test-get-camel-case-version?truthiness=true&orderBy=asc" ) - assert resp.status_code == 200 + assert resp.status_code == 200, resp.text assert resp.get_json() == {"truthiness": True, "order_by": "asc"} resp = app_client.get("/v1.0/test-get-camel-case-version?truthiness=5") assert resp.status_code == 400 diff --git a/tests/api/test_responses.py b/tests/api/test_responses.py index c017c08e7..127245c6e 100644 --- a/tests/api/test_responses.py +++ b/tests/api/test_responses.py @@ -171,7 +171,7 @@ def test_exploded_deep_object_param_endpoint_openapi_multiple_data_types( response = app_client.get( "/v1.0/exploded-deep-object-param?id[foo]=bar&id[fooint]=2&id[fooboo]=false" ) # type: flask.Response - assert response.status_code == 200 + assert response.status_code == 200, response.text response_data = json.loads(response.data.decode("utf-8", "replace")) assert response_data == { "foo": "bar", diff --git a/tests/decorators/test_validation.py b/tests/decorators/test_validation.py index d8cb5a162..4f9c0aa06 100644 --- a/tests/decorators/test_validation.py +++ b/tests/decorators/test_validation.py @@ -1,9 +1,8 @@ from unittest.mock import MagicMock import pytest -from connexion.apis.flask_api import FlaskApi -from connexion.decorators.validation import ParameterValidator from connexion.json_schema import Draft4RequestValidator, Draft4ResponseValidator +from connexion.validators.parameter import ParameterValidator from jsonschema import ValidationError @@ -75,7 +74,7 @@ def test_get_valid_parameter_with_enum_array_header(): def test_invalid_type(monkeypatch): logger = MagicMock() - monkeypatch.setattr("connexion.decorators.validation.logger", logger) + monkeypatch.setattr("connexion.validators.parameter.logger", logger) result = ParameterValidator.validate_parameter( "formdata", 20, {"type": "string", "name": "foo"} ) @@ -91,8 +90,6 @@ def test_invalid_type(monkeypatch): def test_invalid_type_value_error(monkeypatch): - logger = MagicMock() - monkeypatch.setattr("connexion.decorators.validation.logger", logger) value = {"test": 1, "second": 2} result = ParameterValidator.validate_parameter( "formdata", value, {"type": "boolean", "name": "foo"} @@ -101,8 +98,6 @@ def test_invalid_type_value_error(monkeypatch): def test_enum_error(monkeypatch): - logger = MagicMock() - monkeypatch.setattr("connexion.decorators.validation.logger", logger) value = "INVALID" param = {"schema": {"type": "string", "enum": ["valid"]}, "name": "test_path_param"} result = ParameterValidator.validate_parameter("path", value, param) diff --git a/tests/test_validation.py b/tests/test_validation.py index d6e64670e..eb0a60664 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,27 +1,14 @@ from unittest.mock import MagicMock +from urllib.parse import quote_plus -import flask import pytest -from connexion.apis.flask_api import FlaskApi -from connexion.decorators.validation import ParameterValidator +from connexion.decorators.uri_parsing import Swagger2URIParser from connexion.exceptions import BadRequestProblem +from connexion.validators.parameter import ParameterValidator +from starlette.datastructures import QueryParams def test_parameter_validator(monkeypatch): - request = MagicMock(name="request") - request.args = {} - request.headers = {} - request.cookies = {} - request.params = {} - app = MagicMock(name="app") - - app.response_class = flask.Response - monkeypatch.setattr("flask.request", request) - monkeypatch.setattr("flask.current_app", app) - - def orig_handler(*args, **kwargs): - return "OK" - params = [ {"name": "p1", "in": "path", "type": "integer", "required": True}, {"name": "h1", "in": "header", "type": "string", "enum": ["a", "b"]}, @@ -36,78 +23,120 @@ def orig_handler(*args, **kwargs): "items": {"type": "integer", "minimum": 0}, }, ] - validator = ParameterValidator(params, FlaskApi) - handler = validator(orig_handler) - kwargs = {"query": {}, "headers": {}, "cookies": {}} + uri_parser = Swagger2URIParser(params, {}) + validator = ParameterValidator(params, uri_parser=uri_parser) + + kwargs = {"query_params": {}, "headers": {}, "cookies": {}} request = MagicMock(path_params={}, **kwargs) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail == "Missing path parameter 'p1'" + request = MagicMock(path_params={"p1": "123"}, **kwargs) - assert handler(request) == "OK" + try: + validator.validate_request(request) + except Exception as e: + pytest.fail(str(e)) + request = MagicMock(path_params={"p1": ""}, **kwargs) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail == "Wrong type, expected 'integer' for path parameter 'p1'" + request = MagicMock(path_params={"p1": "foo"}, **kwargs) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail == "Wrong type, expected 'integer' for path parameter 'p1'" + request = MagicMock(path_params={"p1": "1.2"}, **kwargs) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail == "Wrong type, expected 'integer' for path parameter 'p1'" - request = MagicMock(path_params={"p1": 1}, query={"q1": "4"}, headers={}) + request = MagicMock( + path_params={"p1": 1}, query_params=QueryParams("q1=4"), headers={}, cookies={} + ) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail.startswith("4 is greater than the maximum of 3") + request = MagicMock( - path_params={"p1": 1}, query={"q1": "3"}, headers={}, cookies={} + path_params={"p1": 1}, query_params=QueryParams("q1=3"), headers={}, cookies={} ) - assert handler(request) == "OK" + try: + validator.validate_request(request) + except Exception as e: + pytest.fail(str(e)) + query_params = QueryParams(f"a1={quote_plus('1,2')}") request = MagicMock( - path_params={"p1": 1}, query={"a1": ["1", "2"]}, headers={}, cookies={} + path_params={"p1": 1}, query_params=query_params, headers={}, cookies={} + ) + try: + validator.validate_request(request) + except Exception as e: + pytest.fail(str(e)) + + query_params = QueryParams(f"a1={quote_plus('1,a')}") + request = MagicMock( + path_params={"p1": 1}, query_params=query_params, headers={}, cookies={} ) - assert handler(request) == "OK" - request = MagicMock(path_params={"p1": 1}, query={"a1": ["1", "a"]}, headers={}) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail.startswith("'a' is not of type 'integer'") + request = MagicMock( - path_params={"p1": "123"}, query={}, headers={}, cookies={"c1": "b"} + path_params={"p1": "123"}, query_params={}, headers={}, cookies={"c1": "b"} ) - assert handler(request) == "OK" + try: + validator.validate_request(request) + except Exception as e: + pytest.fail(str(e)) request = MagicMock( path_params={"p1": "123"}, query={}, headers={}, cookies={"c1": "x"} ) with pytest.raises(BadRequestProblem) as exc: - assert handler(request) + assert validator.validate_request(request) + assert exc.value.detail.startswith("'x' is not one of ['a', 'b']") - request = MagicMock(path_params={"p1": 1}, query={"a1": ["1", "-1"]}, headers={}) + + query_params = QueryParams(f"a1={quote_plus('1,-1')}") + request = MagicMock( + path_params={"p1": 1}, query_params=query_params, headers={}, cookies={} + ) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail.startswith("-1 is less than the minimum of 0") - request = MagicMock(path_params={"p1": 1}, query={"a1": ["1"]}, headers={}) + + query_params = QueryParams("a1=1") + request = MagicMock( + path_params={"p1": 1}, query_params=query_params, headers={}, cookies={} + ) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail.startswith("[1] is too short") + + query_params = QueryParams(f"a1={quote_plus('1,2,3,4')}") request = MagicMock( - path_params={"p1": 1}, query={"a1": ["1", "2", "3", "4"]}, headers={} + path_params={"p1": 1}, query_params=query_params, headers={}, cookies={} ) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail.startswith("[1, 2, 3, 4] is too long") request = MagicMock( - path_params={"p1": "123"}, query={}, headers={"h1": "a"}, cookies={} + path_params={"p1": "123"}, query_params={}, headers={"h1": "a"}, cookies={} ) - assert handler(request) == "OK" + try: + validator.validate_request(request) + except Exception as e: + pytest.fail(str(e)) - request = MagicMock(path_params={"p1": "123"}, query={}, headers={"h1": "x"}) + request = MagicMock( + path_params={"p1": "123"}, query_params={}, headers={"h1": "x"}, cookies={} + ) with pytest.raises(BadRequestProblem) as exc: - handler(request) + validator.validate_request(request) assert exc.value.detail.startswith("'x' is not one of ['a', 'b']")