Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support of the enhanced notification format #21

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 58 additions & 7 deletions apns.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
# SOFTWARE.

from binascii import a2b_hex, b2a_hex
from datetime import datetime
from datetime import datetime, timedelta
from time import mktime
from random import getrandbits
from socket import socket, AF_INET, SOCK_STREAM
from struct import pack, unpack

Expand Down Expand Up @@ -142,7 +144,7 @@ def write(self, string):


class PayloadAlert(object):
def __init__(self, body, action_loc_key=None, loc_key=None,
def __init__(self, body=None, action_loc_key=None, loc_key=None,
loc_args=None, launch_image=None):
super(PayloadAlert, self).__init__()
self.body = body
Expand All @@ -152,7 +154,9 @@ def __init__(self, body, action_loc_key=None, loc_key=None,
self.launch_image = launch_image

def dict(self):
d = { 'body': self.body }
d = {}
if self.body:
d['body'] = self.body
if self.action_loc_key:
d['action-loc-key'] = self.action_loc_key
if self.loc_key:
Expand Down Expand Up @@ -263,6 +267,30 @@ def items(self):
# some more data and append to buffer
break

class UnknownResponse(Exception):
def __init__(self):
super(UnknownResponse, self).__init__()

class UnknownError(Exception):
def __init__(self):
super(UnknownError, self).__init__()

class ProcessingError(Exception):
def __init__(self):
super(ProcessingError, self).__init__()

class InvalidTokenSizeError(Exception):
def __init__(self):
super(InvalidTokenSizeError, self).__init__()

class InvalidPayloadSizeError(Exception):
def __init__(self):
super(InvalidPayloadSizeError, self).__init__()

class InvalidTokenError(Exception):
def __init__(self):
super(InvalidTokenError, self).__init__()

class GatewayConnection(APNsConnection):
"""
A class that represents a connection to the APNs gateway server
Expand All @@ -274,21 +302,44 @@ def __init__(self, use_sandbox=False, **kwargs):
'gateway.sandbox.push.apple.com')[use_sandbox]
self.port = 2195

def _get_notification(self, token_hex, payload):
def _get_notification(self, token_hex, payload, identifier, expiry):
"""
Takes a token as a hex string and a payload as a Python dict and sends
the notification
"""
identifier_bin = identifier[:4]
expiry_bin = APNs.packed_uint_big_endian(int(mktime(expiry.timetuple())))
token_bin = a2b_hex(token_hex)
token_length_bin = APNs.packed_ushort_big_endian(len(token_bin))
payload_json = payload.json()
payload_length_bin = APNs.packed_ushort_big_endian(len(payload_json))

notification = ('\0' + token_length_bin + token_bin
notification = ('\x01' + identifier_bin + expiry_bin + token_length_bin + token_bin
+ payload_length_bin + payload_json)

return notification

def send_notification(self, token_hex, payload):
self.write(self._get_notification(token_hex, payload))
def send_notification(self, token_hex, payload, expiry=None):
if expiry is None:
expiry = datetime.now() + timedelta(30)

identifier = pack('>I', getrandbits(32))
self.write(self._get_notification(token_hex, payload, identifier, expiry))

error_response = self.read(6)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can read, Apple will not send anything back when the push message is accepted. Thus, we will block here indefinitely. Seems consistent with what I'm seeing as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. We should set the socket to non-blocking mode, or with a very short timeout.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the API should be asynchronous here. I'm not confortable with either waiting for a short timeout period or setting the socket to non-blocking mode. In either case you may end up getting an exception for the wrong notification. IMHO the appropriate solution would be to have a separate thread/process read from the socket (in blocking mode, if necessary forever) and execute a callback function when an error response is read.

if error_response != '':
command = error_response[0]
status = ord(error_response[1])
response_identifier = error_response[2:6]

if command != '\x08' or response_identifier != identifier:
raise UnknownResponse()

if status == 0:
return

raise {1: ProcessingError,
5: InvalidTokenSizeError,
7: InvalidPayloadSizeError,
8: InvalidTokenError}.get(status, UnknownError)()

11 changes: 8 additions & 3 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from apns import *
from binascii import a2b_hex
from random import random
from datetime import datetime

import hashlib
import os
Expand Down Expand Up @@ -77,18 +78,22 @@ def testGatewayServer(self):
sound = "default",
badge = 4
)
notification = gateway_server._get_notification(token_hex, payload)
identifier = 'abcd'
expiry = datetime(2000, 01, 01, 00, 00, 00)
notification = gateway_server._get_notification(token_hex, payload, identifier, expiry)

expected_length = (
1 # leading null byte
1 # leading command byte
+ 4 # Identifier as a 4 bytes buffer
+ 4 # Expiry timestamp as a packed integer
+ 2 # length of token as a packed short
+ len(token_hex) / 2 # length of token as binary string
+ 2 # length of payload as a packed short
+ len(payload.json()) # length of JSON-formatted payload
)

self.assertEqual(len(notification), expected_length)
self.assertEqual(notification[0], '\0')
self.assertEqual(notification[0], '\x01') # Enhanched format command byte

def testFeedbackServer(self):
pem_file = TEST_CERTIFICATE
Expand Down