diff --git a/sanic_graphql/graphqlview.py b/sanic_graphql/graphqlview.py index 02335f1..003d485 100644 --- a/sanic_graphql/graphqlview.py +++ b/sanic_graphql/graphqlview.py @@ -1,30 +1,18 @@ -import json -import six +from functools import partial from cgi import parse_header + +from promise import Promise from sanic.response import HTTPResponse from sanic.views import HTTPMethodView -from sanic.exceptions import SanicException -from promise import Promise -from graphql import Source, execute, parse, validate -from graphql.error import format_error as format_graphql_error -from graphql.error import GraphQLError -from graphql.execution import ExecutionResult from graphql.type.schema import GraphQLSchema -from graphql.utils.get_operation_ast import get_operation_ast from graphql.execution.executors.asyncio import AsyncioExecutor +from graphql_server import run_http_query, HttpQueryError, default_format_error, load_json_body, encode_execution_results, json_encode from .render_graphiql import render_graphiql -class HttpError(Exception): - def __init__(self, response, message=None, *args, **kwargs): - self.response = response - self.message = message = message or response.args[0] - super(HttpError, self).__init__(message, *args, **kwargs) - - class GraphQLView(HTTPMethodView): schema = None executor = None @@ -44,14 +32,11 @@ class GraphQLView(HTTPMethodView): def __init__(self, **kwargs): super(GraphQLView, self).__init__() - for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) self._enable_async = self._enable_async and isinstance(kwargs.get('executor'), AsyncioExecutor) - - assert not all((self.graphiql, self.batch)), 'Use either graphiql or batch processing' assert isinstance(self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.' # noinspection PyUnusedLocal @@ -70,211 +55,103 @@ def get_middleware(self, request): def get_executor(self, request): return self.executor + def render_graphiql(self, params, result): + return render_graphiql( + jinja_env=self.jinja_env, + params=params, + result=result, + graphiql_version=self.graphiql_version, + graphiql_template=self.graphiql_template, + ) + + format_error = staticmethod(default_format_error) + encode = staticmethod(json_encode) + async def dispatch_request(self, request, *args, **kwargs): try: - if request.method.lower() not in ('get', 'post'): - raise HttpError(SanicException('GraphQL only supports GET and POST requests.', status_code=405)) - + request_method = request.method.lower() data = self.parse_body(request) - show_graphiql = self.graphiql and self.can_display_graphiql(request, data) - if self.batch: - responses = [] - for entry in data: - responses.append(await self.get_response(request, entry)) + show_graphiql = request_method == 'get' and self.should_display_graphiql(request) + catch = HttpQueryError if show_graphiql else None + + pretty = self.pretty or show_graphiql or request.args.get('pretty') - result = '[{}]'.format(','.join([response[0] for response in responses])) - status_code = max(responses, key=lambda response: response[1])[1] - else: - result, status_code = await self.get_response(request, data, show_graphiql) + execution_results, all_params = run_http_query( + self.schema, + request_method, + data, + query_data=request.args, + batch_enabled=self.batch, + catch=catch, + + # Execute options + return_promise=self._enable_async, + root_value=self.get_root_value(request), + context_value=self.get_context(request), + middleware=self.get_middleware(request), + executor=self.get_executor(request), + ) + awaited_execution_results = await Promise.all(execution_results) + result, status_code = encode_execution_results( + awaited_execution_results, + is_batch=isinstance(data, list), + format_error=self.format_error, + encode=partial(self.encode, pretty=pretty) + ) if show_graphiql: - query, variables, operation_name, id = self.get_graphql_params(request, data) - return await render_graphiql( - jinja_env=self.jinja_env, - graphiql_version=self.graphiql_version, - graphiql_template=self.graphiql_template, - query=query, - variables=variables, - operation_name=operation_name, + return await self.render_graphiql( + params=all_params[0], result=result ) return HTTPResponse( - status=status_code, - body=result, - content_type='application/json' + result, + status=status_code, + content_type='application/json' ) - except HttpError as e: + except HttpQueryError as e: return HTTPResponse( - self.json_encode(request, { - 'errors': [self.format_error(e)] + self.encode({ + 'errors': [default_format_error(e)] }), - status=e.response.status_code, - headers={'Allow': 'GET, POST'}, + status=e.status_code, + headers=e.headers, content_type='application/json' ) - async def get_response(self, request, data, show_graphiql=False): - query, variables, operation_name, id = self.get_graphql_params(request, data) - - execution_result = await self.execute_graphql_request( - request, - data, - query, - variables, - operation_name, - show_graphiql - ) - - status_code = 200 - if execution_result: - response = {} - - if execution_result.errors: - response['errors'] = [self.format_error(e) for e in execution_result.errors] - - if execution_result.invalid: - status_code = 400 - else: - status_code = 200 - response['data'] = execution_result.data - - if self.batch: - response = { - 'id': id, - 'payload': response, - 'status': status_code, - } - - result = self.json_encode(request, response, show_graphiql) - else: - result = None - - return result, status_code - - def json_encode(self, request, d, show_graphiql=False): - pretty = self.pretty or show_graphiql or request.args.get('pretty') - if not pretty: - return json.dumps(d, separators=(',', ':')) - - return json.dumps(d, sort_keys=True, - indent=2, separators=(',', ': ')) - # noinspection PyBroadException def parse_body(self, request): - content_type = self.get_content_type(request) + content_type = self.get_mime_type(request) if content_type == 'application/graphql': - return {'query': request.body.decode()} + return {'query': request.body.decode('utf8')} elif content_type == 'application/json': - try: - request_json = json.loads(request.body.decode('utf-8')) - if (self.batch and not isinstance(request_json, list)) or ( - not self.batch and not isinstance(request_json, dict)): - raise Exception() - except: - raise HttpError(SanicException('POST body sent invalid JSON.', status_code=400)) - return request_json - - elif content_type == 'application/x-www-form-urlencoded': - return request.form + return load_json_body(request.body.decode('utf8')) - elif content_type == 'multipart/form-data': + elif content_type in ('application/x-www-form-urlencoded', 'multipart/form-data'): return request.form return {} - async def execute(self, *args, **kwargs): - result = execute(self.schema, return_promise=self._enable_async, *args, **kwargs) - if isinstance(result, Promise): - return await result - else: - return result - - async def execute_graphql_request(self, request, data, query, variables, operation_name, show_graphiql=False): - if not query: - if show_graphiql: - return None - raise HttpError(SanicException('Must provide query string.', status_code=400)) - - try: - source = Source(query, name='GraphQL request') - ast = parse(source) - validation_errors = validate(self.schema, ast) - if validation_errors: - return ExecutionResult( - errors=validation_errors, - invalid=True, - ) - except Exception as e: - return ExecutionResult(errors=[e], invalid=True) - - if request.method.lower() == 'get': - operation_ast = get_operation_ast(ast, operation_name) - if operation_ast and operation_ast.operation != 'query': - if show_graphiql: - return None - raise HttpError(SanicException( - 'Can only perform a {} operation from a POST request.'.format(operation_ast.operation), - status_code=405, - )) - - try: - return await self.execute( - ast, - root_value=self.get_root_value(request), - variable_values=variables or {}, - operation_name=operation_name, - context_value=self.get_context(request), - middleware=self.get_middleware(request), - executor=self.get_executor(request) - ) - except Exception as e: - return ExecutionResult(errors=[e], invalid=True) - - @classmethod - def can_display_graphiql(cls, request, data): - raw = 'raw' in request.args or 'raw' in data - return not raw and cls.request_wants_html(request) - - @classmethod - def request_wants_html(cls, request): - # Ugly hack - accept = request.headers.get('accept', {}) - return 'text/html' in accept or '*/*' in accept - - @staticmethod - def get_graphql_params(request, data): - query = request.args.get('query') or data.get('query') - variables = request.args.get('variables') or data.get('variables') - id = request.args.get('id') or data.get('id') - - if variables and isinstance(variables, six.text_type): - try: - variables = json.loads(variables) - except: - raise HttpError(SanicException('Variables are invalid JSON.', status_code=400)) - - operation_name = request.args.get('operationName') or data.get('operationName') - - return query, variables, operation_name, id - @staticmethod - def format_error(error): - if isinstance(error, GraphQLError): - return format_graphql_error(error) - - return {'message': six.text_type(error)} - - @staticmethod - def get_content_type(request): + def get_mime_type(request): # We use mimetype here since we don't need the other # information provided by content_type if 'content-type' not in request.headers: - mimetype = 'text/plain' - else: - mimetype, params = parse_header(request.headers['content-type']) + return None + mimetype, _ = parse_header(request.headers['content-type']) return mimetype + + def should_display_graphiql(self, request): + if not self.graphiql or 'raw' in request.args: + return False + + return self.request_wants_html(request) + + def request_wants_html(self, request): + accept = request.headers.get('accept', {}) + return 'text/html' in accept or '*/*' in accept diff --git a/sanic_graphql/render_graphiql.py b/sanic_graphql/render_graphiql.py index 3b914a8..402fa0f 100644 --- a/sanic_graphql/render_graphiql.py +++ b/sanic_graphql/render_graphiql.py @@ -162,18 +162,24 @@ def simple_renderer(template, **values): return template -async def render_graphiql(*, jinja_env=None, graphiql_version=None, graphiql_template=None, **kwargs): +async def render_graphiql(jinja_env=None, graphiql_version=None, graphiql_template=None, params=None, result=None): graphiql_version = graphiql_version or GRAPHIQL_VERSION template = graphiql_template or TEMPLATE - kwargs['graphiql_version'] = graphiql_version + template_vars = { + 'graphiql_version': graphiql_version, + 'query': params and params.query, + 'variables': params and params.variables, + 'operation_name': params and params.operation_name, + 'result': result, + } if jinja_env: template = jinja_env.from_string(template) if jinja_env.is_async: - source = await template.render_async(**kwargs) + source = await template.render_async(**template_vars) else: - source = template.render(**kwargs) + source = template.render(**template_vars) else: - source = simple_renderer(template, **kwargs) + source = simple_renderer(template, **template_vars) return html(source) diff --git a/setup.py b/setup.py index 7b556cf..de55a0b 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,11 @@ from setuptools import setup, find_packages -required_packages = ['graphql-core>=1.0', 'sanic>=0.4.0', 'pytest-runner'] +required_packages = [ + 'graphql-core>=1.0', + 'graphql-server-core>=1.0.dev', + 'sanic>=0.4.0', + 'pytest-runner' +] setup( name='Sanic-GraphQL', diff --git a/tests/test_graphqlview.py b/tests/test_graphqlview.py index 42b102e..1a88178 100644 --- a/tests/test_graphqlview.py +++ b/tests/test_graphqlview.py @@ -11,7 +11,11 @@ except ImportError: from urllib.parse import urlencode -from aiohttp.helpers import FormData +try: + from aiohttp.helpers import FormData +except ImportError: + from aiohttp.formdata import FormData + from graphql.execution.executors.asyncio import AsyncioExecutor from graphql.execution.executors.sync import SyncExecutor @@ -204,7 +208,7 @@ def test_allows_sending_a_mutation_via_post(app): def test_allows_post_with_url_encoding(app): data = FormData() data.add_field('query', '{test}') - _, response = app.client.post(uri=url_string(), data=data('utf-8'), headers={'content-type': data.content_type}) + _, response = app.client.post(uri=url_string(), data=data(), headers={'content-type': 'application/x-www-form-urlencoded'}) assert response.status == 200 assert response_json(response) == { @@ -407,12 +411,12 @@ def test_handles_errors_caused_by_a_lack_of_query(app): @parametrize_sync_async_app_test('app') -def test_handles_invalid_json_bodies(app): +def test_handles_batch_correctly_if_is_disabled(app): _, response = app.client.post(uri=url_string(), data='[]', headers={'content-type': 'application/json'}) assert response.status == 400 assert response_json(response) == { - 'errors': [{'message': 'POST body sent invalid JSON.'}] + 'errors': [{'message': 'Batch GraphQL requests are not enabled.'}] } @@ -523,9 +527,7 @@ def test_batch_allows_post_with_json_encoding(app): assert response.status == 200 assert response_json(response) == [{ - 'id': 1, - 'payload': { 'data': {'test': "Hello World"} }, - 'status': 200, + 'data': {'test': "Hello World"}, }] @@ -543,9 +545,7 @@ def test_batch_supports_post_json_query_with_json_variables(app): assert response.status == 200 assert response_json(response) == [{ - 'id': 1, - 'payload': { 'data': {'test': "Hello Dolly"} }, - 'status': 200, + 'data': {'test': "Hello Dolly"}, }] @@ -570,14 +570,10 @@ def test_batch_allows_post_with_operation_name(app): assert response.status == 200 assert response_json(response) == [{ - 'id': 1, - 'payload': { - 'data': { - 'test': 'Hello World', - 'shared': 'Hello Everyone' - } - }, - 'status': 200, + 'data': { + 'test': 'Hello World', + 'shared': 'Hello Everyone' + } }] diff --git a/tox.ini b/tox.ini index 75c9fb5..2612ab4 100644 --- a/tox.ini +++ b/tox.ini @@ -11,6 +11,7 @@ python = deps = pytest>=2.7.2 graphql-core>=1.0 + graphql-server-core>=1.0.dev sanic>=0.3.1 aiohttp>=1.3.0 jinja2>=2.9.0