Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
feat: add webpush topics
Browse files Browse the repository at this point in the history
Add's webpush topics with versioned sort key.

Closes #643
  • Loading branch information
bbangert committed Sep 26, 2016
1 parent 50d3143 commit 1ac2d7a
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 21 deletions.
52 changes: 46 additions & 6 deletions autopush/db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,34 @@
"""Database Interaction"""
"""Database Interaction
WebPush Sort Keys
-----------------
Messages for WebPush are stored using a partition key + sort key, originally
the sort key was:
CHID : Encrypted(UAID: CHID)
The encrypted portion was returned as the Location to the Application Server.
Decrypting it resulted in enough information to create the sort key so that
the message could be deleted and located again.
For WebPush Topic messages, a new scheme was needed since the only way to
locate the prior message is the UAID + CHID + Topic. Using Encryption in
the sort key is therefore not useful since it would change every update.
The sort key scheme for WebPush messages is:
VERSION : CHID : TOPIC
To ensure updated messages are not deleted, each message will still have an
update-id key/value in its item.
Non-versioned messages are assumed to be original messages from before this
scheme was adopted.
``VERSION`` is a 2-digit 0-padded number, starting at 01 for Topic messages.
"""
from __future__ import absolute_import

import datetime
Expand Down Expand Up @@ -263,6 +293,17 @@ def generate_last_connect():
return int(val)


def make_webpush_sort_key(channel_id, message_id=None, topic=None):
"""Create a webpush sort key"""
chid = normalize_id(channel_id)
if topic:
return "01:{chid}:{topic}".format(chid=chid, topic=topic)
elif message_id:
return "{chid}:{message_id}".format(chid=chid, message_id=message_id)
else:
raise Exception("Improper call, no message_id or topic provided.")


class Storage(object):
"""Create a Storage table abstraction on top of a DynamoDB Table object"""
def __init__(self, table, metrics):
Expand Down Expand Up @@ -418,21 +459,20 @@ def save_channels(self, uaid, channels):

@track_provisioned
def store_message(self, uaid, channel_id, message_id, ttl, data=None,
headers=None, timestamp=None):
headers=None, timestamp=None, topic=None):
"""Stores a message in the message table for the given uaid/channel
with the message id"""
item = dict(
uaid=hasher(uaid),
chidmessageid="%s:%s" % (normalize_id(channel_id), message_id),
chidmessageid=make_webpush_sort_key(channel_id,
message_id=message_id,
topic=topic),
data=data,
headers=headers,
ttl=ttl,
timestamp=timestamp or int(time.time()),
updateid=uuid.uuid4().hex
)
if data:
item["headers"] = headers
item["data"] = data
self.table.put_item(data=item, overwrite=True)
return True

Expand Down
37 changes: 30 additions & 7 deletions autopush/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@
import uuid
import re

from collections import namedtuple

import cyclone.web
from attr import attrs, attrib
from boto.dynamodb2.exceptions import (
ItemNotFound,
ProvisionedThroughputExceededException,
Expand Down Expand Up @@ -62,6 +61,7 @@

# Our max TTL is 60 days realistically with table rotation, so we hard-code it
MAX_TTL = 60 * 60 * 24 * 60
VALID_BASE64_URL = re.compile(r'^[0-9A-Za-z\-_=]*$')
VALID_TTL = re.compile(r'^\d+$')
AUTH_SCHEMES = ["bearer", "webpush"]
PREF_SCHEME = "webpush"
Expand All @@ -79,9 +79,15 @@
}


class Notification(namedtuple("Notification",
"version data channel_id headers ttl")):
@attrs
class Notification(object):
"""Parsed notification from the request"""
version = attrib()
data = attrib()
channel_id = attrib()
headers = attrib()
ttl = attrib()
topic = attrib(default=None)


def parse_request_params(request):
Expand Down Expand Up @@ -513,6 +519,7 @@ def _uaid_lookup_results(self, uaid_data):
# Only simplepush uses version/data out of body/query, GCM/APNS will
# use data out of the request body 'WebPush' style.
use_simplepush = router_key == "simplepush"
topic = self.request.headers.get("topic")
if use_simplepush:
self.version, data = parse_request_params(self.request)
self._client_info['message_id'] = self.version
Expand Down Expand Up @@ -547,6 +554,20 @@ def _uaid_lookup_results(self, uaid_data):
401, 110, message="Encryption header missing 'salt' value")
return

if topic:
if len(topic) > 32:
self._write_response(
400, 113, message="Topic must be no greater than 32 "
"characters"
)
return

if not VALID_BASE64_URL.match(topic):
self._write_response(
400, 113, message="Topic must be URL and Filename "
"safe Base64 alphabet"
)

if VALID_TTL.match(self.request.headers.get("ttl", "0")):
ttl = int(self.request.headers.get("ttl", "0"))
# Cap the TTL to our MAX_TTL
Expand Down Expand Up @@ -574,10 +595,11 @@ def _uaid_lookup_results(self, uaid_data):
# Generate a message ID, then route the notification.
d = deferToThread(self.ap_settings.fernet.encrypt, ':'.join([
'm', self.uaid, self.chid]).encode('utf8'))
d.addCallback(self._route_notification, uaid_data, data, ttl)
d.addCallback(self._route_notification, uaid_data, data, ttl, topic)
return d

def _route_notification(self, version, uaid_data, data, ttl=None):
def _route_notification(self, version, uaid_data, data, ttl=None,
topic=None):
self.version = self._client_info['message_id'] = version
warning = ""
# Clean up the header values (remove padding)
Expand All @@ -590,7 +612,8 @@ def _route_notification(self, version, uaid_data, data, ttl=None):
notification = Notification(version=version, data=data,
channel_id=self.chid,
headers=self.request.headers,
ttl=ttl)
ttl=ttl,
topic=topic)

d = Deferred()
d.addCallback(self.router.route_notification, uaid_data)
Expand Down
1 change: 1 addition & 0 deletions autopush/router/webpush.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def _save_notification(self, uaid, notification, month_table):
message_id=notification.version,
ttl=notification.ttl,
timestamp=int(time.time()),
topic=notification.topic,
)

def amend_msg(self, msg, router_data=None):
Expand Down
38 changes: 38 additions & 0 deletions autopush/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,44 @@ def handle_finish(value):
self.finish_deferred.addCallback(handle_finish)
return self.finish_deferred

def test_webpush_bad_topic_len(self):
fresult = dict(router_type="webpush")
frouter = self.settings.routers["webpush"]
frouter.route_notification.return_value = RouterResponse()
self.endpoint.chid = dummy_chid
self.request_mock.headers["topic"] = \
"asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf"
self.request_mock.body = b""
self.endpoint._uaid_lookup_results(fresult)

def handle_finish(value):
self.endpoint.set_status.assert_called_with(400, None)
self._check_error(code=400, errno=113,
message="Topic must be no greater than 32 "
"characters")

self.finish_deferred.addCallback(handle_finish)
return self.finish_deferred

def test_webpush_bad_topic_content(self):
fresult = dict(router_type="webpush")
frouter = self.settings.routers["webpush"]
frouter.route_notification.return_value = RouterResponse()
self.endpoint.chid = dummy_chid
self.request_mock.headers["topic"] = \
"asdf:442;23^@*#$(O!4232"
self.request_mock.body = b""
self.endpoint._uaid_lookup_results(fresult)

def handle_finish(value):
self.endpoint.set_status.assert_called_with(400, None)
self._check_error(code=400, errno=113,
message="Topic must be URL and Filename "
"safe Base64 alphabet")

self.finish_deferred.addCallback(handle_finish)
return self.finish_deferred

@patch('uuid.uuid4', return_value=uuid.UUID(dummy_request_id))
def test_init_info(self, t):
d = self.endpoint._init_info()
Expand Down
33 changes: 32 additions & 1 deletion autopush/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def delete_notification(self, channel, message=None, status=204):

def send_notification(self, channel=None, version=None, data=None,
use_header=True, status=None, ttl=200,
timeout=0.2, vapid=None, endpoint=None):
timeout=0.2, vapid=None, endpoint=None,
topic=None):
if not channel:
channel = random.choice(self.channels.keys())

Expand Down Expand Up @@ -218,6 +219,8 @@ def send_notification(self, channel=None, version=None, data=None,
headers.update({
'Crypto-Key': headers.get('Crypto-Key') + ';' + ckey
})
if topic:
headers["Topic"] = topic
body = data or ""
method = "POST"
status = status or 201
Expand Down Expand Up @@ -716,6 +719,34 @@ def test_basic_delivery(self):
eq_(result["messageType"], "notification")
yield self.shut_down(client)

@inlineCallbacks
def test_topic_basic_delivery(self):
data = str(uuid.uuid4())
client = yield self.quick_register(use_webpush=True)
result = yield client.send_notification(data=data, topic="Inbox")
eq_(result["headers"]["encryption"], client._crypto_key)
eq_(result["data"], base64url_encode(data))
eq_(result["messageType"], "notification")
yield self.shut_down(client)

@inlineCallbacks
def test_topic_replacement_delivery(self):
data = str(uuid.uuid4())
data2 = str(uuid.uuid4())
client = yield self.quick_register(use_webpush=True)
yield client.disconnect()
yield client.send_notification(data=data, topic="Inbox")
yield client.send_notification(data=data2, topic="Inbox")
yield client.connect()
yield client.hello()
result = yield client.get_notification()
eq_(result["headers"]["encryption"], client._crypto_key)
eq_(result["data"], base64url_encode(data2))
eq_(result["messageType"], "notification")
result = yield client.get_notification()
eq_(result, None)
yield self.shut_down(client)

@inlineCallbacks
def test_basic_delivery_v0_endpoint(self):
data = str(uuid.uuid4())
Expand Down
39 changes: 39 additions & 0 deletions autopush/tests/test_web_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,45 @@ def test_invalid_vapid_crypto_header(self):

eq_(cm.exception.status_code, 401)

def test_invalid_topic(self):
schema = self._make_fut()
schema.context["settings"].parse_endpoint.return_value = dict(
uaid=dummy_uaid,
chid=dummy_chid,
public_key="",
)
schema.context["settings"].router.get_uaid.return_value = dict(
router_type="webpush",
)

info = self._make_test_data(
headers={
"topic": "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdf",
}
)

with assert_raises(InvalidRequest) as cm:
schema.load(info)

eq_(cm.exception.status_code, 400)
eq_(cm.exception.errno, 113)
eq_(cm.exception.message,
"Topic must be no greater than 32 characters")

info = self._make_test_data(
headers={
"topic": "asdf??asdf::;f",
}
)

with assert_raises(InvalidRequest) as cm:
schema.load(info)

eq_(cm.exception.status_code, 400)
eq_(cm.exception.errno, 113)
eq_(cm.exception.message,
"Topic must be URL and Filename safe Base64 alphabet")


class TestWebPushRequestSchemaUsingVapid(unittest.TestCase):
def _make_fut(self):
Expand Down
12 changes: 9 additions & 3 deletions autopush/web/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import time
from collections import namedtuple

from attr import attrs, attrib
from boto.dynamodb2.exceptions import (
ProvisionedThroughputExceededException,
)
Expand All @@ -27,9 +27,15 @@
"#error-codes")


class Notification(namedtuple("Notification",
"version data channel_id headers ttl")):
@attrs
class Notification(object):
"""Parsed notification from the request"""
version = attrib()
data = attrib()
channel_id = attrib()
headers = attrib()
ttl = attrib()
topic = attrib(default=None)


class BaseWebHandler(BaseHandler):
Expand Down
17 changes: 17 additions & 0 deletions autopush/web/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
AUTH_SCHEMES = ["bearer", "webpush"]
PREF_SCHEME = "webpush"

# Base64 URL validation
VALID_BASE64_URL = re.compile(r'^[0-9A-Za-z\-_=]*$')


class ThreadedValidate(object):
"""A cyclone request validation decorator
Expand Down Expand Up @@ -234,8 +237,22 @@ class WebPushHeaderSchema(Schema):
encryption = fields.String()
encryption_key = fields.String(load_from="encryption-key")
ttl = fields.Integer(required=False, missing=None)
topic = fields.String(required=False, missing=None)
api_ver = fields.String()

@validates('topic')
def validate_topic(self, value):
if value is None:
return True

if len(value) > 32:
raise InvalidRequest("Topic must be no greater than 32 "
"characters", errno=113)

if not VALID_BASE64_URL.match(value):
raise InvalidRequest("Topic must be URL and Filename safe Base"
"64 alphabet", errno=113)

@validates_schema
def validate_cypto_headers(self, d):
# Not allowed to use aesgcm128 + a crypto_key
Expand Down
3 changes: 2 additions & 1 deletion autopush/web/webpush.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def post(self, api_ver="v1", token=None):
data=self.valid_input["body"],
channel_id=str(sub["chid"]),
headers=self.valid_input["headers"],
ttl=self.valid_input["headers"]["ttl"]
ttl=self.valid_input["headers"]["ttl"],
topic=self.valid_input["headers"]["topic"],
)
self._client_info["uaid"] = hasher(user_data.get("uaid"))
self._client_info["channel_id"] = user_data.get("chid")
Expand Down
Loading

0 comments on commit 1ac2d7a

Please sign in to comment.