diff --git a/README.markdown b/README.markdown index b13bdcf..ada8261 100644 --- a/README.markdown +++ b/README.markdown @@ -26,6 +26,27 @@ for (token_hex, fail_time) in apns.feedback_server.items(): # do stuff with token_hex and fail_time ``` +## Send a notification in enhanced format +```python +from apns import APNs, Payload, APNResponseError +from datetime import datetime, timedelta + +apns = APNs(use_sandbox=True, cert_file='cert.pem', key_file='key.pem', enhanced=True) + +token_hex = 'b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b87' +payload = Payload(alert="Hello World!", sound="default", badge=1) +identifier = 1234 +expiry = datetime.utcnow() + timedelta(30) # undelivered notification expires after 30 seconds + +try: + apns.gateway_server.send_notification(token_hex, payload) +except APNResponseError, err: + # handle apn's error response + # just tried notification is not sent and this response doesn't belong to that notification. + # formerly sent notifications should to be looked up with err.identifier to find one which caused this error. + # when error response is received, connection to APN server is closed. +``` + For more complicated alerts including custom buttons etc, use the PayloadAlert class. Example: diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..1b62267 --- /dev/null +++ b/__init__.py @@ -0,0 +1,17 @@ +import apns +import apnserrors + +APNs = apns.APNs +Payload = apns.Payload + +PayloadTooLargeError = apnserrors.PayloadTooLargeError +APNResponseError = apnserrors.APNResponseError +ProcessingError = apnserrors.ProcessingError +MissingDeviceTokenError = apnserrors.MissingDeviceTokenError +MissingTopicError = apnserrors.MissingTopicError +MissingPayloadError = apnserrors.MissingPayloadError +InvalidTokenSizeError = apnserrors.InvalidTokenSizeError +InvalidTopicSizeError = apnserrors.InvalidTopicSizeError +InvalidPayloadSizeError = apnserrors.InvalidPayloadSizeError +InvalidTokenError = apnserrors.InvalidTokenError +UnknownError = apnserrors.UnknownError diff --git a/apns.py b/apns.py index fdcec08..17b9ce1 100644 --- a/apns.py +++ b/apns.py @@ -25,25 +25,37 @@ from binascii import a2b_hex, b2a_hex from datetime import datetime -from socket import socket, AF_INET, SOCK_STREAM +from time import mktime +from socket import socket, AF_INET, SOCK_STREAM, timeout, error as socket_error from struct import pack, unpack +import select +import errno + +support_enhanced = True + try: from ssl import wrap_socket + from ssl import SSLError, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE except ImportError: from socket import ssl as wrap_socket + support_enhanced = False try: import json except ImportError: import simplejson as json +from apnserrors import * + MAX_PAYLOAD_LENGTH = 256 +TIMEOUT = 60 +ERROR_RESPONSE_LENGTH = 6 class APNs(object): """A class representing an Apple Push Notification service connection""" - def __init__(self, use_sandbox=False, cert_file=None, key_file=None): + def __init__(self, use_sandbox=False, cert_file=None, key_file=None, enhanced=False): """ Set use_sandbox to True to use the sandbox (test) APNs servers. Default is False. @@ -52,9 +64,17 @@ def __init__(self, use_sandbox=False, cert_file=None, key_file=None): self.use_sandbox = use_sandbox self.cert_file = cert_file self.key_file = key_file + self.enhanced = enhanced and support_enhanced self._feedback_connection = None self._gateway_connection = None + @staticmethod + def unpacked_uchar_big_endian(byte): + """ + Returns an unsigned char from a packed big-endian (network) byte + """ + return unpack('>B', byte)[0] + @staticmethod def packed_ushort_big_endian(num): """ @@ -100,7 +120,8 @@ def gateway_server(self): self._gateway_connection = GatewayConnection( use_sandbox = self.use_sandbox, cert_file = self.cert_file, - key_file = self.key_file + key_file = self.key_file, + enhanced = self.enhanced ) return self._gateway_connection @@ -109,10 +130,11 @@ class APNsConnection(object): """ A generic connection class for communicating with the APNs """ - def __init__(self, cert_file=None, key_file=None): + def __init__(self, cert_file=None, key_file=None, enhanced=False): super(APNsConnection, self).__init__() self.cert_file = cert_file self.key_file = key_file + self.enhanced = enhanced self._socket = None self._ssl = None @@ -123,11 +145,29 @@ def _connect(self): # Establish an SSL connection self._socket = socket(AF_INET, SOCK_STREAM) self._socket.connect((self.server, self.port)) - self._ssl = wrap_socket(self._socket, self.key_file, self.cert_file) + + if self.enhanced: + self._ssl = wrap_socket(self._socket, self.key_file, self.cert_file, + do_handshake_on_connect=False) + self._ssl.setblocking(0) + while True: + try: + self._ssl.do_handshake() + break + except SSLError, err: + if SSL_ERROR_WANT_READ == err.args[0]: + select.select([self._ssl], [], []) + elif SSL_ERROR_WANT_WRITE == err.args[0]: + select.select([], [self._ssl], []) + else: + raise + else: + self._ssl = wrap_socket(self._socket, self.key_file, self.cert_file) def _disconnect(self): if self._socket: self._socket.close() + self._ssl = None def _connection(self): if not self._ssl: @@ -135,11 +175,69 @@ def _connection(self): return self._ssl def read(self, n=None): - return self._connection().read(n) + return self._connection().recv(n) + + def recvall(self, n): + data = "" + while True: + more = self._connection().recv(n - len(data)) + data += more + if len(data) >= n: + break + rlist, _, _ = select.select([self._connection()], [], [], TIMEOUT) + if not rlist: + raise timeout + return data + def write(self, string): - return self._connection().write(string) - + if self.enhanced: # nonblocking socket + rlist, _, _ = select.select([self._connection()], [], [], 0) + + if rlist: # there's error response from APNs + buff = self.recvall(ERROR_RESPONSE_LENGTH) + if len(buff) != ERROR_RESPONSE_LENGTH: + return None + + command = APNs.unpacked_uchar_big_endian(buff[0]) + + if 8 != command: + self._disconnect() + raise UnknownError(0) + + status = APNs.unpacked_uchar_big_endian(buff[1]) + identifier = APNs.unpacked_uint_big_endian(buff[2:6]) + + self._disconnect() + + raise { 1: ProcessingError, + 2: MissingDeviceTokenError, + 3: MissingTopicError, + 4: MissingPayloadError, + 5: InvalidTokenSizeError, + 6: InvalidTopicSizeError, + 7: InvalidPayloadSizeError, + 8: InvalidTokenError }.get(status, UnknownError)(identifier) + + _, wlist, _ = select.select([], [self._connection()], [], TIMEOUT) + if wlist: + return self._connection().sendall(string) + else: + self._disconnect() + raise timeout + + else: # not-enhanced format using blocking socket + try: + return self._connection().write(string) + except socket_error, err: + try: + if errno.EPIPE == err.errno: + self._disconnect() + except AttributeError: + if errno.EPIPE == err.args[0]: + self._disconnect() + finally: + raise err class PayloadAlert(object): def __init__(self, body, action_loc_key=None, loc_key=None, @@ -163,10 +261,6 @@ def dict(self): d['launch-image'] = self.launch_image return d -class PayloadTooLargeError(Exception): - def __init__(self): - super(PayloadTooLargeError, self).__init__() - class Payload(object): """A class representing an APNs message payload""" def __init__(self, alert=None, badge=None, sound=None, custom={}): @@ -285,10 +379,34 @@ def _get_notification(self, token_hex, payload): payload_length_bin = APNs.packed_ushort_big_endian(len(payload_json)) notification = ('\0' + token_length_bin + token_bin - + payload_length_bin + payload_json) + + payload_length_bin + payload_json) return notification - def send_notification(self, token_hex, payload): - self.write(self._get_notification(token_hex, payload)) + def _get_enhanced_notification(self, token_hex, payload, identifier, expiry): + """ + Takes a token as a hex string and a payload as a Python dict and sends + the notification in the enhanced format + """ + token_bin = a2b_hex(token_hex) + token_length_bin = APNs.packed_ushort_big_endian(len(token_bin)) + payload_json = payload.json() + payload_length_bin = APNs.packed_ushort_big_endian(len(payload_json)) + identifier_bin = APNs.packed_uint_big_endian(identifier) + + expiry_int = int(mktime(expiry.timetuple())) if isinstance(expiry, datetime) \ + else int(expiry) + + expiry_bin = APNs.packed_uint_big_endian(expiry_int) + notification = ('\1' + identifier_bin + expiry_bin + token_length_bin + token_bin + + payload_length_bin + payload_json) + + return notification + + def send_notification(self, token_hex, payload, identifier=0, expiry=0): + if self.enhanced: + self.write(self._get_enhanced_notification(token_hex, payload, identifier, + expiry)) + else: + self.write(self._get_notification(token_hex, payload)) diff --git a/apnserrors.py b/apnserrors.py new file mode 100644 index 0000000..46095dd --- /dev/null +++ b/apnserrors.py @@ -0,0 +1,50 @@ +class PayloadTooLargeError(Exception): + def __init__(self): + super(PayloadTooLargeError, self).__init__() + +class APNResponseError(Exception): + def __init__(self, status, identifier): + self.status = status + self.identifier = identifier + + def __repr__(self): + return "{}".format(self.__class__.__name__, self.identifier) + + def __str__(self): + return self.__repr__() + +class ProcessingError(APNResponseError): + def __init__(self, identifier): + super(ProcessingError, self).__init__(1, identifier) + +class MissingDeviceTokenError(APNResponseError): + def __init__(self, identifier): + super(MissingDeviceTokenError, self).__init__(2, identifier) + +class MissingTopicError(APNResponseError): + def __init__(self, identifier): + super(MissingTopicError, self).__init__(3, identifier) + +class MissingPayloadError(APNResponseError): + def __init__(self, identifier): + super(MissingPayloadError, self).__init__(4, identifier) + +class InvalidTokenSizeError(APNResponseError): + def __init__(self, identifier): + super(InvalidTokenSizeError, self).__init__(5, identifier) + +class InvalidTopicSizeError(APNResponseError): + def __init__(self, identifier): + super(InvalidTopicSizeError, self).__init__(6, identifier) + +class InvalidPayloadSizeError(APNResponseError): + def __init__(self, identifier): + super(InvalidPayloadSizeError, self).__init__(7, identifier) + +class InvalidTokenError(APNResponseError): + def __init__(self, identifier): + super(InvalidTokenError, self).__init__(8, identifier) + +class UnknownError(APNResponseError): + def __init__(self, identifier): + super(UnknownError, self).__init__(255, identifier) diff --git a/tests.py b/tests.py index ff5af97..081ee0c 100644 --- a/tests.py +++ b/tests.py @@ -3,6 +3,7 @@ from apns import * from binascii import a2b_hex from random import random +from datetime import datetime, timedelta import hashlib import os @@ -90,6 +91,37 @@ def testGatewayServer(self): self.assertEqual(len(notification), expected_length) self.assertEqual(notification[0], '\0') + def testEnhancedGatewayServer(self): + pem_file = TEST_CERTIFICATE + apns = APNs(use_sandbox=True, cert_file=pem_file, key_file=pem_file, enhanced=True) + gateway_server = apns.gateway_server + + self.assertEqual(gateway_server.cert_file, apns.cert_file) + self.assertEqual(gateway_server.key_file, apns.key_file) + + token_hex = 'b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c' + payload = Payload( + alert = "Hello World!", + sound = "default", + badge = 4 + ) + expiry = datetime.utcnow() + timedelta(30) + notification = gateway_server._get_enhanced_notification(token_hex, payload, 0, + expiry) + + expected_length = ( + 1 # leading null byte + + 4 # identifier as a packed int + + 4 # expiry as a packed int + + 2 # length of token as a packed short + + len(token_hex) / 2 # length of token as binary string + + 2 # length of payload as a packed short + + len(payload.json()) # length of JSON-formatted payload + ) + + self.assertEqual(len(notification), expected_length) + self.assertEqual(notification[0], '\1') + def testFeedbackServer(self): pem_file = TEST_CERTIFICATE apns = APNs(use_sandbox=True, cert_file=pem_file, key_file=pem_file)