diff --git a/autopush/db.py b/autopush/db.py index f31b7d22..b8506519 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -679,7 +679,7 @@ def get_uaid(self, uaid): @track_provisioned def register_user(self, data): - # type: (ItemLike) -> Tuple[bool, Dict[str, Any], Dict[str, Any]] + # type: (ItemLike) -> Tuple[bool, Dict[str, Any]] """Register this user If a record exists with a newer ``connected_at``, then the user will @@ -728,9 +728,9 @@ def register_user(self, data): # this not work r[key] = value result = r - return (True, result, data) + return (True, result) except ConditionalCheckFailedException: - return (False, {}, data) + return (False, {}) @track_provisioned def drop_user(self, uaid): diff --git a/autopush/router/apnsrouter.py b/autopush/router/apnsrouter.py index e2fb6436..2415a663 100644 --- a/autopush/router/apnsrouter.py +++ b/autopush/router/apnsrouter.py @@ -1,5 +1,6 @@ """APNS Router""" import uuid +from typing import Any, Dict # noqa from hyper.http20.exceptions import ConnectionError, HTTP20Error from twisted.internet.threads import deferToThread @@ -67,21 +68,15 @@ def __init__(self, ap_settings, router_conf, load_connections=True): self.log.debug("Starting APNS router...") def register(self, uaid, router_data, app_id, *args, **kwargs): + # type: (str, Dict[str, Any], str, *Any, **Any) -> None """Register an endpoint for APNS, on the `app_id` release channel. This will validate that an APNs instance token is in the `router_data`, :param uaid: User Agent Identifier - :type uaid: str :param router_data: Dict containing router specific configuration info - :type router_data: dict :param app_id: The release channel identifier for cert info lookup - :type app_id: str - - :returns: a modified router_data for the user agent record. - :rtype: dict - """ if app_id not in self.apns: @@ -92,11 +87,9 @@ def register(self, uaid, router_data, app_id, *args, **kwargs): raise RouterException("No token registered", status_code=400, response_body="No token registered") router_data["rel_channel"] = app_id - return router_data - def amend_msg(self, msg, router_data=None): + def amend_endpoint_response(self, response, router_data): """This function is stubbed out for this router""" - return msg def route_notification(self, notification, uaid_data): """Start the APNS notification routing, returns a deferred diff --git a/autopush/router/fcm.py b/autopush/router/fcm.py index c8b98c55..7bcaf8b8 100644 --- a/autopush/router/fcm.py +++ b/autopush/router/fcm.py @@ -1,4 +1,5 @@ """FCM Router""" +from typing import Any, Dict # noqa import pyfcm from requests.exceptions import ConnectionError @@ -117,12 +118,11 @@ def __init__(self, ap_settings, router_conf): self.log.debug("Starting FCM router...") self.ap_settings = ap_settings - def amend_msg(self, msg, data=None): - if data is not None: - msg["senderid"] = data.get('creds', {}).get('senderID') - return msg + def amend_endpoint_response(self, response, router_data): + response["senderid"] = router_data.get('creds', {}).get('senderID') def register(self, uaid, router_data, app_id, *args, **kwargs): + # type: (str, Dict[str, Any], str, *Any, **Any) -> None """Validate that the FCM Instance Token is in the ``router_data``""" senderid = app_id # "token" is the GCM registration id token generated by the client. @@ -142,7 +142,6 @@ def register(self, uaid, router_data, app_id, *args, **kwargs): # Assign a senderid router_data["creds"] = {"senderID": self.senderID, "auth": self.auth} - return router_data def route_notification(self, notification, uaid_data): """Start the FCM notification routing, returns a deferred""" diff --git a/autopush/router/gcm.py b/autopush/router/gcm.py index 26f1b402..036edf3c 100644 --- a/autopush/router/gcm.py +++ b/autopush/router/gcm.py @@ -1,4 +1,5 @@ """GCM Router""" +from typing import Any, Dict # noqa import gcmclient from requests.exceptions import ConnectionError @@ -40,12 +41,11 @@ def __init__(self, ap_settings, router_conf): self.log.debug("Starting GCM router...") self.ap_settings = ap_settings - def amend_msg(self, msg, data=None): - if data is not None: - msg["senderid"] = data.get('creds', {}).get('senderID') - return msg + def amend_endpoint_response(self, response, router_data): + response["senderid"] = router_data.get('creds', {}).get('senderID') def register(self, uaid, router_data, app_id, *args, **kwargs): + # type: (str, Dict[str, Any], str, *Any, **Any) -> None """Validate that the GCM Instance Token is in the ``router_data``""" # "token" is the GCM registration id token generated by the client. if "token" not in router_data: @@ -65,7 +65,6 @@ def register(self, uaid, router_data, app_id, *args, **kwargs): # Assign a senderid router_data["creds"] = {"senderID": senderid, "auth": self.senderIDs[senderid]} - return router_data def route_notification(self, notification, uaid_data): """Start the GCM notification routing, returns a deferred""" diff --git a/autopush/router/interface.py b/autopush/router/interface.py index 5cfbac8d..4a995b7b 100644 --- a/autopush/router/interface.py +++ b/autopush/router/interface.py @@ -1,5 +1,5 @@ """Router interface""" -from typing import Any, Optional # noqa +from typing import Any, Dict, Optional # noqa class RouterResponse(object): @@ -28,10 +28,9 @@ def __init__(self, settings, router_conf): raise NotImplementedError("__init__ must be implemented") def register(self, uaid, router_data, app_id, *args, **kwargs): - # type: (str, dict, str, *Any, **Any) -> dict + # type: (str, Dict[str, Any], str, *Any, **Any) -> None """Register the uaid with the router_data dict however is preferred - and return a dict that will be stored as router_data for this - user in the future. + prior to storing the router_data for this user in the future. :param uaid: User Agent Identifier :param router_data: Route specific configuration info @@ -43,18 +42,20 @@ def register(self, uaid, router_data, app_id, *args, **kwargs): """ raise NotImplementedError("register must be implemented") - def amend_msg(self, msg, router_data=None): - # type: (dict, Optional[dict]) -> dict - """Modify an outbound response message to include router info + def amend_endpoint_response(self, response, router_data): + # type: (Dict[str, Any], Dict[str, Any]) -> None + """Modify an outbound Endpoint registration response message to + include router info. - :param msg: A dict of the response data to be sent to the client - :param router_data: a dictionary of router data - :returns: A potentially modified dict to return to the client + Some routers require additional info to be returned to + clients. - Some routers may require additional info to be returned to clients. + :param response: The response data to be sent to the client + :param router_data: Route specific configuration info """ - raise NotImplementedError("amend_msg must be implemented") + raise NotImplementedError( + "amend_endpoint_response must be implemented") def route_notification(self, notification, uaid_data): """Route a notification diff --git a/autopush/router/simple.py b/autopush/router/simple.py index 4070a11b..99505423 100644 --- a/autopush/router/simple.py +++ b/autopush/router/simple.py @@ -8,6 +8,7 @@ import json from urllib import urlencode from StringIO import StringIO +from typing import Any # noqa import requests from boto.dynamodb2.exceptions import ( @@ -47,12 +48,12 @@ def __init__(self, ap_settings, router_conf): self.conf = router_conf self.waker = None - def register(self, uaid, *args, **kwargs): - """Return no additional routing data""" - return {} + def register(self, *args, **kwargs): + # type: (*Any, **Any) -> None + """No additional routing data""" - def amend_msg(self, msg, router_data=None): - return msg + def amend_endpoint_response(self, response, router_data): + """This function is stubbed out for this router""" def stored_response(self, notification): return RouterResponse(202, "Notification Stored") diff --git a/autopush/router/webpush.py b/autopush/router/webpush.py index 74d4e126..045d25de 100644 --- a/autopush/router/webpush.py +++ b/autopush/router/webpush.py @@ -92,6 +92,3 @@ def _save_notification(self, uaid_data, notification): self.ap_settings.message_tables[month_table].store_message, notification=notification, ) - - def amend_msg(self, msg, router_data=None): - return msg diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index a8f0f389..3204caf9 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -501,7 +501,7 @@ def raise_condition(*args, **kwargs): router_data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234, router_type="simplepush") result = router.register_user(router_data) - eq_(result, (False, {}, router_data)) + eq_(result, (False, {})) def test_node_clear(self): r = get_router_table() diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index 69b35fe0..2e75fc46 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -23,6 +23,7 @@ create_rotating_message_table, has_connected_this_month, ) +from autopush.exceptions import RouterException from autopush.settings import AutopushSettings from autopush.router.interface import IRouter from autopush.tests.test_db import make_webpush_notification @@ -434,6 +435,28 @@ def restore(*args, **kwargs): self.reg.post(self._make_req()) return self.finish_deferred + def test_post_bad_router_register(self, *args): + frouter = Mock(spec=IRouter) + self.reg.ap_settings.routers["simplepush"] = frouter + rexc = RouterException("invalid", status_code=402, errno=107) + frouter.register = Mock(side_effect=rexc) + + self.reg.request.body = json.dumps(dict( + type="simplepush", + channelID=str(dummy_chid), + data={}, + )) + self.reg.request.uri = "/v1/xxx/yyy/register" + self.reg.request.headers["Authorization"] = self.auth + + def handle_finish(value): + self._check_error(rexc.status_code, rexc.errno, "") + + self.finish_deferred.addBoth(handle_finish) + self.reg.post(self._make_req("simplepush", "", + body=self.reg.request.body)) + return self.finish_deferred + def test_post_existing_uaid(self, *args): self.reg.request.body = json.dumps(dict( channelID=str(dummy_chid), @@ -703,8 +726,8 @@ def handle_finish(value): self.reg.write.assert_called_with({}) frouter.register.assert_called_with( dummy_uaid.hex, - router_data=data, - app_id='', + data, + '', uri=self.reg.request.uri ) @@ -757,6 +780,19 @@ def restore(*args, **kwargs): self.reg.put(self._make_req(uaid=dummy_uaid.hex)) return self.finish_deferred + def test_put_bad_router_register(self): + frouter = self.reg.ap_settings.routers["test"] + rexc = RouterException("invalid", status_code=402, errno=107) + frouter.register = Mock(side_effect=rexc) + + def handle_finish(value): + self._check_error(rexc.status_code, rexc.errno, "") + + self.finish_deferred.addCallback(handle_finish) + self.reg.request.headers["Authorization"] = self.auth + self.reg.put(self._make_req(router_type='test', uaid=dummy_uaid.hex)) + return self.finish_deferred + def test_delete_bad_chid_value(self): notif = make_webpush_notification(dummy_uaid.hex, str(dummy_chid)) messages = self.reg.ap_settings.message @@ -865,6 +901,8 @@ def test_get(self): chids = [str(dummy_chid), str(dummy_uaid)] def handle_finish(value): + self.settings.message.all_channels.assert_called_with( + str(dummy_uaid)) call_args = json.loads( self.reg.write.call_args[0][0] ) diff --git a/autopush/tests/test_router.py b/autopush/tests/test_router.py index a6253b3a..347497bc 100644 --- a/autopush/tests/test_router.py +++ b/autopush/tests/test_router.py @@ -62,7 +62,7 @@ def init(self, settings, router_conf): ir = IRouter(None, None) assert_raises(NotImplementedError, ir.register, "uaid", {}, "") assert_raises(NotImplementedError, ir.route_notification, "uaid", {}) - assert_raises(NotImplementedError, ir.amend_msg, {}) + assert_raises(NotImplementedError, ir.amend_endpoint_response, {}, {}) # FOR LEGACY REASONS, CHANNELID MUST BE IN HEX FORMAT FOR BRIDGE PUBLICATION @@ -129,10 +129,9 @@ def setUp(self, mt, mc): rel_channel="firefox")) def test_register(self): - result = self.router.register("uaid", - router_data={"token": "connect_data"}, - app_id="firefox") - eq_(result, {"rel_channel": "firefox", "token": "connect_data"}) + router_data = {"token": "connect_data"} + self.router.register("uaid", router_data=router_data, app_id="firefox") + eq_(router_data, {"rel_channel": "firefox", "token": "connect_data"}) def test_register_bad(self): with assert_raises(RouterException): @@ -247,7 +246,9 @@ def test_too_many_connections(self): def test_amend(self): resp = {"key": "value"} - eq_(resp, self.router.amend_msg(resp)) + expected = resp.copy() + self.router.amend_endpoint_response(resp, {}) + eq_(resp, expected) def test_route_crypto_key(self): headers = {"content-encoding": "aesgcm", @@ -331,13 +332,12 @@ def test_init(self): GCMRouter(settings, {"senderIDs": {}}) def test_register(self): - result = self.router.register("uaid", - router_data={"token": "test123"}, - app_id="test123") + router_data = {"token": "test123"} + self.router.register("uaid", router_data=router_data, app_id="test123") # Check the information that will be recorded for this user - eq_(result, {"token": "test123", - "creds": {"senderID": "test123", - "auth": "12345678abcdefg"}}) + eq_(router_data, {"token": "test123", + "creds": {"senderID": "test123", + "auth": "12345678abcdefg"}}) def test_register_bad(self): with assert_raises(RouterException): @@ -571,14 +571,12 @@ def check_results(fail): return d def test_amend(self): - self.router.register("uaid", - router_data={"token": "test123"}, - app_id="test123") + router_data = {"token": "test123"} + self.router.register("uaid", router_data=router_data, app_id="test123") resp = {"key": "value"} - result = self.router.amend_msg(resp, - self.router_data.get('router_data')) - eq_({"key": "value", "senderid": "test123"}, - result) + self.router.amend_endpoint_response( + resp, self.router_data.get('router_data')) + eq_({"key": "value", "senderid": "test123"}, resp) def test_register_invalid_token(self): with assert_raises(RouterException): @@ -649,17 +647,16 @@ def throw_auth(*args, **kwargs): FCMRouter(settings, {}) def test_register(self): - result = self.router.register("uaid", - router_data={"token": "test123"}, - app_id="test123") + router_data = {"token": "test123"} + self.router.register("uaid", router_data=router_data, app_id="test123") # Check the information that will be recorded for this user - eq_(result, {"token": "test123", - "creds": {"senderID": "test123", - "auth": "12345678abcdefg"}}) + eq_(router_data, {"token": "test123", + "creds": {"senderID": "test123", + "auth": "12345678abcdefg"}}) def test_register_bad(self): with assert_raises(RouterException): - self.router.register("uaid", router_data={}) + self.router.register("uaid", router_data={}, app_id="invalid123") def test_route_notification(self): self.router.fcm = self.fcm @@ -862,10 +859,9 @@ def test_amend(self): router_data={"token": "test123"}, app_id="test123") resp = {"key": "value"} - result = self.router.amend_msg(resp, - self.router_data.get('router_data')) - eq_({"key": "value", "senderid": "test123"}, - result) + self.router.amend_endpoint_response( + resp, self.router_data.get('router_data')) + eq_({"key": "value", "senderid": "test123"}, resp) def test_register_invalid_token(self): with assert_raises(RouterException): @@ -910,8 +906,9 @@ def _raise_item_error(self): raise ItemNotFound() def test_register(self): - r = self.router.register("uaid", router_data={}) - eq_(r, {}) + router_data = {} + self.router.register("uaid", router_data=router_data) + eq_(router_data, {}) def test_route_to_connected(self): self.agent_mock.request.return_value = response_mock = Mock() @@ -1109,9 +1106,11 @@ def check_deliver(result): eq_(self.router.udp, udp_data) return d - def test_ammend(self): + def test_amend(self): resp = {"key": "value"} - eq_(resp, self.router.amend_msg(resp)) + expected = resp.copy() + self.router.amend_endpoint_response(resp, {}) + eq_(resp, expected) class WebPushRouterTestCase(unittest.TestCase): @@ -1210,6 +1209,8 @@ def verify_deliver(fail): d.addBoth(verify_deliver) return d - def test_ammend(self): + def test_amend(self): resp = {"key": "value"} - eq_(resp, self.router.amend_msg(resp)) + expected = resp.copy() + self.router.amend_endpoint_response(resp, {}) + eq_(resp, expected) diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index cfe607cd..82cd2215 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -554,7 +554,7 @@ def test_hello_old(self): )) def fake_msg(data): - return (True, msg_data, data) + return (True, msg_data) mock_msg = Mock(wraps=db.Message) mock_msg.fetch_messages.return_value = [] @@ -616,7 +616,7 @@ def test_hello_tomorrow(self): } def fake_msg(data): - return (True, msg_data, data) + return (True, msg_data) mock_msg = Mock(wraps=db.Message) mock_msg.fetch_messages.return_value = "01;", [] @@ -690,7 +690,7 @@ def test_hello_tomorrow_provision_error(self): } def fake_msg(data): - return (True, msg_data, data) + return (True, msg_data) mock_msg = Mock(wraps=db.Message) mock_msg.fetch_messages.return_value = "01;", [] @@ -960,7 +960,7 @@ def test_hello_check_fail(self): # Fail out the register_user call self.proto.ap_settings.router.register_user = \ - Mock(return_value=(False, {}, {})) + Mock(return_value=(False, {})) self._send_message(dict(messageType="hello", channelIDs=[])) @@ -1386,7 +1386,7 @@ def test_register_kill_others(self): self.proto.ps.uaid = uaid connected = int(time.time()) res = dict(node_id=node_id, connected_at=connected, uaid=uaid) - self.proto._check_other_nodes((True, res, None)) + self.proto._check_other_nodes((True, res)) mock_agent.request.assert_called_with( "DELETE", "%s/notif/%s/%s" % (node_id, uaid, connected)) @@ -1401,7 +1401,7 @@ def test_register_kill_others_fail(self): self.proto.ps.uaid = uaid connected = int(time.time()) res = dict(node_id=node_id, connected_at=connected, uaid=uaid) - self.proto._check_other_nodes((True, res, None)) + self.proto._check_other_nodes((True, res)) d.errback(ConnectError()) return d @@ -1448,7 +1448,7 @@ def test_check_kill_self(self): self.proto.sendClose = Mock() self.proto.ps.uaid = uaid res = dict(node_id=node_id, connected_at=connected, uaid=uaid) - self.proto._check_other_nodes((True, res, None)) + self.proto._check_other_nodes((True, res)) # the current one should be dropped. eq_(ff.sendClose.call_count, 0) eq_(self.proto.sendClose.call_count, 1) @@ -1468,7 +1468,7 @@ def test_check_kill_existing(self): self.proto.sendClose = Mock() self.proto.ps.uaid = uaid res = dict(node_id=node_id, connected_at=connected, uaid=uaid) - self.proto._check_other_nodes((True, res, None)) + self.proto._check_other_nodes((True, res)) # the existing one should be dropped. eq_(ff.sendClose.call_count, 1) eq_(self.proto.sendClose.call_count, 0) diff --git a/autopush/web/registration.py b/autopush/web/registration.py index 9d8ad09f..57760b7e 100644 --- a/autopush/web/registration.py +++ b/autopush/web/registration.py @@ -2,6 +2,13 @@ import re import time import uuid +from typing import ( # noqa + Any, + Dict, + Optional, + Set, + Tuple +) from boto.dynamodb2.exceptions import ItemNotFound from cryptography.hazmat.primitives import constant_time @@ -11,7 +18,7 @@ pre_load, validates_schema ) -from twisted.internet.defer import Deferred +from twisted.internet import defer from twisted.internet.threads import deferToThread from autopush.db import generate_last_connect, hasher @@ -71,12 +78,12 @@ def extract_data(self, req): status_code=410, errno=106) return dict( - auth=req.get('headers', {}).get("Authorization"), - router_data=router_data, - router_type=req['path_kwargs'].get('router_type'), - router_token=req['path_kwargs'].get('router_token'), uaid=uaid, chid=chid, + router_type=req['path_kwargs'].get('router_type'), + router_token=req['path_kwargs'].get('router_token'), + router_data=router_data, + auth=req.get('headers', {}).get("Authorization"), ) @validates_schema(skip_on_field_errors=True) @@ -144,44 +151,42 @@ def post(self, *args, **kwargs): """ self.add_header("Content-Type", "application/json") + + uaid = self.valid_input['uaid'] + router = self.valid_input["router"] + router_type = self.valid_input["router_type"] + router_token = self.valid_input.get("router_token") router_data = self.valid_input['router_data'] + # If the client didn't provide a CHID, make one up. # Note, valid_input may explicitly set "chid" to None # THIS VALUE MUST MATCH WHAT'S SPECIFIED IN THE BRIDGE CONNECTIONS. # currently hex formatted. - self.chid = router_data["channelID"] = (self.valid_input["chid"] or - uuid.uuid4().hex) + chid = router_data["channelID"] = (self.valid_input["chid"] or + uuid.uuid4().hex) self.ap_settings.metrics.increment("updates.client.register", tags=self.base_tags()) - # If there's a UAID, ensure its valid, otherwise we ensure the hash - # matches up - new_uaid = False - - # normalize the path vars into parameters - router = self.ap_settings.routers[self.valid_input['router_type']] - - if not self.valid_input['uaid']: - self.valid_input['uaid'] = uuid.uuid4() - new_uaid = True - self.uaid = self.valid_input['uaid'] - self.app_server_key = router_data.get("key") - if new_uaid: - d = Deferred() - d.addCallback(router.register, - router_data=router_data, - app_id=self.valid_input.get("router_token"), - uri=self.request.uri) - d.addCallback(self._save_router_data, - self.valid_input["router_type"]) - d.addCallback(self._create_endpoint) - d.addCallback(self._return_endpoint, new_uaid, router) + + if not uaid: + uaid = uuid.uuid4() + d = defer.execute( + router.register, + uaid.hex, router_data, router_token, uri=self.request.uri) + d.addCallback( + lambda _: + deferToThread(self._register_user_and_channel, + uaid, chid, router, router_type, router_data) + ) + d.addCallback(self._write_endpoint, + uaid, chid, router, router_data) d.addErrback(self._router_fail_err) d.addErrback(self._response_err) - d.callback(self.valid_input['uaid'].hex) else: - d = self._create_endpoint() - d.addCallback(self._return_endpoint, new_uaid) + d = deferToThread(self._register_channel, + uaid, chid, router_data.get("key")) + d.addCallback(self._write_endpoint, uaid, chid) d.addErrback(self._response_err) + return d @threaded_validate(RegistrationSchema) def put(self, *args, **kwargs): @@ -190,19 +195,23 @@ def put(self, *args, **kwargs): Update router type/data for a UAID. """ - self.uaid = self.valid_input['uaid'] + uaid = self.valid_input['uaid'] router = self.valid_input['router'] + router_type = self.valid_input['router_type'] + router_token = self.valid_input['router_token'] + router_data = self.valid_input['router_data'] self.add_header("Content-Type", "application/json") - d = Deferred() - d.addCallback(router.register, - router_data=self.valid_input['router_data'], - app_id=self.valid_input['router_token'], - uri=self.request.uri) - d.addCallback(self._save_router_data, self.valid_input['router_type']) + d = defer.execute( + router.register, + uaid.hex, router_data, router_token, uri=self.request.uri) + d.addCallback( + lambda _: + deferToThread(self._register_user, uaid, router_data, router_type) + ) d.addCallback(self._success) d.addErrback(self._router_fail_err) d.addErrback(self._response_err) - d.callback(self.valid_input['uaid'].hex) + return d def _delete_channel(self, uaid, chid): message = self.ap_settings.message @@ -215,18 +224,9 @@ def _delete_uaid(self, uaid, router): if not router.drop_user(uaid.hex): raise ItemNotFound("UAID not found") - def _register_channel(self, router_data=None): - self.ap_settings.message.register_channel(self.uaid.hex, - self.chid) - endpoint = self.ap_settings.make_endpoint(self.uaid.hex, - self.chid, - self.app_server_key) - return endpoint, router_data - def _check_uaid(self, uaid): - if not uaid or uaid == 'None': + if not uaid: raise ItemNotFound("UAID not found") - return uaid @threaded_validate(RegistrationSchema) def get(self, *args, **kwargs): @@ -235,11 +235,14 @@ def get(self, *args, **kwargs): Return a list of known channelIDs for a given UAID """ - self.uaid = self.valid_input['uaid'] + uaid = self.valid_input['uaid'] self.add_header("Content-Type", "application/json") - d = deferToThread(self._check_uaid, str(self.uaid)) - d.addCallback(self.ap_settings.message.all_channels) - d.addCallback(self._write_channels) + d = defer.execute(self._check_uaid, uaid) + d.addCallback( + lambda _: + deferToThread(self.ap_settings.message.all_channels, str(uaid)) + ) + d.addCallback(self._write_channels, uaid) d.addErrback(self._uaid_not_found_err) d.addErrback(self._response_err) return d @@ -291,58 +294,66 @@ def _chid_not_found_err(self, fail): ############################################################# # Callbacks ############################################################# - def _save_router_data(self, router_data, router_type): - """Called when new data needs to be saved to a user-record""" - user_item = dict( - uaid=self.uaid.hex, + def _register_user_and_channel(self, + uaid, # type: uuid.UUID + chid, # type: str + router, # type: Any + router_type, # type: str + router_data # type: Dict[str, Any] + ): + # type: (...) -> str + """Register a new user/channel, return its endpoint""" + self._register_user(uaid, router_type, router_data) + return self._register_channel(uaid, chid, router_data.get("key")) + + def _register_user(self, uaid, router_type, router_data): + # type: (uuid.UUID, str, Dict[str, Any]) -> None + """Save a new user record""" + self.ap_settings.router.register_user(dict( + uaid=uaid.hex, router_type=router_type, router_data=router_data, connected_at=ms_time(), last_connect=generate_last_connect(), - ) - return deferToThread(self.ap_settings.router.register_user, user_item) + )) - def _create_endpoint(self, result=None): - """Called to register a new channel and create its endpoint.""" - router_data = None - try: - router_data = result[2] - except (IndexError, TypeError): - pass - return deferToThread(self._register_channel, router_data) - - def _return_endpoint(self, endpoint_data, new_uaid, router=None): - """Called after the endpoint was made and should be returned to the - requestor""" - hashed = None - if new_uaid: + def _register_channel(self, uaid, chid, app_server_key): + # type(uuid.UUID, str, str) -> str + """Register a new channel and create/return its endpoint""" + self.ap_settings.message.register_channel(uaid.hex, chid) + return self.ap_settings.make_endpoint(uaid.hex, chid, app_server_key) + + def _write_endpoint(self, + endpoint, # type: str + uaid, # type: uuid.UUID + chid, # type: str + router=None, # type: Optional[Any] + router_data=None # type: Optional[Dict[str, Any]] + ): + # type: (...) -> None + """Write the JSON response of the created endpoint""" + response = dict(channelID=chid, endpoint=endpoint) + if router_data is not None: + # a new uaid + secret = None if self.ap_settings.bear_hash_key: - hashed = generate_hash(self.ap_settings.bear_hash_key[0], - self.uaid.hex) - msg = dict( - uaid=self.uaid.hex, - secret=hashed, - channelID=self.chid, - endpoint=endpoint_data[0], - ) + secret = generate_hash( + self.ap_settings.bear_hash_key[0], uaid.hex) + response.update(uaid=uaid.hex, secret=secret) # Apply any router specific fixes to the outbound response. - if router is not None: - msg = router.amend_msg(msg, - endpoint_data[1].get('router_data')) - else: - msg = dict(channelID=self.chid, endpoint=endpoint_data[0]) - self.write(json.dumps(msg)) - self.log.debug(format="Endpoint registered via HTTP", + router.amend_endpoint_response(response, router_data) + self.write(json.dumps(response)) + self.log.debug("Endpoint registered via HTTP", client_info=self._client_info) self.finish() - def _write_channels(self, channel_info, *args, **kwargs): - # channel_info is a tuple containing a flag and the list of channels - dashed = [str(uuid.UUID(x)) for x in channel_info[1]] - self.write(json.dumps( - {"uaid": self.uaid.hex, - "channelIDs": dashed} - )) + def _write_channels(self, channel_info, uaid): + # type: (Tuple[bool, Set[str]], uuid.UUID) -> None + response = dict( + uaid=uaid.hex, + channelIDs=[str(uuid.UUID(x)) for x in channel_info[1]] + ) + self.write(json.dumps(response)) self.finish() def _success(self, result): diff --git a/autopush/websocket.py b/autopush/websocket.py index 81d0da3f..93be35dc 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -817,7 +817,7 @@ def _check_other_nodes(self, result, url=DEFAULT_WS_ERR): needed""" self.transport.resumeProducing() - registered, previous, _ = result + registered, previous = result if not registered: # Registration failed msg = {"messageType": "hello", "reason": "already_connected",