From 1ac2d7aba8da7c3482eef192694efa056e90253d Mon Sep 17 00:00:00 2001 From: Ben Bangert Date: Sun, 25 Sep 2016 19:08:30 -0700 Subject: [PATCH] feat: add webpush topics Add's webpush topics with versioned sort key. Closes #643 --- autopush/db.py | 52 +++++++++++++++++++++++---- autopush/endpoint.py | 37 +++++++++++++++---- autopush/router/webpush.py | 1 + autopush/tests/test_endpoint.py | 38 ++++++++++++++++++++ autopush/tests/test_integration.py | 33 ++++++++++++++++- autopush/tests/test_web_validation.py | 39 ++++++++++++++++++++ autopush/web/base.py | 12 +++++-- autopush/web/validation.py | 17 +++++++++ autopush/web/webpush.py | 3 +- autopush/websocket.py | 21 +++++++++-- base-requirements.txt | 2 +- docs/http.rst | 1 + 12 files changed, 235 insertions(+), 21 deletions(-) diff --git a/autopush/db.py b/autopush/db.py index 55853787..e3e99bd9 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -1,4 +1,34 @@ -"""Database Interaction""" +"""Database Interaction + +WebPush Sort Keys +----------------- + +Messages for WebPush are stored using a partition key + sort key, originally +the sort key was: + + CHID : Encrypted(UAID: CHID) + +The encrypted portion was returned as the Location to the Application Server. +Decrypting it resulted in enough information to create the sort key so that +the message could be deleted and located again. + +For WebPush Topic messages, a new scheme was needed since the only way to +locate the prior message is the UAID + CHID + Topic. Using Encryption in +the sort key is therefore not useful since it would change every update. + +The sort key scheme for WebPush messages is: + + VERSION : CHID : TOPIC + +To ensure updated messages are not deleted, each message will still have an +update-id key/value in its item. + +Non-versioned messages are assumed to be original messages from before this +scheme was adopted. + +``VERSION`` is a 2-digit 0-padded number, starting at 01 for Topic messages. + +""" from __future__ import absolute_import import datetime @@ -263,6 +293,17 @@ def generate_last_connect(): return int(val) +def make_webpush_sort_key(channel_id, message_id=None, topic=None): + """Create a webpush sort key""" + chid = normalize_id(channel_id) + if topic: + return "01:{chid}:{topic}".format(chid=chid, topic=topic) + elif message_id: + return "{chid}:{message_id}".format(chid=chid, message_id=message_id) + else: + raise Exception("Improper call, no message_id or topic provided.") + + class Storage(object): """Create a Storage table abstraction on top of a DynamoDB Table object""" def __init__(self, table, metrics): @@ -418,21 +459,20 @@ def save_channels(self, uaid, channels): @track_provisioned def store_message(self, uaid, channel_id, message_id, ttl, data=None, - headers=None, timestamp=None): + headers=None, timestamp=None, topic=None): """Stores a message in the message table for the given uaid/channel with the message id""" item = dict( uaid=hasher(uaid), - chidmessageid="%s:%s" % (normalize_id(channel_id), message_id), + chidmessageid=make_webpush_sort_key(channel_id, + message_id=message_id, + topic=topic), data=data, headers=headers, ttl=ttl, timestamp=timestamp or int(time.time()), updateid=uuid.uuid4().hex ) - if data: - item["headers"] = headers - item["data"] = data self.table.put_item(data=item, overwrite=True) return True diff --git a/autopush/endpoint.py b/autopush/endpoint.py index ab98e370..f893b265 100644 --- a/autopush/endpoint.py +++ b/autopush/endpoint.py @@ -27,9 +27,8 @@ import uuid import re -from collections import namedtuple - import cyclone.web +from attr import attrs, attrib from boto.dynamodb2.exceptions import ( ItemNotFound, ProvisionedThroughputExceededException, @@ -62,6 +61,7 @@ # Our max TTL is 60 days realistically with table rotation, so we hard-code it MAX_TTL = 60 * 60 * 24 * 60 +VALID_BASE64_URL = re.compile(r'^[0-9A-Za-z\-_=]*$') VALID_TTL = re.compile(r'^\d+$') AUTH_SCHEMES = ["bearer", "webpush"] PREF_SCHEME = "webpush" @@ -79,9 +79,15 @@ } -class Notification(namedtuple("Notification", - "version data channel_id headers ttl")): +@attrs +class Notification(object): """Parsed notification from the request""" + version = attrib() + data = attrib() + channel_id = attrib() + headers = attrib() + ttl = attrib() + topic = attrib(default=None) def parse_request_params(request): @@ -513,6 +519,7 @@ def _uaid_lookup_results(self, uaid_data): # Only simplepush uses version/data out of body/query, GCM/APNS will # use data out of the request body 'WebPush' style. use_simplepush = router_key == "simplepush" + topic = self.request.headers.get("topic") if use_simplepush: self.version, data = parse_request_params(self.request) self._client_info['message_id'] = self.version @@ -547,6 +554,20 @@ def _uaid_lookup_results(self, uaid_data): 401, 110, message="Encryption header missing 'salt' value") return + if topic: + if len(topic) > 32: + self._write_response( + 400, 113, message="Topic must be no greater than 32 " + "characters" + ) + return + + if not VALID_BASE64_URL.match(topic): + self._write_response( + 400, 113, message="Topic must be URL and Filename " + "safe Base64 alphabet" + ) + if VALID_TTL.match(self.request.headers.get("ttl", "0")): ttl = int(self.request.headers.get("ttl", "0")) # Cap the TTL to our MAX_TTL @@ -574,10 +595,11 @@ def _uaid_lookup_results(self, uaid_data): # Generate a message ID, then route the notification. d = deferToThread(self.ap_settings.fernet.encrypt, ':'.join([ 'm', self.uaid, self.chid]).encode('utf8')) - d.addCallback(self._route_notification, uaid_data, data, ttl) + d.addCallback(self._route_notification, uaid_data, data, ttl, topic) return d - def _route_notification(self, version, uaid_data, data, ttl=None): + def _route_notification(self, version, uaid_data, data, ttl=None, + topic=None): self.version = self._client_info['message_id'] = version warning = "" # Clean up the header values (remove padding) @@ -590,7 +612,8 @@ def _route_notification(self, version, uaid_data, data, ttl=None): notification = Notification(version=version, data=data, channel_id=self.chid, headers=self.request.headers, - ttl=ttl) + ttl=ttl, + topic=topic) d = Deferred() d.addCallback(self.router.route_notification, uaid_data) diff --git a/autopush/router/webpush.py b/autopush/router/webpush.py index 816747a5..bdbd1fee 100644 --- a/autopush/router/webpush.py +++ b/autopush/router/webpush.py @@ -158,6 +158,7 @@ def _save_notification(self, uaid, notification, month_table): message_id=notification.version, ttl=notification.ttl, timestamp=int(time.time()), + topic=notification.topic, ) def amend_msg(self, msg, router_data=None): diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index d6f294c9..c2aec707 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -457,6 +457,44 @@ def handle_finish(value): self.finish_deferred.addCallback(handle_finish) return self.finish_deferred + def test_webpush_bad_topic_len(self): + fresult = dict(router_type="webpush") + frouter = self.settings.routers["webpush"] + frouter.route_notification.return_value = RouterResponse() + self.endpoint.chid = dummy_chid + self.request_mock.headers["topic"] = \ + "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf" + self.request_mock.body = b"" + self.endpoint._uaid_lookup_results(fresult) + + def handle_finish(value): + self.endpoint.set_status.assert_called_with(400, None) + self._check_error(code=400, errno=113, + message="Topic must be no greater than 32 " + "characters") + + self.finish_deferred.addCallback(handle_finish) + return self.finish_deferred + + def test_webpush_bad_topic_content(self): + fresult = dict(router_type="webpush") + frouter = self.settings.routers["webpush"] + frouter.route_notification.return_value = RouterResponse() + self.endpoint.chid = dummy_chid + self.request_mock.headers["topic"] = \ + "asdf:442;23^@*#$(O!4232" + self.request_mock.body = b"" + self.endpoint._uaid_lookup_results(fresult) + + def handle_finish(value): + self.endpoint.set_status.assert_called_with(400, None) + self._check_error(code=400, errno=113, + message="Topic must be URL and Filename " + "safe Base64 alphabet") + + self.finish_deferred.addCallback(handle_finish) + return self.finish_deferred + @patch('uuid.uuid4', return_value=uuid.UUID(dummy_request_id)) def test_init_info(self, t): d = self.endpoint._init_info() diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index d1fc031f..30c72f7f 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -187,7 +187,8 @@ def delete_notification(self, channel, message=None, status=204): def send_notification(self, channel=None, version=None, data=None, use_header=True, status=None, ttl=200, - timeout=0.2, vapid=None, endpoint=None): + timeout=0.2, vapid=None, endpoint=None, + topic=None): if not channel: channel = random.choice(self.channels.keys()) @@ -218,6 +219,8 @@ def send_notification(self, channel=None, version=None, data=None, headers.update({ 'Crypto-Key': headers.get('Crypto-Key') + ';' + ckey }) + if topic: + headers["Topic"] = topic body = data or "" method = "POST" status = status or 201 @@ -716,6 +719,34 @@ def test_basic_delivery(self): eq_(result["messageType"], "notification") yield self.shut_down(client) + @inlineCallbacks + def test_topic_basic_delivery(self): + data = str(uuid.uuid4()) + client = yield self.quick_register(use_webpush=True) + result = yield client.send_notification(data=data, topic="Inbox") + eq_(result["headers"]["encryption"], client._crypto_key) + eq_(result["data"], base64url_encode(data)) + eq_(result["messageType"], "notification") + yield self.shut_down(client) + + @inlineCallbacks + def test_topic_replacement_delivery(self): + data = str(uuid.uuid4()) + data2 = str(uuid.uuid4()) + client = yield self.quick_register(use_webpush=True) + yield client.disconnect() + yield client.send_notification(data=data, topic="Inbox") + yield client.send_notification(data=data2, topic="Inbox") + yield client.connect() + yield client.hello() + result = yield client.get_notification() + eq_(result["headers"]["encryption"], client._crypto_key) + eq_(result["data"], base64url_encode(data2)) + eq_(result["messageType"], "notification") + result = yield client.get_notification() + eq_(result, None) + yield self.shut_down(client) + @inlineCallbacks def test_basic_delivery_v0_endpoint(self): data = str(uuid.uuid4()) diff --git a/autopush/tests/test_web_validation.py b/autopush/tests/test_web_validation.py index 4e1e11bc..aed5cb8b 100644 --- a/autopush/tests/test_web_validation.py +++ b/autopush/tests/test_web_validation.py @@ -528,6 +528,45 @@ def test_invalid_vapid_crypto_header(self): eq_(cm.exception.status_code, 401) + def test_invalid_topic(self): + schema = self._make_fut() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + schema.context["settings"].router.get_uaid.return_value = dict( + router_type="webpush", + ) + + info = self._make_test_data( + headers={ + "topic": "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdf", + } + ) + + with assert_raises(InvalidRequest) as cm: + schema.load(info) + + eq_(cm.exception.status_code, 400) + eq_(cm.exception.errno, 113) + eq_(cm.exception.message, + "Topic must be no greater than 32 characters") + + info = self._make_test_data( + headers={ + "topic": "asdf??asdf::;f", + } + ) + + with assert_raises(InvalidRequest) as cm: + schema.load(info) + + eq_(cm.exception.status_code, 400) + eq_(cm.exception.errno, 113) + eq_(cm.exception.message, + "Topic must be URL and Filename safe Base64 alphabet") + class TestWebPushRequestSchemaUsingVapid(unittest.TestCase): def _make_fut(self): diff --git a/autopush/web/base.py b/autopush/web/base.py index a19118c4..e5e6957e 100644 --- a/autopush/web/base.py +++ b/autopush/web/base.py @@ -1,7 +1,7 @@ import json import time -from collections import namedtuple +from attr import attrs, attrib from boto.dynamodb2.exceptions import ( ProvisionedThroughputExceededException, ) @@ -27,9 +27,15 @@ "#error-codes") -class Notification(namedtuple("Notification", - "version data channel_id headers ttl")): +@attrs +class Notification(object): """Parsed notification from the request""" + version = attrib() + data = attrib() + channel_id = attrib() + headers = attrib() + ttl = attrib() + topic = attrib(default=None) class BaseWebHandler(BaseHandler): diff --git a/autopush/web/validation.py b/autopush/web/validation.py index 2756fb87..089f2068 100644 --- a/autopush/web/validation.py +++ b/autopush/web/validation.py @@ -35,6 +35,9 @@ AUTH_SCHEMES = ["bearer", "webpush"] PREF_SCHEME = "webpush" +# Base64 URL validation +VALID_BASE64_URL = re.compile(r'^[0-9A-Za-z\-_=]*$') + class ThreadedValidate(object): """A cyclone request validation decorator @@ -234,8 +237,22 @@ class WebPushHeaderSchema(Schema): encryption = fields.String() encryption_key = fields.String(load_from="encryption-key") ttl = fields.Integer(required=False, missing=None) + topic = fields.String(required=False, missing=None) api_ver = fields.String() + @validates('topic') + def validate_topic(self, value): + if value is None: + return True + + if len(value) > 32: + raise InvalidRequest("Topic must be no greater than 32 " + "characters", errno=113) + + if not VALID_BASE64_URL.match(value): + raise InvalidRequest("Topic must be URL and Filename safe Base" + "64 alphabet", errno=113) + @validates_schema def validate_cypto_headers(self, d): # Not allowed to use aesgcm128 + a crypto_key diff --git a/autopush/web/webpush.py b/autopush/web/webpush.py index 8b48a8a4..adcebbc6 100644 --- a/autopush/web/webpush.py +++ b/autopush/web/webpush.py @@ -45,7 +45,8 @@ def post(self, api_ver="v1", token=None): data=self.valid_input["body"], channel_id=str(sub["chid"]), headers=self.valid_input["headers"], - ttl=self.valid_input["headers"]["ttl"] + ttl=self.valid_input["headers"]["ttl"], + topic=self.valid_input["headers"]["topic"], ) self._client_info["uaid"] = hasher(user_data.get("uaid")) self._client_info["channel_id"] = user_data.get("chid") diff --git a/autopush/websocket.py b/autopush/websocket.py index 42688411..dfa4998e 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -36,6 +36,7 @@ from functools import wraps from random import randrange +import re from autobahn.twisted.websocket import WebSocketServerProtocol from boto.dynamodb2.exceptions import ( ProvisionedThroughputExceededException, @@ -918,8 +919,7 @@ def finish_webpush_notifications(self, notifs): # Send out all the notifications now = int(time.time()) for notif in notifs: - # Split off the chid and message id - chid, version = notif["chidmessageid"].split(":") + chid, version = self._parse_notification_db_key(notif) # If the TTL is too old, don't deliver and fire a delete off if not notif["ttl"] or now >= (notif["ttl"]+notif["timestamp"]): @@ -944,6 +944,23 @@ def finish_webpush_notifications(self, notifs): ) self.sendJSON(msg) + def _parse_notification_db_key(self, notif): + """Parses a notification sort key into a chid, version + + chid is an informative value for the client matching the channel id + while version is an aproppriate value such that the server can verify + the appropriate message is being ack'd that has not been modified. + + """ + # If this is versioned, process it appropriately + if re.match(r'^\d\d:', notif["chidmessageid"]): + # We only know 01 at the moment, so we split appropriately + _, chid, version = notif["chidmessageid"].split(":") + return chid, version + else: + # Split off the chid and message id + return notif["chidmessageid"].split(":") + def _rotate_message_table(self): """Function to fire off a message table copy of channels + update the router current_month entry""" diff --git a/base-requirements.txt b/base-requirements.txt index de42435b..127f1b0c 100644 --- a/base-requirements.txt +++ b/base-requirements.txt @@ -32,7 +32,7 @@ idna==2.1 ipaddress==1.0.16 itsdangerous==0.24 jmespath==0.9.0 -marshmallow==2.9.1 +marshmallow==2.10.2 mccabe==0.5.2 pbr==1.10.0 pluggy==0.3.1 diff --git a/docs/http.rst b/docs/http.rst index 64c84308..0fa612ba 100644 --- a/docs/http.rst +++ b/docs/http.rst @@ -93,6 +93,7 @@ Unless otherwise specified, all calls return the following error codes: - Missing Crypto Headers - Include the appropriate encryption headers (`WebPush Encryption §3.2 `_ and `WebPush VAPID §4 `_) - errno 112 - Invalid TTL header value + - errno 113 - Invalid Topic header value - 401 - Bad Authorization