diff --git a/dev-requirements-py3.txt b/dev-requirements-py3.txt index 2eaa657d..80c8efd4 100644 --- a/dev-requirements-py3.txt +++ b/dev-requirements-py3.txt @@ -5,4 +5,6 @@ flake8-bugbear==18.8.0 # aiohttp aiohttp>=3.0.0 webtest-aiohttp==2.0.0 +sanic>=0.8.0 +webtest-sanic>=0.1.5 pytest-aiohttp>=0.3.0 diff --git a/tests/apps/sanic_app.py b/tests/apps/sanic_app.py new file mode 100644 index 00000000..dc0becf4 --- /dev/null +++ b/tests/apps/sanic_app.py @@ -0,0 +1,197 @@ +from sanic import Sanic +from sanic.response import json as J +from sanic.views import HTTPMethodView + +import marshmallow as ma +from webargs import fields, ValidationError, missing +from webargs.sanicparser import parser, use_args, use_kwargs, HandleValidationError +from webargs.core import MARSHMALLOW_VERSION_INFO + + +class TestAppConfig: + TESTING = True + + +hello_args = {"name": fields.Str(missing="World", validate=lambda n: len(n) >= 3)} +hello_multiple = {"name": fields.List(fields.Str())} + + +class HelloSchema(ma.Schema): + name = fields.Str(missing="World", validate=lambda n: len(n) >= 3) + + +strict_kwargs = {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {} +hello_many_schema = HelloSchema(many=True, **strict_kwargs) + +app = Sanic(__name__) +app.config.from_object(TestAppConfig) + + +@app.route("/echo", methods=["GET", "POST"]) +async def echo(request): + parsed = await parser.parse(hello_args, request) + return J(parsed) + + +@app.route("/echo_query") +async def echo_query(request): + parsed = await parser.parse(hello_args, request, locations=("query",)) + return J(parsed) + + +@app.route("/echo_use_args", methods=["GET", "POST"]) +@use_args(hello_args) +async def echo_use_args(request, args): + return J(args) + + +@app.route("/echo_use_args_validated", methods=["GET", "POST"]) +@use_args( + {"value": fields.Int(required=True)}, validate=lambda args: args["value"] > 42 +) +async def echo_use_args_validated(request, args): + return J(args) + + +@app.route("/echo_use_kwargs", methods=["GET", "POST"]) +@use_kwargs(hello_args) +async def echo_use_kwargs(request, name): + return J({"name": name}) + + +@app.route("/echo_multi", methods=["GET", "POST"]) +async def multi(request): + parsed = await parser.parse(hello_multiple, request) + return J(parsed) + + +@app.route("/echo_many_schema", methods=["GET", "POST"]) +async def many_nested(request): + parsed = await parser.parse(hello_many_schema, request, locations=("json",)) + return J(parsed, content_type="application/json") + + +@app.route("/echo_use_args_with_path_param/") +@use_args({"value": fields.Int()}) +async def echo_use_args_with_path(request, args, name): + return J(args) + + +@app.route("/echo_use_kwargs_with_path_param/") +@use_kwargs({"value": fields.Int()}) +async def echo_use_kwargs_with_path(request, name, value): + return J({"value": value}) + + +@app.route("/error", methods=["GET", "POST"]) +async def error(request): + def always_fail(value): + raise ValidationError("something went wrong") + + args = {"text": fields.Str(validate=always_fail)} + parsed = await parser.parse(args, request) + return J(parsed) + + +@app.route("/error400", methods=["GET", "POST"]) +async def error400(request): + def always_fail(value): + raise ValidationError("something went wrong", status_code=400) + + args = {"text": fields.Str(validate=always_fail)} + parsed = await parser.parse(args, request) + + return J(parsed) + + +@app.route("/echo_headers") +async def echo_headers(request): + parsed = await parser.parse(hello_args, request, locations=("headers",)) + return J(parsed) + + +@app.route("/echo_cookie") +async def echo_cookie(request): + parsed = await parser.parse(hello_args, request, locations=("cookies",)) + return J(parsed) + + +@app.route("/echo_file", methods=["POST"]) +async def echo_file(request): + args = {"myfile": fields.Field()} + result = await parser.parse(args, request, locations=("files",)) + fp = result["myfile"] + content = fp.body.decode("utf8") + return J({"myfile": content}) + + +@app.route("/echo_view_arg/") +async def echo_view_arg(request, view_arg): + parsed = await parser.parse( + {"view_arg": fields.Int()}, request, locations=("view_args",) + ) + return J(parsed) + + +@app.route("/echo_view_arg_use_args/") +@use_args({"view_arg": fields.Int(location="view_args")}) +async def echo_view_arg_with_use_args(request, args, **kwargs): + return J(args) + + +@app.route("/echo_nested", methods=["POST"]) +async def echo_nested(request): + args = {"name": fields.Nested({"first": fields.Str(), "last": fields.Str()})} + parsed = await parser.parse(args, request) + return J(parsed) + + +@app.route("/echo_nested_many", methods=["POST"]) +async def echo_nested_many(request): + args = { + "users": fields.Nested({"id": fields.Int(), "name": fields.Str()}, many=True) + } + parsed = await parser.parse(args, request) + return J(parsed) + + +@app.route("/echo_nested_many_data_key", methods=["POST"]) +async def echo_nested_many_with_data_key(request): + data_key_kwarg = { + "load_from" if (MARSHMALLOW_VERSION_INFO[0] < 3) else "data_key": "X-Field" + } + args = {"x_field": fields.Nested({"id": fields.Int()}, many=True, **data_key_kwarg)} + parsed = await parser.parse(args, request) + return J(parsed) + + +class EchoMethodViewUseArgs(HTTPMethodView): + @use_args({"val": fields.Int()}) + async def post(self, request, args): + return J(args) + + +app.add_route(EchoMethodViewUseArgs.as_view(), "/echo_method_view_use_args") + + +class EchoMethodViewUseKwargs(HTTPMethodView): + @use_kwargs({"val": fields.Int()}) + async def post(self, request, val): + return J({"val": val}) + + +app.add_route(EchoMethodViewUseKwargs.as_view(), "/echo_method_view_use_kwargs") + + +@app.route("/echo_use_kwargs_missing", methods=["POST"]) +@use_kwargs({"username": fields.Str(), "password": fields.Str()}) +async def echo_use_kwargs_missing(request, username, password): + assert password is missing + return J({"username": username}) + + +# Return validation errors as JSON +@app.exception(HandleValidationError) +async def handle_validation_error(request, err): + assert isinstance(err.data["schema"], ma.Schema) + return J({"errors": err.exc.messages}, status=422) diff --git a/tests/test_sanicparser.py b/tests/test_sanicparser.py new file mode 100644 index 00000000..9226539d --- /dev/null +++ b/tests/test_sanicparser.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals +import json +import mock + +from sanic.exceptions import SanicException + +from webargs import fields, ValidationError, missing +from webargs.sanicparser import parser, abort +from webargs.core import MARSHMALLOW_VERSION_INFO + +from .apps.sanic_app import app +from .common import CommonTestCase +from webtest_sanic import TestApp +import asyncio +import pytest +import io + + +class TestSanicParser(CommonTestCase): + def create_app(self): + return app + + # testing of file uplaods is made through sanic.test_client + @pytest.mark.skip(reason="files location not supported for aiohttpparser") + def test_parse_files(self, testapp): + pass + + def create_testapp(self, app): + loop = asyncio.new_event_loop() + self.loop = loop + return TestApp(app, loop=self.loop) + + def after_create_app(self): + self.loop.close() + + def test_parsing_view_args(self, testapp): + res = testapp.get("/echo_view_arg/42") + assert res.json == {"view_arg": 42} + + def test_parsing_invalid_view_arg(self, testapp): + res = testapp.get("/echo_view_arg/foo", expect_errors=True) + assert res.status_code == 422 + assert res.json == {"errors": {"view_arg": ["Not a valid integer."]}} + + def test_use_args_with_view_args_parsing(self, testapp): + res = testapp.get("/echo_view_arg_use_args/42") + assert res.json == {"view_arg": 42} + + def test_use_args_on_a_method_view(self, testapp): + res = testapp.post("/echo_method_view_use_args", {"val": 42}) + assert res.json == {"val": 42} + + def test_use_kwargs_on_a_method_view(self, testapp): + res = testapp.post("/echo_method_view_use_kwargs", {"val": 42}) + assert res.json == {"val": 42} + + def test_use_kwargs_with_missing_data(self, testapp): + res = testapp.post("/echo_use_kwargs_missing", {"username": "foo"}) + assert res.json == {"username": "foo"} + + # regression test for https://github.com/sloria/webargs/issues/145 + def test_nested_many_with_data_key(self, testapp): + res = testapp.post_json("/echo_nested_many_data_key", {"x_field": [{"id": 42}]}) + # https://github.com/marshmallow-code/marshmallow/pull/714 + if MARSHMALLOW_VERSION_INFO[0] < 3: + assert res.json == {"x_field": [{"id": 42}]} + + res = testapp.post_json("/echo_nested_many_data_key", {"X-Field": [{"id": 24}]}) + assert res.json == {"x_field": [{"id": 24}]} + + res = testapp.post_json("/echo_nested_many_data_key", {}) + assert res.json == {} + + +@mock.patch("webargs.sanicparser.abort") +def test_abort_called_on_validation_error(mock_abort, loop): + app.test_client.get( + "/echo_use_args_validated", + params={"value": 41}, + headers={"content_type": "application/json"}, + ) + + mock_abort.assert_called + abort_args, abort_kwargs = mock_abort.call_args + assert abort_args[0] == 422 + expected_msg = "Invalid value." + assert abort_kwargs["messages"] == [expected_msg] + assert type(abort_kwargs["exc"]) == ValidationError + + +def test_parse_files(loop): + res = app.test_client.post( + "/echo_file", data={"myfile": io.BytesIO(b"data")}, gather_request=False + ) + assert res.json == {"myfile": "data"} + + +def test_parse_form_returns_missing_if_no_form(): + req = mock.Mock() + req.form.get.side_effect = AttributeError("no form") + assert parser.parse_form(req, "foo", fields.Field()) is missing + + +def test_abort_with_message(): + with pytest.raises(SanicException) as excinfo: + abort(400, message="custom error message") + assert excinfo.value.data["message"] == "custom error message" + + +def test_abort_has_serializable_data(): + with pytest.raises(SanicException) as excinfo: + abort(400, message="custom error message") + serialized_error = json.dumps(excinfo.value.data) + error = json.loads(serialized_error) + assert isinstance(error, dict) + assert error["message"] == "custom error message" + + with pytest.raises(SanicException) as excinfo: + abort( + 400, + message="custom error message", + exc=ValidationError("custom error message"), + ) + serialized_error = json.dumps(excinfo.value.data) + error = json.loads(serialized_error) + assert isinstance(error, dict) + assert error["message"] == "custom error message" diff --git a/webargs/sanicparser.py b/webargs/sanicparser.py new file mode 100644 index 00000000..017d29cf --- /dev/null +++ b/webargs/sanicparser.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +"""Sanic request argument parsing module. + +Example: :: + + from sanic import Sanic + + from webargs import fields + from webargs.sanicparser import use_args + + app = Sanic(__name__) + + hello_args = { + 'name': fields.Str(required=True) + } + + @app.route('/') + @use_args(hello_args) + async def index(args): + return 'Hello ' + args['name'] +""" +import sanic + +from webargs import core +from webargs.asyncparser import AsyncParser + + +@sanic.exceptions.add_status_code(422) +class HandleValidationError(sanic.exceptions.SanicException): + pass + + +def abort(http_status_code, exc=None, **kwargs): + """Raise a HTTPException for the given http_status_code. Attach any keyword + arguments to the exception for later processing. + + From Flask-Restful. See NOTICE file for license information. + """ + try: + sanic.exceptions.abort(http_status_code, exc) + except sanic.exceptions.SanicException as err: + err.data = kwargs + err.exc = exc + raise err + + +def is_json_request(req): + content_type = req.content_type + return core.is_json(content_type) + + +class SanicParser(AsyncParser): + """Sanic request argument parser.""" + + __location_map__ = dict(view_args="parse_view_args", **core.Parser.__location_map__) + + def parse_view_args(self, req, name, field): + """Pull a value from the request's ``view_args``.""" + return core.get_value(req.match_info, name, field) + + def get_request_from_view_args(self, view, args, kwargs): + """Get request object from a handler function or method. Used internally by + ``use_args`` and ``use_kwargs``. + """ + if len(args) > 1 and isinstance(args[1], sanic.request.Request): + req = args[1] + else: + req = args[0] + assert isinstance( + req, sanic.request.Request + ), "Request argument not found for handler" + return req + + def parse_json(self, req, name, field): + """Pull a json value from the request.""" + if not (req.body and is_json_request(req)): + return core.missing + json_data = req.json + if json_data is None: + return core.missing + return core.get_value(json_data, name, field, allow_many_nested=True) + + def parse_querystring(self, req, name, field): + """Pull a querystring value from the request.""" + return core.get_value(req.args, name, field) + + def parse_form(self, req, name, field): + """Pull a form value from the request.""" + try: + return core.get_value(req.form, name, field) + except AttributeError: + pass + return core.missing + + def parse_headers(self, req, name, field): + """Pull a value from the header data.""" + return core.get_value(req.headers, name, field) + + def parse_cookies(self, req, name, field): + """Pull a value from the cookiejar.""" + return core.get_value(req.cookies, name, field) + + def parse_files(self, req, name, field): + """Pull a file from the request.""" + return core.get_value(req.files, name, field) + + def handle_error(self, error, req, schema): + """Handles errors during parsing. Aborts the current HTTP request and + responds with a 422 error. + """ + + status_code = getattr(error, "status_code", self.DEFAULT_VALIDATION_STATUS) + abort(status_code, exc=error, messages=error.messages, schema=schema) + + +parser = SanicParser() +use_args = parser.use_args +use_kwargs = parser.use_kwargs