diff --git a/autopush/router/apnsrouter.py b/autopush/router/apnsrouter.py index 3ebf4b42..e2fb6436 100644 --- a/autopush/router/apnsrouter.py +++ b/autopush/router/apnsrouter.py @@ -89,7 +89,7 @@ def register(self, uaid, router_data, app_id, *args, **kwargs): status_code=400, response_body="Unknown release channel") if not router_data.get("token"): - raise RouterException("No token registered", status_code=500, + raise RouterException("No token registered", status_code=400, response_body="No token registered") router_data["rel_channel"] = app_id return router_data diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index d70f972b..69b35fe0 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -879,3 +879,15 @@ def handle_finish(value): router_token="test", uaid=dummy_uaid.hex)) return self.finish_deferred + + def test_get_no_uaid(self): + self.reg.request.headers['Authorization'] = self.auth + + def handle_finish(value): + self.status_mock.assert_called_with(410, reason=None) + + self.finish_deferred.addCallback(handle_finish) + self.reg.get(self._make_req( + router_type="test", + router_token="test")) + return self.finish_deferred diff --git a/autopush/web/base.py b/autopush/web/base.py index b9eb4643..545fcfc4 100644 --- a/autopush/web/base.py +++ b/autopush/web/base.py @@ -239,7 +239,8 @@ def _boto_err(self, fail): def _router_response(self, response): for name, val in response.headers.items(): - self.set_header(name, val) + if val is not None: + self.set_header(name, val) if 200 <= response.status_code < 300: self.set_status(response.status_code, reason=None) diff --git a/autopush/web/registration.py b/autopush/web/registration.py index c1033fa5..9680738a 100644 --- a/autopush/web/registration.py +++ b/autopush/web/registration.py @@ -226,6 +226,11 @@ def _register_channel(self, router_data=None): self.app_server_key) return endpoint, router_data + def _check_uaid(self, uaid): + if not uaid or uaid == 'None': + raise ItemNotFound("UAID not found") + return uaid + @threaded_validate(RegistrationSchema) def get(self, *args, **kwargs): """HTTP GET @@ -235,11 +240,11 @@ def get(self, *args, **kwargs): """ self.uaid = self.valid_input['uaid'] self.add_header("Content-Type", "application/json") - d = deferToThread(self.ap_settings.message.all_channels, - str(self.uaid)) + d = deferToThread(self._check_uaid, str(self.uaid)) + d.addCallback(self.ap_settings.message.all_channels) d.addCallback(self._write_channels) - d.addErrback(self._response_err) d.addErrback(self._uaid_not_found_err) + d.addErrback(self._response_err) return d @threaded_validate(RegistrationSchema)