Skip to content

Commit

Permalink
Move parameter validation to middleware (#1610)
Browse files Browse the repository at this point in the history
Fixes #1525

This PR follows up on #1588 and #1591 and moves the last part of
validation, the parameter validation, to the middleware,
  • Loading branch information
RobbeSneyders authored Dec 23, 2022
2 parents a829536 + 825c682 commit 1af4733
Show file tree
Hide file tree
Showing 21 changed files with 667 additions and 709 deletions.
18 changes: 0 additions & 18 deletions connexion/apis/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,34 +181,19 @@ def __init__(
specification,
base_path=None,
arguments=None,
validate_responses=False,
strict_validation=False,
resolver=None,
debug=False,
resolver_error_handler=None,
validator_map=None,
pythonic_params=False,
options=None,
**kwargs,
):
"""
:type validate_responses: bool
:type strict_validation: bool
:param validator_map: Custom validators for the types "parameter", "body" and "response".
:type validator_map: dict
:type resolver_error_handler: callable | None
:param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended
to any shadowed built-ins
:type pythonic_params: bool
"""
self.validator_map = validator_map

logger.debug("Validate Responses: %s", str(validate_responses))
self.validate_responses = validate_responses

logger.debug("Strict Request Validation: %s", str(strict_validation))
self.strict_validation = strict_validation

logger.debug("Pythonic params: %s", str(pythonic_params))
self.pythonic_params = pythonic_params

Expand Down Expand Up @@ -244,9 +229,6 @@ def add_operation(self, path, method):
path,
method,
self.resolver,
validate_responses=self.validate_responses,
validator_map=self.validator_map,
strict_validation=self.strict_validation,
pythonic_params=self.pythonic_params,
uri_parser_class=self.options.uri_parser_class,
)
Expand Down
1 change: 0 additions & 1 deletion connexion/apps/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def add_api(
strict_validation=strict_validation,
auth_all_paths=auth_all_paths,
debug=self.debug,
validator_map=validator_map,
pythonic_params=pythonic_params,
options=api_options.as_dict(),
)
Expand Down
20 changes: 14 additions & 6 deletions connexion/decorators/uri_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import logging
import re

from .. import utils
from .decorator import BaseDecorator
from connexion.decorators.decorator import BaseDecorator
from connexion.exceptions import TypeValidationError
from connexion.utils import all_json, coerce_type, deep_merge, is_null, is_nullable

logger = logging.getLogger("connexion.decorators.uri_parsing")

Expand Down Expand Up @@ -119,6 +120,15 @@ def resolve_params(self, params, _in):
else:
resolved_param[k] = values[-1]

if not (is_nullable(param_defn) and is_null(resolved_param[k])):
try:
# TODO: coerce types in a single place
resolved_param[k] = coerce_type(
param_defn, resolved_param[k], "parameter", k
)
except TypeValidationError:
pass

return resolved_param

def __call__(self, function):
Expand Down Expand Up @@ -182,9 +192,7 @@ def resolve_form(self, form_data):
)
if defn and defn["type"] == "array":
form_data[k] = self._split(form_data[k], encoding, "form")
elif "contentType" in encoding and utils.all_json(
[encoding.get("contentType")]
):
elif "contentType" in encoding and all_json([encoding.get("contentType")]):
form_data[k] = json.loads(form_data[k])
return form_data

Expand Down Expand Up @@ -231,7 +239,7 @@ def _preprocess_deep_objects(self, query_data):
ret = dict.fromkeys(root_keys, [{}])
for k, v, is_deep_object in deep:
if is_deep_object:
ret[k] = [utils.deep_merge(v[0], ret[k][0])]
ret[k] = [deep_merge(v[0], ret[k][0])]
else:
ret[k] = v
return ret
Expand Down
215 changes: 0 additions & 215 deletions connexion/decorators/validation.py

This file was deleted.

19 changes: 19 additions & 0 deletions connexion/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,22 @@ def __init__(
)

super().__init__(title=title, detail=detail, **kwargs)


class TypeValidationError(Exception):
def __init__(self, schema_type, parameter_type, parameter_name):
"""
Exception raise when type validation fails
:type schema_type: str
:type parameter_type: str
:type parameter_name: str
:return:
"""
self.schema_type = schema_type
self.parameter_type = parameter_type
self.parameter_name = parameter_name

def __str__(self):
msg = "Wrong type, expected '{schema_type}' for {parameter_type} parameter '{parameter_name}'"
return msg.format(**vars(self))
20 changes: 14 additions & 6 deletions connexion/middleware/request_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from connexion import utils
from connexion.datastructures import MediaTypeDict
from connexion.decorators.uri_parsing import AbstractURIParser
from connexion.exceptions import UnsupportedMediaTypeProblem
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.operations import AbstractOperation
Expand All @@ -25,14 +24,12 @@ def __init__(
operation: AbstractOperation,
strict_validation: bool = False,
validator_map: t.Optional[dict] = None,
uri_parser_class: t.Optional[AbstractURIParser] = None,
) -> None:
self.next_app = next_app
self._operation = operation
self.strict_validation = strict_validation
self._validator_map = VALIDATOR_MAP.copy()
self._validator_map.update(validator_map or {})
self.uri_parser_class = uri_parser_class

def extract_content_type(
self, headers: t.List[t.Tuple[bytes, bytes]]
Expand Down Expand Up @@ -73,12 +70,24 @@ def validate_mime_type(self, mime_type: str) -> None:
async def __call__(self, scope: Scope, receive: Receive, send: Send):
receive_fn = receive

# Validate parameters & headers
uri_parser_class = self._operation._uri_parser_class
uri_parser = uri_parser_class(
self._operation.parameters, self._operation.body_definition()
)
parameter_validator_cls = self._validator_map["parameter"]
parameter_validator = parameter_validator_cls( # type: ignore
self._operation.parameters,
uri_parser=uri_parser,
strict_validation=self.strict_validation,
)
parameter_validator.validate(scope)

# Extract content type
headers = scope["headers"]
mime_type, encoding = self.extract_content_type(headers)
self.validate_mime_type(mime_type)

# TODO: Validate parameters

# Validate body
schema = self._operation.body_schema(mime_type)
if schema:
Expand Down Expand Up @@ -137,7 +146,6 @@ def make_operation(
operation=operation,
strict_validation=self.strict_validation,
validator_map=self.validator_map,
uri_parser_class=self.uri_parser_class,
)


Expand Down
Loading

0 comments on commit 1af4733

Please sign in to comment.