-
Notifications
You must be signed in to change notification settings - Fork 30
feat: Add multiple cert handlers for APNs #660
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
"""APNS Router""" | ||
import time | ||
import uuid | ||
|
||
import apns | ||
from twisted.logger import Logger | ||
from twisted.internet.threads import deferToThread | ||
|
||
from autopush.router.interface import RouterException, RouterResponse | ||
|
||
|
||
|
@@ -27,86 +27,159 @@ class APNSRouter(object): | |
255: 'Unknown', | ||
} | ||
|
||
def _connect(self): | ||
"""Connect to APNS""" | ||
self.apns = apns.APNs(use_sandbox=self.config.get("sandbox", False), | ||
cert_file=self.config.get("cert_file"), | ||
key_file=self.config.get("key_file"), | ||
enhanced=True) | ||
def _connect(self, cert_info): | ||
"""Connect to APNS | ||
|
||
:param cert_info: APNS certificate configuration info | ||
:type cert_info: dict | ||
|
||
:returns: APNs to be stored under the proper release channel name. | ||
:rtype: apns.APNs | ||
|
||
""" | ||
# Do I still need to call this in _error? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess not. 😄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so, yeah, remember my cryptic comment about calling self._connect() before? ... |
||
return apns.APNs( | ||
use_sandbox=cert_info.get("sandbox", False), | ||
cert_file=cert_info.get("cert"), | ||
key_file=cert_info.get("key"), | ||
enhanced=True) | ||
|
||
def __init__(self, ap_settings, router_conf): | ||
"""Create a new APNS router and connect to APNS""" | ||
self.ap_settings = ap_settings | ||
self._base_tags = [] | ||
self.config = router_conf | ||
self.default_title = router_conf.get("default_title", "SimplePush") | ||
self.default_body = router_conf.get("default_body", "New Alert") | ||
self._connect() | ||
self.log.debug("Starting APNS router...") | ||
self.apns = dict() | ||
self.messages = dict() | ||
self._config = router_conf | ||
self._max_messages = self._config.pop('max_messages', 100) | ||
for rel_channel in self._config: | ||
self.apns[rel_channel] = self._connect(self._config[rel_channel]) | ||
self.apns[rel_channel].gateway_server.register_response_listener( | ||
self._error) | ||
self.ap_settings = ap_settings | ||
self.log.debug("Starting APNS router...") | ||
|
||
def register(self, uaid, router_data, app_id, *args, **kwargs): | ||
"""Register an endpoint for APNS, on the `app_id` release channel. | ||
|
||
def register(self, uaid, router_data, *args, **kwargs): | ||
"""Validate that an APNs instance token is in the ``router_data``""" | ||
This will validate that an APNs instance token is in the | ||
``router_data``, | ||
|
||
:param uaid: User Agent Identifier | ||
:type uaid: str | ||
:param router_data: Dict containing router specific configuration info | ||
:type router_data: dict | ||
:param app_id: The release channel identifier for cert info lookup | ||
:type app_id: str | ||
|
||
:returns: a modified router_data for the user agent record. | ||
:rtype: dict | ||
|
||
|
||
""" | ||
if app_id not in self.apns: | ||
raise RouterException("Unknown release channel specified", | ||
status_code=400, | ||
response_body="Unknown release channel") | ||
if not router_data.get("token"): | ||
raise RouterException("No token registered", status_code=500, | ||
response_body="No token registered") | ||
router_data["rel_channel"] = app_id | ||
return router_data | ||
|
||
def amend_msg(self, msg, router_data=None): | ||
"""This function is stubbed out for this router""" | ||
return msg | ||
|
||
def route_notification(self, notification, uaid_data): | ||
"""Start the APNS notification routing, returns a deferred""" | ||
"""Start the APNS notification routing, returns a deferred | ||
|
||
:param notification: Notification data to send | ||
:type notification: dict | ||
:param uaid_data: User Agent specific data | ||
:type uaid_data: dict | ||
|
||
""" | ||
router_data = uaid_data["router_data"] | ||
# Kick the entire notification routing off to a thread | ||
return deferToThread(self._route, notification, router_data) | ||
|
||
def _route(self, notification, router_data): | ||
"""Blocking APNS call to route the notification""" | ||
token = router_data["token"] | ||
"""Blocking APNS call to route the notification | ||
|
||
:param notification: Notification data to send | ||
:type notification: dict | ||
:param router_data: Pre-initialized data for this connection | ||
:type router_data: dict | ||
|
||
""" | ||
router_token = router_data["token"] | ||
rel_channel = router_data["rel_channel"] | ||
config = self._config[rel_channel] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure hope we never, ever, change valid release channel names. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we do, thumbscrews. Fortunately ops has the ability to generate more certs. They'll just have to change the app_id they're sending in to use the new cert pair. |
||
if len(self.messages) >= self._max_messages: | ||
raise RouterException("Too many messages in pending queue", | ||
status_code=503, | ||
response_body="Pending buffer full", | ||
) | ||
apns_client = self.apns[rel_channel] | ||
custom = { | ||
"Chid": notification.channel_id, | ||
"Ver": notification.version, | ||
"chid": notification.channel_id, | ||
"ver": notification.version, | ||
} | ||
if notification.data: | ||
custom["Msg"] = notification.data | ||
custom["Con"] = notification.headers["content-encoding"] | ||
custom["Enc"] = notification.headers["encryption"] | ||
custom["body"] = notification.data | ||
custom["con"] = notification.headers["content-encoding"] | ||
custom["enc"] = notification.headers["encryption"] | ||
|
||
if "crypto-key" in notification.headers: | ||
custom["Cryptokey"] = notification.headers["crypto-key"] | ||
custom["cryptokey"] = notification.headers["crypto-key"] | ||
elif "encryption-key" in notification.headers: | ||
custom["Enckey"] = notification.headers["encryption-key"] | ||
|
||
payload = apns.Payload(alert=router_data.get("title", | ||
self.default_title), | ||
content_available=1, | ||
custom=custom) | ||
now = int(time.time()) | ||
self.messages[now] = {"token": token, "payload": payload} | ||
# TODO: Add listener for error handling. | ||
self.apns.gateway_server.register_response_listener(self._error) | ||
self.ap_settings.metrics.increment( | ||
"updates.client.bridge.apns.attempted", | ||
self._base_tags) | ||
custom["enckey"] = notification.headers["encryption-key"] | ||
|
||
self.apns.gateway_server.send_notification(token, payload, now) | ||
payload = apns.Payload( | ||
alert=router_data.get("title", config.get('default_title', | ||
'Mozilla Push')), | ||
content_available=1, | ||
custom=custom) | ||
now = time.time() | ||
|
||
# cleanup sent messages | ||
if self.messages: | ||
for time_sent in self.messages.keys(): | ||
if time_sent < now - self.config.get("expry", 10): | ||
del self.messages[time_sent] | ||
self.ap_settings.metrics.increment( | ||
"updates.client.bridge.apns.succeed", | ||
self._base_tags) | ||
# "apns-id" | ||
msg_id = str(uuid.uuid4()) | ||
self.messages[msg_id] = { | ||
"time_sent": now, | ||
"rel_channel": router_data["rel_channel"], | ||
"router_token": router_token, | ||
"payload": payload} | ||
|
||
apns_client.gateway_server.send_notification(router_token, payload, | ||
msg_id) | ||
location = "%s/m/%s" % (self.ap_settings.endpoint_url, | ||
notification.version) | ||
self.ap_settings.metrics.increment( | ||
"updates.client.bridge.apns.%s.sent" % | ||
router_data["rel_channel"], | ||
self._base_tags) | ||
return RouterResponse(status_code=201, response_body="", | ||
headers={"TTL": notification.ttl, | ||
"Location": location}, | ||
logged_status=200) | ||
|
||
def _cleanup(self): | ||
"""clean up pending, but expired messages. | ||
|
||
APNs may not always respond with a status code, this will clean out | ||
pending retryable messages. | ||
|
||
""" | ||
for msg_id in self.messages.keys(): | ||
message = self.messages[msg_id] | ||
expry = self._config[message['rel_channel']].get("expry", 10) | ||
if message["time_sent"] < time.time() - expry: | ||
try: | ||
del self.messages[msg_id] | ||
except KeyError: # pragma nocover | ||
pass | ||
|
||
def _error(self, err): | ||
"""Error handler""" | ||
if err['status'] == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That shouldn't happen now that cleanup is run off the event loop. |
||
|
@@ -117,11 +190,11 @@ def _error(self, err): | |
status=self.errors[err['status']]) | ||
if err['status'] in [1, 255]: | ||
self.log.debug("Retrying...") | ||
self._connect() | ||
resend = self.messages.get(err.get('identifier')) | ||
if resend is None: | ||
return | ||
self.apns.gateway_server.send_notification(resend['token'], | ||
resend['payload'], | ||
err['identifier'], | ||
) | ||
apns_client = self.apns[resend["rel_channel"]] | ||
apns_client.gateway_server.send_notification(resend['token'], | ||
resend['payload'], | ||
err['identifier'], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to get it, a key check should be fine. "if 'apns' in ...."