diff --git a/autopush/base.py b/autopush/base.py index ad494332..728dee2a 100644 --- a/autopush/base.py +++ b/autopush/base.py @@ -1,9 +1,25 @@ +import json import uuid import cyclone.web from twisted.logger import Logger from twisted.python import failure +status_codes = { + 200: "OK", + 201: "Created", + 202: "Accepted", + 400: "Bad Request", + 401: "Unauthorized", + 404: "Not Found", + 413: "Payload Too Large", + 418: "I'm a teapot", + 500: "Internal Server Error", + 503: "Service Unavailable", +} +DEFAULT_ERR_URL = ("http://autopush.readthedocs.io/en/latest/http.html" + "#error-codes") + class BaseHandler(cyclone.web.RequestHandler): """Base cyclone RequestHandler for autopush""" @@ -77,3 +93,35 @@ def authenticate_peer_cert(self): self.set_header('WWW-Authenticate', 'Transport mode="tls-client-certificate"') self.finish() + + def _write_response(self, status_code, errno=None, message=None, + error=None, headers=None, url=DEFAULT_ERR_URL): + """Writes out a full JSON error and sets the appropriate status""" + self.set_status(status_code, reason=error) + error_data = dict( + code=status_code, + error=error or status_codes.get(status_code, ""), + more_info=url, + ) + if errno: + error_data["errno"] = errno + if message: + error_data["message"] = message + self.write(json.dumps(error_data)) + self.set_header("Content-Type", "application/json") + if headers: + for header in headers.keys(): + self.set_header(header, headers.get(header)) + self.finish() + + +class DefaultHandler(BaseHandler): + """Unauthenticated catch-all handler that returns a 404 for + unknown paths. Cyclone matches handlers in order, so this handler + should be registered last.""" + + def authenticate_peer_cert(self): + pass + + def prepare(self): + self._write_response(404) diff --git a/autopush/main.py b/autopush/main.py index 4ed31183..54102a38 100644 --- a/autopush/main.py +++ b/autopush/main.py @@ -12,6 +12,7 @@ import autopush.db as db import autopush.utils as utils +from autopush.base import DefaultHandler from autopush.logging import PushLogger from autopush.settings import AutopushSettings from autopush.ssl import AutopushSSLContextFactory @@ -587,6 +588,7 @@ def endpoint_main(sysargs=None, use_files=True): (endpoint_paths['message'], MessageHandler, h_kwargs), (endpoint_paths['registration'], RegistrationHandler, h_kwargs), (endpoint_paths['logcheck'], LogCheckHandler, h_kwargs), + (r".*", DefaultHandler, h_kwargs), ], default_host=settings.hostname, debug=args.debug, log_function=skip_request_logging diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index e32fa3f7..113f1885 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -30,6 +30,7 @@ import autopush.db as db from autopush import __version__ +from autopush.base import DefaultHandler from autopush.db import ( create_rotating_message_table, get_month, @@ -397,6 +398,8 @@ def setUp(self): # GET /register/uaid => chid + endpoint (endpoint_paths['registration'], RegistrationHandler, h_kwargs), (endpoint_paths['logcheck'], LogCheckHandler, h_kwargs), + + (r".*", DefaultHandler, h_kwargs), ], default_host=settings.hostname, log_function=skip_request_logging, diff --git a/autopush/tests/test_web_base.py b/autopush/tests/test_web_base.py index 51f42962..eab1331e 100644 --- a/autopush/tests/test_web_base.py +++ b/autopush/tests/test_web_base.py @@ -175,7 +175,7 @@ def test_init_info(self, t): eq_(d["authorization"], "webpush token barney") def test_write_response(self): - self.base._write_response(400, 103, message="Fail", + self.base._write_response(400, errno=103, message="Fail", headers=dict(Location="http://a.com/")) self.status_mock.assert_called_with(400, reason=None) diff --git a/autopush/web/base.py b/autopush/web/base.py index 7a06b665..c3b476e2 100644 --- a/autopush/web/base.py +++ b/autopush/web/base.py @@ -11,23 +11,9 @@ from autopush.base import BaseHandler from autopush.exceptions import InvalidRequest, RouterException -status_codes = { - 200: "OK", - 201: "Created", - 202: "Accepted", - 400: "Bad Request", - 401: "Unauthorized", - 404: "Not Found", - 413: "Payload Too Large", - 418: "I'm a teapot", - 500: "Internal Server Error", - 503: "Service Unavailable", -} # Older versions used "bearer", newer specification requires "webpush" AUTH_SCHEMES = ["bearer", "webpush"] PREF_SCHEME = "webpush" -DEFAULT_ERR_URL = ("http://autopush.readthedocs.io/en/latest/http.html" - "#error-codes") class ThreadedValidate(object): @@ -153,26 +139,6 @@ def head(self, *args, **kwargs): ############################################################# # Error Callbacks ############################################################# - def _write_response(self, status_code, errno, message=None, error=None, - headers=None, - url=DEFAULT_ERR_URL): - """Writes out a full JSON error and sets the appropriate status""" - self.set_status(status_code, reason=error) - error_data = dict( - code=status_code, - errno=errno, - error=error or status_codes.get(status_code, ""), - more_info=url, - ) - if message: - error_data["message"] = message - self.write(json.dumps(error_data)) - self.set_header("Content-Type", "application/json") - if headers: - for header in headers.keys(): - self.set_header(header, headers.get(header)) - self.finish() - def _validation_err(self, fail): """errBack for validation errors""" fail.trap(InvalidRequest) @@ -181,7 +147,7 @@ def _validation_err(self, fail): status_code=exc.status_code, errno=exc.errno, client_info=self._client_info) - self._write_response(exc.status_code, exc.errno, + self._write_response(exc.status_code, errno=exc.errno, message="Request did not validate %s" % (exc.message or ""), headers=exc.headers) @@ -197,15 +163,16 @@ def _response_err(self, fail): self.log.failure(format=fmt, failure=fail, status_code=500, errno=999, client_info=self._client_info) - self._write_response(500, 999, message="An unexpected server error" - " occurred.") + self._write_response(500, errno=999, + message="An unexpected server error" + " occurred.") def _overload_err(self, fail): """errBack for throughput provisioned exceptions""" fail.trap(ProvisionedThroughputExceededException) self.log.info(format="Throughput Exceeded", status_code=503, errno=201, client_info=self._client_info) - self._write_response(503, 201, + self._write_response(503, errno=201, message="Please slow message send rate") def _boto_err(self, fail): diff --git a/autopush/web/log_check.py b/autopush/web/log_check.py index 4576554c..731b2336 100644 --- a/autopush/web/log_check.py +++ b/autopush/web/log_check.py @@ -35,7 +35,7 @@ def get(self, err_type=None, *args, **kwargs): self.log.error(format="Test Error Message", status_code=418, errno=0, client_info=self._client_info) - self._write_response(418, 999, message="ERROR:Success", + self._write_response(418, errno=999, message="ERROR:Success", error="Test Error") if 'crit' in err_type: try: @@ -44,5 +44,6 @@ def get(self, err_type=None, *args, **kwargs): self.log.failure(format="Test Critical Message", status_code=418, errno=0, client_info=self._client_info) - self._write_response(418, 999, message="FAILURE:Success", + self._write_response(418, errno=999, + message="FAILURE:Success", error="Test Failure") diff --git a/autopush/web/registration.py b/autopush/web/registration.py index 4f59540a..a9bf00e8 100644 --- a/autopush/web/registration.py +++ b/autopush/web/registration.py @@ -257,7 +257,7 @@ def _chid_not_found_err(self, fail): self.log.info(format="CHID not found in AWS.", status_code=410, errno=106, **self._client_info) - self._write_response(410, 106, message="Invalid endpoint.") + self._write_response(410, errno=106, message="Invalid endpoint.") ############################################################# # Callbacks