From a2330ae65c7b849447b8e87cf0c3887ecf0d2a86 Mon Sep 17 00:00:00 2001 From: Josh Chung Date: Wed, 24 Oct 2012 19:01:35 +0900 Subject: [PATCH 1/5] adding enhanced format notification --- apns.py | 136 ++++++++++++++++++++++++++++++++++++++++++++------ apnserrors.py | 44 ++++++++++++++++ tests.py | 32 ++++++++++++ 3 files changed, 196 insertions(+), 16 deletions(-) create mode 100644 apnserrors.py diff --git a/apns.py b/apns.py index fdcec08..bd44a71 100644 --- a/apns.py +++ b/apns.py @@ -24,12 +24,16 @@ # SOFTWARE. from binascii import a2b_hex, b2a_hex -from datetime import datetime -from socket import socket, AF_INET, SOCK_STREAM +from datetime import datetime, timedelta +from time import mktime +from socket import socket, AF_INET, SOCK_STREAM, timeout from struct import pack, unpack +import select + try: from ssl import wrap_socket + from ssl import SSLError, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE except ImportError: from socket import ssl as wrap_socket @@ -38,12 +42,16 @@ except ImportError: import simplejson as json +from apnserrors import * + MAX_PAYLOAD_LENGTH = 256 +TIMEOUT = 60 +ERROR_RESPONSE_LENGTH = 6 class APNs(object): """A class representing an Apple Push Notification service connection""" - def __init__(self, use_sandbox=False, cert_file=None, key_file=None): + def __init__(self, use_sandbox=False, cert_file=None, key_file=None, enhanced=False): """ Set use_sandbox to True to use the sandbox (test) APNs servers. Default is False. @@ -52,9 +60,17 @@ def __init__(self, use_sandbox=False, cert_file=None, key_file=None): self.use_sandbox = use_sandbox self.cert_file = cert_file self.key_file = key_file + self.enhanced = enhanced self._feedback_connection = None self._gateway_connection = None + @staticmethod + def unpacked_uchar_big_endian(byte): + """ + Returns an unsigned char from a packed big-endian (network) byte + """ + return unpack('>B', byte)[0] + @staticmethod def packed_ushort_big_endian(num): """ @@ -100,7 +116,8 @@ def gateway_server(self): self._gateway_connection = GatewayConnection( use_sandbox = self.use_sandbox, cert_file = self.cert_file, - key_file = self.key_file + key_file = self.key_file, + enhanced = self.enhanced ) return self._gateway_connection @@ -109,10 +126,11 @@ class APNsConnection(object): """ A generic connection class for communicating with the APNs """ - def __init__(self, cert_file=None, key_file=None): + def __init__(self, cert_file=None, key_file=None, enhanced=False): super(APNsConnection, self).__init__() self.cert_file = cert_file self.key_file = key_file + self.enhanced = enhanced self._socket = None self._ssl = None @@ -123,11 +141,29 @@ def _connect(self): # Establish an SSL connection self._socket = socket(AF_INET, SOCK_STREAM) self._socket.connect((self.server, self.port)) - self._ssl = wrap_socket(self._socket, self.key_file, self.cert_file) + + if self.enhanced: + self._ssl = wrap_socket(self._socket, self.key_file, self.cert_file, + do_handshake_on_connect=False) + self._ssl.setblocking(0) + while True: + try: + self._ssl.do_handshake() + break + except SSLError, err: + if SSL_ERROR_WANT_READ == err.args[0]: + select.select([self._ssl], [], []) + elif SSL_ERROR_WANT_WRITE == err.args[0]: + select.select([], [self._ssl], []) + else: + raise + else: + self._ssl = wrap_socket(self._socket, self.key_file, self.cert_file) def _disconnect(self): if self._socket: self._socket.close() + self._ssl = None def _connection(self): if not self._ssl: @@ -135,11 +171,59 @@ def _connection(self): return self._ssl def read(self, n=None): - return self._connection().read(n) + return self._connection().recv(n) + + def recvall(self, n): + data = "" + while True: + more = self._connection().recv(n - len(data)) + data += more + if len(data) >= n: + break + rlist, _, _ = select.select([self._connection()], [], [], TIMEOUT) + if not rlist: + raise timeout + return data + def write(self, string): - return self._connection().write(string) - + if self.enhanced: # nonblocking socket + rlist, _, _ = select.select([self._connection()], [], [], 0) + + if rlist: # there's error response from APNs + buff = self.recvall(ERROR_RESPONSE_LENGTH) + if len(buff) != ERROR_RESPONSE_LENGTH: + return None + + command = APNs.unpacked_uchar_big_endian(buff[0]) + + if 8 != command: + self._disconnect() + raise UnknownError(0) + + status = APNs.unpacked_uchar_big_endian(buff[1]) + identifier = APNs.unpacked_uint_big_endian(buff[2:6]) + + self._disconnect() + + raise { 1: ProcessingError, + 2: MissingDeviceTokenError, + 3: MissingTopicError, + 4: MissingPayloadError, + 5: InvalidTokenSizeError, + 6: InvalidTopicSizeError, + 7: InvalidPayloadSizeError, + 8: InvalidTokenError }.get(status, UnknownError)(identifier) + + _, wlist, _ = select.select([], [self._connection()], [], TIMEOUT) + if wlist: + return self._connection().sendall(string) + else: + self._disconnect() + raise timeout + + else: # not-enhanced format using blocking socket + return self._connection().sendall(string) class PayloadAlert(object): def __init__(self, body, action_loc_key=None, loc_key=None, @@ -163,10 +247,6 @@ def dict(self): d['launch-image'] = self.launch_image return d -class PayloadTooLargeError(Exception): - def __init__(self): - super(PayloadTooLargeError, self).__init__() - class Payload(object): """A class representing an APNs message payload""" def __init__(self, alert=None, badge=None, sound=None, custom={}): @@ -285,10 +365,34 @@ def _get_notification(self, token_hex, payload): payload_length_bin = APNs.packed_ushort_big_endian(len(payload_json)) notification = ('\0' + token_length_bin + token_bin - + payload_length_bin + payload_json) + + payload_length_bin + payload_json) return notification - def send_notification(self, token_hex, payload): - self.write(self._get_notification(token_hex, payload)) + def _get_enhanced_notification(self, token_hex, payload, identifier, expiry): + """ + Takes a token as a hex string and a payload as a Python dict and sends + the notification in the enhanced format + """ + token_bin = a2b_hex(token_hex) + token_length_bin = APNs.packed_ushort_big_endian(len(token_bin)) + payload_json = payload.json() + payload_length_bin = APNs.packed_ushort_big_endian(len(payload_json)) + identifier_bin = APNs.packed_uint_big_endian(identifier) + expiry_bin = APNs.packed_uint_big_endian(int(mktime(expiry.timetuple()))) + + notification = ('\1' + identifier_bin + expiry_bin + token_length_bin + token_bin + + payload_length_bin + payload_json) + return notification + + def send_notification(self, token_hex, payload, identifier=None, expiry=None): + if self.enhanced: + if not expiry: # by default, undelivered notification expires after 30 seconds + expiry = datetime.utcnow() + timedelta(30) + if not identifier: + identifier = 0 + self.write(self._get_enhanced_notification(token_hex, payload, identifier, + expiry)) + else: + self.write(self._get_notification(token_hex, payload)) diff --git a/apnserrors.py b/apnserrors.py new file mode 100644 index 0000000..0d24ad9 --- /dev/null +++ b/apnserrors.py @@ -0,0 +1,44 @@ +class PayloadTooLargeError(Exception): + def __init__(self): + super(PayloadTooLargeError, self).__init__() + +class APNResponseError(Exception): + def __init__(self, status, identifier): + self.status = status + self.identifier = identifier + +class ProcessingError(APNResponseError): + def __init__(self, identifier): + super(ProcessingError, self).__init__(1, identifier) + +class MissingDeviceTokenError(APNResponseError): + def __init__(self, identifier): + super(MissingDeviceTokenError, self).__init__(2, identifier) + +class MissingTopicError(APNResponseError): + def __init__(self, identifier): + super(MissingTopicError, self).__init__(3, identifier) + +class MissingPayloadError(APNResponseError): + def __init__(self, identifier): + super(MissingPayloadError, self).__init__(4, identifier) + +class InvalidTokenSizeError(APNResponseError): + def __init__(self, identifier): + super(InvalidTokenSizeError, self).__init__(5, identifier) + +class InvalidTopicSizeError(APNResponseError): + def __init__(self, identifier): + super(InvalidTopicSizeError, self).__init__(6, identifier) + +class InvalidPayloadSizeError(APNResponseError): + def __init__(self, identifier): + super(InvalidPayloadSizeError, self).__init__(7, identifier) + +class InvalidTokenError(APNResponseError): + def __init__(self, identifier): + super(InvalidTokenError, self).__init__(8, identifier) + +class UnknownError(APNResponseError): + def __init__(self, identifier): + super(UnknownError, self).__init__(255, identifier) diff --git a/tests.py b/tests.py index ff5af97..081ee0c 100644 --- a/tests.py +++ b/tests.py @@ -3,6 +3,7 @@ from apns import * from binascii import a2b_hex from random import random +from datetime import datetime, timedelta import hashlib import os @@ -90,6 +91,37 @@ def testGatewayServer(self): self.assertEqual(len(notification), expected_length) self.assertEqual(notification[0], '\0') + def testEnhancedGatewayServer(self): + pem_file = TEST_CERTIFICATE + apns = APNs(use_sandbox=True, cert_file=pem_file, key_file=pem_file, enhanced=True) + gateway_server = apns.gateway_server + + self.assertEqual(gateway_server.cert_file, apns.cert_file) + self.assertEqual(gateway_server.key_file, apns.key_file) + + token_hex = 'b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c' + payload = Payload( + alert = "Hello World!", + sound = "default", + badge = 4 + ) + expiry = datetime.utcnow() + timedelta(30) + notification = gateway_server._get_enhanced_notification(token_hex, payload, 0, + expiry) + + expected_length = ( + 1 # leading null byte + + 4 # identifier as a packed int + + 4 # expiry as a packed int + + 2 # length of token as a packed short + + len(token_hex) / 2 # length of token as binary string + + 2 # length of payload as a packed short + + len(payload.json()) # length of JSON-formatted payload + ) + + self.assertEqual(len(notification), expected_length) + self.assertEqual(notification[0], '\1') + def testFeedbackServer(self): pem_file = TEST_CERTIFICATE apns = APNs(use_sandbox=True, cert_file=pem_file, key_file=pem_file) From a30d29360135c3b166127362d71d3628bd5e3b82 Mon Sep 17 00:00:00 2001 From: Josh Chung Date: Sun, 28 Oct 2012 13:34:18 +0900 Subject: [PATCH 2/5] added __repr__ and __str__ to APNResponseError --- apnserrors.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/apnserrors.py b/apnserrors.py index 0d24ad9..46095dd 100644 --- a/apnserrors.py +++ b/apnserrors.py @@ -6,6 +6,12 @@ class APNResponseError(Exception): def __init__(self, status, identifier): self.status = status self.identifier = identifier + + def __repr__(self): + return "{}".format(self.__class__.__name__, self.identifier) + + def __str__(self): + return self.__repr__() class ProcessingError(APNResponseError): def __init__(self, identifier): From 0157eef39a626ffca19cd4e48fd57854181fcffb Mon Sep 17 00:00:00 2001 From: Josh Chung Date: Sun, 28 Oct 2012 14:23:25 +0900 Subject: [PATCH 3/5] support enhanced format after python 2.5 and disconnect if broken pipe occurs while sending notification in old format --- apns.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/apns.py b/apns.py index bd44a71..c2ed129 100644 --- a/apns.py +++ b/apns.py @@ -26,16 +26,20 @@ from binascii import a2b_hex, b2a_hex from datetime import datetime, timedelta from time import mktime -from socket import socket, AF_INET, SOCK_STREAM, timeout +from socket import socket, AF_INET, SOCK_STREAM, timeout, error as socket_error from struct import pack, unpack import select +import errno + +support_enhanced = True try: from ssl import wrap_socket from ssl import SSLError, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE except ImportError: from socket import ssl as wrap_socket + support_enhanced = False try: import json @@ -60,7 +64,7 @@ def __init__(self, use_sandbox=False, cert_file=None, key_file=None, enhanced=Fa self.use_sandbox = use_sandbox self.cert_file = cert_file self.key_file = key_file - self.enhanced = enhanced + self.enhanced = enhanced and support_enhanced self._feedback_connection = None self._gateway_connection = None @@ -223,7 +227,12 @@ def write(self, string): raise timeout else: # not-enhanced format using blocking socket - return self._connection().sendall(string) + try: + return self._connection().write(string) + except socket_error, err: + if errno.EPIPE == err.errno: + self._disconnect() + raise class PayloadAlert(object): def __init__(self, body, action_loc_key=None, loc_key=None, From 373469b3eba6792dc97ac209d70dbe7308dbe2b3 Mon Sep 17 00:00:00 2001 From: Josh Chung Date: Sun, 28 Oct 2012 14:23:54 +0900 Subject: [PATCH 4/5] added __init__.py --- __init__.py | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 __init__.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..fe96c92 --- /dev/null +++ b/__init__.py @@ -0,0 +1,7 @@ +import apns + +APNs = apns.APNs +Payload = apns.Payload + +PayloadTooLargeError = apns.PayloadTooLargeError +APNResponseError = apns.APNResponseError From 737c2ee0096b87292841e2946e28c01474179a66 Mon Sep 17 00:00:00 2001 From: Josh Chung Date: Sun, 28 Oct 2012 19:02:57 +0900 Subject: [PATCH 5/5] a bit enhancements --- README.markdown | 21 +++++++++++++++++++++ __init__.py | 14 ++++++++++++-- apns.py | 25 +++++++++++++++---------- 3 files changed, 48 insertions(+), 12 deletions(-) diff --git a/README.markdown b/README.markdown index b13bdcf..ada8261 100644 --- a/README.markdown +++ b/README.markdown @@ -26,6 +26,27 @@ for (token_hex, fail_time) in apns.feedback_server.items(): # do stuff with token_hex and fail_time ``` +## Send a notification in enhanced format +```python +from apns import APNs, Payload, APNResponseError +from datetime import datetime, timedelta + +apns = APNs(use_sandbox=True, cert_file='cert.pem', key_file='key.pem', enhanced=True) + +token_hex = 'b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b87' +payload = Payload(alert="Hello World!", sound="default", badge=1) +identifier = 1234 +expiry = datetime.utcnow() + timedelta(30) # undelivered notification expires after 30 seconds + +try: + apns.gateway_server.send_notification(token_hex, payload) +except APNResponseError, err: + # handle apn's error response + # just tried notification is not sent and this response doesn't belong to that notification. + # formerly sent notifications should to be looked up with err.identifier to find one which caused this error. + # when error response is received, connection to APN server is closed. +``` + For more complicated alerts including custom buttons etc, use the PayloadAlert class. Example: diff --git a/__init__.py b/__init__.py index fe96c92..1b62267 100644 --- a/__init__.py +++ b/__init__.py @@ -1,7 +1,17 @@ import apns +import apnserrors APNs = apns.APNs Payload = apns.Payload -PayloadTooLargeError = apns.PayloadTooLargeError -APNResponseError = apns.APNResponseError +PayloadTooLargeError = apnserrors.PayloadTooLargeError +APNResponseError = apnserrors.APNResponseError +ProcessingError = apnserrors.ProcessingError +MissingDeviceTokenError = apnserrors.MissingDeviceTokenError +MissingTopicError = apnserrors.MissingTopicError +MissingPayloadError = apnserrors.MissingPayloadError +InvalidTokenSizeError = apnserrors.InvalidTokenSizeError +InvalidTopicSizeError = apnserrors.InvalidTopicSizeError +InvalidPayloadSizeError = apnserrors.InvalidPayloadSizeError +InvalidTokenError = apnserrors.InvalidTokenError +UnknownError = apnserrors.UnknownError diff --git a/apns.py b/apns.py index c2ed129..17b9ce1 100644 --- a/apns.py +++ b/apns.py @@ -24,7 +24,7 @@ # SOFTWARE. from binascii import a2b_hex, b2a_hex -from datetime import datetime, timedelta +from datetime import datetime from time import mktime from socket import socket, AF_INET, SOCK_STREAM, timeout, error as socket_error from struct import pack, unpack @@ -230,9 +230,14 @@ def write(self, string): try: return self._connection().write(string) except socket_error, err: - if errno.EPIPE == err.errno: - self._disconnect() - raise + try: + if errno.EPIPE == err.errno: + self._disconnect() + except AttributeError: + if errno.EPIPE == err.args[0]: + self._disconnect() + finally: + raise err class PayloadAlert(object): def __init__(self, body, action_loc_key=None, loc_key=None, @@ -388,19 +393,19 @@ def _get_enhanced_notification(self, token_hex, payload, identifier, expiry): payload_json = payload.json() payload_length_bin = APNs.packed_ushort_big_endian(len(payload_json)) identifier_bin = APNs.packed_uint_big_endian(identifier) - expiry_bin = APNs.packed_uint_big_endian(int(mktime(expiry.timetuple()))) + + expiry_int = int(mktime(expiry.timetuple())) if isinstance(expiry, datetime) \ + else int(expiry) + + expiry_bin = APNs.packed_uint_big_endian(expiry_int) notification = ('\1' + identifier_bin + expiry_bin + token_length_bin + token_bin + payload_length_bin + payload_json) return notification - def send_notification(self, token_hex, payload, identifier=None, expiry=None): + def send_notification(self, token_hex, payload, identifier=0, expiry=0): if self.enhanced: - if not expiry: # by default, undelivered notification expires after 30 seconds - expiry = datetime.utcnow() + timedelta(30) - if not identifier: - identifier = 0 self.write(self._get_enhanced_notification(token_hex, payload, identifier, expiry)) else: