diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index fa22844d..3806b403 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -212,7 +212,7 @@ def send_notification(self, channel=None, version=None, data=None, if use_header: headers.update({ "Content-Type": "application/octet-stream", - "Content-Encoding": "aesgcm-128", + "Content-Encoding": "aesgcm", "Encryption": self._crypto_key, "Crypto-Key": 'keyid="a1"; dh="JcqK-OLkJZlJ3sJJWstJCA"', }) diff --git a/autopush/tests/test_web_validation.py b/autopush/tests/test_web_validation.py index ee204e76..eac50a78 100644 --- a/autopush/tests/test_web_validation.py +++ b/autopush/tests/test_web_validation.py @@ -150,7 +150,7 @@ def check_result(result): class TestSimplePushRequestSchema(unittest.TestCase): def _make_fut(self): - from autopush.web.push_validation import SimplePushRequestSchema + from autopush.web.simplepush import SimplePushRequestSchema schema = SimplePushRequestSchema() schema.context["settings"] = Mock() schema.context["log"] = Mock() @@ -283,7 +283,7 @@ def test_invalid_data_size(self): class TestWebPushRequestSchema(unittest.TestCase): def _make_fut(self): - from autopush.web.push_validation import WebPushRequestSchema + from autopush.web.webpush import WebPushRequestSchema schema = WebPushRequestSchema() schema.context["settings"] = Mock() schema.context["log"] = Mock() @@ -324,10 +324,14 @@ def test_no_headers(self): schema.context["settings"].router.get_uaid.return_value = dict( router_type="webpush", ) - data = self._make_test_data(body="asdfasdf", - headers={"ttl": "invalid"}) - result, errors = schema.load(data) - eq_(errors, {'headers': {'ttl': [u'Not a valid integer.']}}) + data = self._make_test_data(body="asdfasdf") + + with assert_raises(InvalidRequest) as cm: + schema.load(data) + + eq_(cm.exception.status_code, 400) + eq_(cm.exception.errno, 110) + eq_(cm.exception.message, "Unknown Content-Encoding") def test_invalid_simplepush_user(self): schema = self._make_fut() @@ -419,23 +423,136 @@ def test_invalid_header_combo(self): info = self._make_test_data( headers={ "content-encoding": "aesgcm128", - "crypto-key": "asdfjialsjdfiasjld", - } + "crypto-key": "dh=asdfjialsjdfiasjld", + "encryption-key": "dh=asdfjasidlfjaislf", + }, + body="asdfasdf", ) with assert_raises(InvalidRequest) as cm: schema.load(info) eq_(cm.exception.errno, 110) + def test_invalid_header_combo_04(self): + schema = self._make_fut() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + schema.context["settings"].router.get_uaid.return_value = dict( + router_type="webpush", + ) info = self._make_test_data( headers={ - "encryption-key": "aesgcm128", - "crypto-key": "asdfjialsjdfiasjld", - } + "content-encoding": "aesgcm", + "encryption": "salt=ajisldjfi", + "crypto-key": "dh=asdfjialsjdfiasjld", + "encryption-key": "dh=asdfjasidlfjaislf", + }, + body="asdfasdf", ) with assert_raises(InvalidRequest) as cm: schema.load(info) + eq_(cm.exception.message, "Encryption-Key header not valid for 02 " + "or later webpush-encryption") + eq_(cm.exception.errno, 110) + + def test_missing_encryption_salt(self): + schema = self._make_fut() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + schema.context["settings"].router.get_uaid.return_value = dict( + router_type="webpush", + ) + info = self._make_test_data( + headers={ + "content-encoding": "aesgcm128", + "encryption": "dh=asdfjasidlfjaislf", + "encryption-key": "dh=jilajsidfljasildjf", + }, + body="asdfasdf", + ) + with assert_raises(InvalidRequest) as cm: + schema.load(info) + + eq_(cm.exception.status_code, 400) + eq_(cm.exception.errno, 110) + + def test_missing_encryption_salt_04(self): + schema = self._make_fut() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + schema.context["settings"].router.get_uaid.return_value = dict( + router_type="webpush", + ) + info = self._make_test_data( + headers={ + "content-encoding": "aesgcm", + "encryption": "dh=asdfjasidlfjaislf", + "crypto-key": "dh=jilajsidfljasildjf", + }, + body="asdfasdf", + ) + with assert_raises(InvalidRequest) as cm: + schema.load(info) + + eq_(cm.exception.status_code, 400) + eq_(cm.exception.errno, 110) + + def test_missing_encryption_key_dh(self): + schema = self._make_fut() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + schema.context["settings"].router.get_uaid.return_value = dict( + router_type="webpush", + ) + info = self._make_test_data( + headers={ + "content-encoding": "aesgcm128", + "encryption": "salt=asdfjasidlfjaislf", + "encryption-key": "keyid=jialsjdifjlasd", + }, + body="asdfasdf", + ) + with assert_raises(InvalidRequest) as cm: + schema.load(info) + + eq_(cm.exception.status_code, 400) + eq_(cm.exception.errno, 110) + + def test_missing_crypto_key_dh(self): + schema = self._make_fut() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + schema.context["settings"].router.get_uaid.return_value = dict( + router_type="webpush", + ) + info = self._make_test_data( + headers={ + "content-encoding": "aesgcm", + "encryption": "salt=asdfjasidlfjaislf", + "crypto-key": "p256ecdsa=BA1Hxzyi1RUM1b5wjxsn7nGxAs", + }, + body="asdfasdf", + ) + with assert_raises(InvalidRequest) as cm: + schema.load(info) + + eq_(cm.exception.status_code, 400) eq_(cm.exception.errno, 110) def test_invalid_data_size(self): @@ -451,7 +568,12 @@ def test_invalid_data_size(self): schema.context["settings"].max_data = 1 with assert_raises(InvalidRequest) as cm: - schema.load(self._make_test_data(body="asdfasdfasdfasdfasd")) + schema.load(self._make_test_data( + headers={ + "content-encoding": "aesgcm", + "crypto-key": "dh=asdfjialsjdfiasjld", + }, + body="asdfasdfasdfasdfasd")) eq_(cm.exception.errno, 104) @@ -489,7 +611,8 @@ def test_valid_data_crypto_padding_stripped(self): headers={ "authorization": "not vapid", "content-encoding": "aesgcm128", - "encryption": "salt=" + padded_value + "encryption": "salt=" + padded_value, + "encryption-key": "dh=asdfasdfasdf", } ) @@ -497,6 +620,36 @@ def test_valid_data_crypto_padding_stripped(self): eq_(errors, {}) eq_(result["headers"]["encryption"], "salt=asdfjiasljdf") + def test_invalid_dh_value_for_01_crypto(self): + schema = self._make_fut() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + schema.context["settings"].router.get_uaid.return_value = dict( + router_type="webpush", + ) + + padded_value = "asdfjiasljdf===" + + info = self._make_test_data( + body="asdfasdfasdfasdf", + headers={ + "authorization": "not vapid", + "content-encoding": "aesgcm128", + "encryption": "salt=" + padded_value, + "crypto-key": "dh=asdfasdfasdf" + } + ) + + with assert_raises(InvalidRequest) as cm: + schema.load(info) + + eq_(cm.exception.status_code, 400) + eq_(cm.exception.message, "dh value in Crypto-Key header not valid " + "for 01 or earlier webpush-encryption") + def test_invalid_vapid_crypto_header(self): schema = self._make_fut() schema.context["settings"].parse_endpoint.return_value = dict( @@ -511,7 +664,7 @@ def test_invalid_vapid_crypto_header(self): info = self._make_test_data( body="asdfasdfasdfasdf", headers={ - "content-encoding": "text", + "content-encoding": "aesgcm", "encryption": "salt=ignored", "authorization": "invalid", "crypto-key": "dh=crap", @@ -565,7 +718,7 @@ def test_invalid_topic(self): class TestWebPushRequestSchemaUsingVapid(unittest.TestCase): def _make_fut(self): - from autopush.web.push_validation import WebPushRequestSchema + from autopush.web.webpush import WebPushRequestSchema from autopush.settings import AutopushSettings schema = WebPushRequestSchema() schema.context["log"] = Mock() @@ -617,7 +770,7 @@ def test_valid_vapid_crypto_header(self): token="asdfasdf", ), headers={ - "content-encoding": "aes128", + "content-encoding": "aesgcm", "encryption": "salt=stuff", "authorization": auth, "crypto-key": ckey @@ -648,7 +801,7 @@ def test_valid_vapid_crypto_header_webpush(self): token="asdfasdf", ), headers={ - "content-encoding": "aes128", + "content-encoding": "aesgcm", "encryption": "salt=stuff", "authorization": auth, "crypto-key": ckey @@ -659,7 +812,7 @@ def test_valid_vapid_crypto_header_webpush(self): eq_(errors, {}) ok_("jwt" in result) - @patch("autopush.web.push_validation.extract_jwt") + @patch("autopush.web.webpush.extract_jwt") def test_invalid_vapid_crypto_header(self, mock_jwt): schema = self._make_fut() mock_jwt.side_effect = ValueError("Unknown public key " @@ -682,7 +835,7 @@ def test_invalid_vapid_crypto_header(self, mock_jwt): token="asdfasdf", ), headers={ - "content-encoding": "aes128", + "content-encoding": "aesgcm", "encryption": "salt=stuff", "authorization": auth, "crypto-key": ckey @@ -695,7 +848,7 @@ def test_invalid_vapid_crypto_header(self, mock_jwt): eq_(cm.exception.status_code, 401) eq_(cm.exception.errno, 109) - @patch("autopush.web.push_validation.extract_jwt") + @patch("autopush.web.webpush.extract_jwt") def test_invalid_encryption_header(self, mock_jwt): schema = self._make_fut() mock_jwt.side_effect = ValueError("Unknown public key " @@ -718,8 +871,8 @@ def test_invalid_encryption_header(self, mock_jwt): token="asdfasdf", ), headers={ - "content-encoding": "aes128", - "encryption": "foo=stuff", + "content-encoding": "aesgcm", + "encryption": "salt=stuff", "authorization": auth, "crypto-key": ckey } @@ -729,9 +882,9 @@ def test_invalid_encryption_header(self, mock_jwt): schema.load(info) eq_(cm.exception.status_code, 401) - eq_(cm.exception.errno, 110) + eq_(cm.exception.errno, 109) - @patch("autopush.web.push_validation.extract_jwt") + @patch("autopush.web.webpush.extract_jwt") def test_invalid_encryption_jwt(self, mock_jwt): schema = self._make_fut() # use a deeply superclassed error to make sure that it gets picked up. @@ -754,7 +907,7 @@ def test_invalid_encryption_jwt(self, mock_jwt): token="asdfasdf", ), headers={ - "content-encoding": "aes128", + "content-encoding": "aesgcm", "encryption": "salt=stuff", "authorization": auth, "crypto-key": ckey @@ -767,7 +920,7 @@ def test_invalid_encryption_jwt(self, mock_jwt): eq_(cm.exception.status_code, 401) eq_(cm.exception.errno, 109) - @patch("autopush.web.push_validation.extract_jwt") + @patch("autopush.web.webpush.extract_jwt") def test_invalid_crypto_key_header_content(self, mock_jwt): schema = self._make_fut() mock_jwt.side_effect = ValueError("Unknown public key " @@ -800,7 +953,7 @@ def test_invalid_crypto_key_header_content(self, mock_jwt): with assert_raises(InvalidRequest) as cm: schema.load(info) - eq_(cm.exception.status_code, 401) + eq_(cm.exception.status_code, 400) eq_(cm.exception.errno, 110) def test_expired_vapid_header(self): @@ -823,7 +976,7 @@ def test_expired_vapid_header(self): token="asdfasdf", ), headers={ - "content-encoding": "aes128", + "content-encoding": "aesgcm", "encryption": "salt=stuff", "authorization": auth, "crypto-key": ckey @@ -857,7 +1010,7 @@ def test_missing_vapid_header(self): token="asdfasdf", ), headers={ - "content-encoding": "aes128", + "content-encoding": "aesgcm", "encryption": "salt=stuff", "crypto-key": ckey } @@ -890,7 +1043,7 @@ def test_bogus_vapid_header(self): token="asdfasdf", ), headers={ - "content-encoding": "aes128", + "content-encoding": "aesgcm", "encryption": "salt=stuff", "crypto-key": ckey, "authorization": "bogus crap" diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index 8b4e26ca..fead90bb 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -732,13 +732,25 @@ def test_hello_webpush_uses_one_db_call(self): self._send_message(dict(messageType="hello", use_webpush=True, channelIDs=[])) - def check_result(msg): + d = Deferred() + + def check_result(msg, duration=0): + if len(db.DB_CALLS) < 3: # pragma: nocover + if duration > 3.0: # pragma: nocover + raise Exception("db calls isn't 3 yet") + else: + reactor.callLater(0.1, check_result, msg, duration+0.1) + return + eq_(db.DB_CALLS, ['register_user', 'fetch_messages', 'fetch_timestamp_messages']) eq_(msg["status"], 200) db.DB_CALLS = [] db.TRACK_DB_CALLS = False - return self._check_response(check_result) + d.callback(True) + f = self._check_response(check_result) + f.addErrback(lambda x: d.callback(True)) + return d def test_hello_with_webpush(self): self._connect() diff --git a/autopush/web/push_validation.py b/autopush/web/push_validation.py deleted file mode 100644 index 0a942e24..00000000 --- a/autopush/web/push_validation.py +++ /dev/null @@ -1,295 +0,0 @@ -"""Validation handler and Schemas""" -import re -import time -import urlparse - -from boto.dynamodb2.exceptions import ItemNotFound -from cryptography.fernet import InvalidToken -from jose import JOSEError -from marshmallow import ( - Schema, - fields, - pre_load, - post_load, - validates, - validates_schema, -) - -from autopush.web.base import AUTH_SCHEMES, PREF_SCHEME -from autopush.exceptions import ( - InvalidRequest, - InvalidTokenException, - VapidAuthException, -) -from autopush.utils import ( - base64url_encode, - extract_jwt, - WebPushNotification -) - -MAX_TTL = 60 * 60 * 24 * 60 - -# Base64 URL validation -VALID_BASE64_URL = re.compile(r'^[0-9A-Za-z\-_]+=*$') - - -class SimplePushSubscriptionSchema(Schema): - uaid = fields.UUID(required=True) - chid = fields.UUID(required=True) - - @pre_load - def extract_subscription(self, d): - try: - result = self.context["settings"].parse_endpoint( - token=d["token"], - version=d["api_ver"], - ) - except InvalidTokenException: - raise InvalidRequest("invalid token", errno=102) - return result - - @validates_schema - def validate_uaid_chid(self, d): - try: - result = self.context["settings"].router.get_uaid(d["uaid"].hex) - except ItemNotFound: - raise InvalidRequest("UAID not found", status_code=410, errno=103) - - if result.get("router_type") != "simplepush": - raise InvalidRequest("Wrong URL for user", errno=108) - - # Propagate the looked up user data back out - d["user_data"] = result - - -class SimplePushRequestSchema(Schema): - subscription = fields.Nested(SimplePushSubscriptionSchema, - load_from="token_info") - version = fields.Integer(missing=time.time) - data = fields.String(missing=None) - - @validates('data') - def validate_data(self, value): - max_data = self.context["settings"].max_data - if value and len(value) > max_data: - raise InvalidRequest( - "Data payload must be smaller than {}".format(max_data), - errno=104, - ) - - @pre_load - def token_prep(self, d): - d["token_info"] = dict( - api_ver=d["path_kwargs"].get("api_ver"), - token=d["path_kwargs"].get("token"), - ) - return d - - @pre_load - def extract_fields(self, d): - body_string = d["body"] - if len(body_string) > 0: - body_args = urlparse.parse_qs(body_string, keep_blank_values=True) - version = body_args.get("version") - data = body_args.get("data") - else: - version = d["arguments"].get("version") - data = d["arguments"].get("data") - version = version[0] if version is not None else version - data = data[0] if data is not None else data - if version and version >= "1": - d["version"] = version - if data: - d["data"] = data - return d - - -class WebPushSubscriptionSchema(Schema): - uaid = fields.UUID(required=True) - chid = fields.UUID(required=True) - public_key = fields.Raw(missing=None) - - @pre_load - def extract_subscription(self, d): - try: - result = self.context["settings"].parse_endpoint( - token=d["token"], - version=d["api_ver"], - ckey_header=d["ckey_header"], - auth_header=d["auth_header"], - ) - except (VapidAuthException): - raise InvalidRequest("missing authorization header", - status_code=401, errno=109) - except (InvalidTokenException, InvalidToken): - raise InvalidRequest("invalid token", status_code=404, errno=102) - return result - - @validates_schema(skip_on_field_errors=True) - def validate_uaid(self, d): - try: - result = self.context["settings"].router.get_uaid(d["uaid"].hex) - except ItemNotFound: - raise InvalidRequest("UAID not found", status_code=410, errno=103) - - if result.get("router_type") not in ["webpush", "gcm", "apns", "fcm"]: - raise InvalidRequest("Wrong URL for user", errno=108) - - if result.get("critical_failure"): - raise InvalidRequest("Critical Failure: %s" % - result.get("critical_failure"), - status_code=410, - errno=105) - - # Propagate the looked up user data back out - d["user_data"] = result - - -class WebPushHeaderSchema(Schema): - authorization = fields.String() - crypto_key = fields.String(load_from="crypto-key") - content_encoding = fields.String(load_from="content-encoding") - encryption = fields.String() - encryption_key = fields.String(load_from="encryption-key") - ttl = fields.Integer(required=False, missing=None) - topic = fields.String(required=False, missing=None) - api_ver = fields.String() - - @validates('topic') - def validate_topic(self, value): - if value is None: - return True - - if len(value) > 32: - raise InvalidRequest("Topic must be no greater than 32 " - "characters", errno=113) - - if not VALID_BASE64_URL.match(value): - raise InvalidRequest("Topic must be URL and Filename safe Base" - "64 alphabet", errno=113) - - @validates_schema - def validate_cypto_headers(self, d): - # Not allowed to use aesgcm128 + a crypto_key - if (d.get("content_encoding", "").lower() == "aesgcm128" and - d.get("crypto_key")): - wpe_url = ("https://developers.google.com/web/updates/2016/03/" - "web-push-encryption") - raise InvalidRequest( - message="You're using outdated encryption; " - "Please update to the format described in " + wpe_url, - errno=110, - ) - - # These both can't be present - if "encryption_key" in d and "crypto_key" in d: - raise InvalidRequest("Invalid crypto headers", errno=110) - - # Cap TTL - if 'ttl' in d: - d["ttl"] = min(d["ttl"], MAX_TTL) - - @post_load - def fixup_headers(self, d): - return {k.replace("_", "-"): v for k, v in d.items()} - - -class WebPushRequestSchema(Schema): - subscription = fields.Nested(WebPushSubscriptionSchema, - load_from="token_info") - headers = fields.Nested(WebPushHeaderSchema) - body = fields.Raw() - token_info = fields.Raw() - - @validates('body') - def validate_data(self, value): - max_data = self.context["settings"].max_data - if value and len(value) > max_data: - raise InvalidRequest( - "Data payload must be smaller than {}".format(max_data), - errno=104, - ) - - @validates_schema(skip_on_field_errors=True) - def ensure_encoding_with_data(self, d): - # This runs before nested schemas, so we use the - separated - # field name - req_fields = ["content-encoding", "encryption"] - if d.get("body"): - if not all([x in d["headers"] for x in req_fields]): - raise InvalidRequest("Client error", status_code=400, - errno=110) - if (d["headers"].get("crypto-key") and - "dh=" not in d["headers"]["crypto-key"]): - raise InvalidRequest( - "Crypto-Key header missing public-key 'dh' value", - status_code=401, - errno=110) - if (d["headers"].get("encryption") and - "salt=" not in d["headers"]["encryption"]): - raise InvalidRequest( - "Encryption header missing 'salt' value", - status_code=401, - errno=110) - - @pre_load - def token_prep(self, d): - d["token_info"] = dict( - api_ver=d["path_kwargs"].get("api_ver"), - token=d["path_kwargs"].get("token"), - ckey_header=d["headers"].get("crypto-key", ""), - auth_header=d["headers"].get("authorization", ""), - ) - return d - - def validate_auth(self, d): - auth = d["headers"].get("authorization") - needs_auth = d["token_info"]["api_ver"] == "v2" - if not auth and not needs_auth: - return - - public_key = d["subscription"].get("public_key") - try: - auth_type, token = auth.split(' ', 1) - except ValueError: - raise InvalidRequest("Invalid Authorization Header", - status_code=401, errno=109, - headers={"www-authenticate": PREF_SCHEME}) - - # If its not a bearer token containing what may be JWT, stop - if auth_type.lower() not in AUTH_SCHEMES or '.' not in token: - if needs_auth: - raise InvalidRequest("Missing Authorization Header", - status_code=401, errno=109) - return - - try: - jwt = extract_jwt(token, public_key) - except (AssertionError, ValueError, JOSEError): - raise InvalidRequest("Invalid Authorization Header", - status_code=401, errno=109, - headers={"www-authenticate": PREF_SCHEME}) - if jwt.get('exp', 0) < time.time(): - raise InvalidRequest("Invalid bearer token: Auth expired", - status_code=401, errno=109, - headers={"www-authenticate": PREF_SCHEME}) - jwt_crypto_key = base64url_encode(public_key) - d["jwt"] = dict(jwt_crypto_key=jwt_crypto_key, jwt_data=jwt) - - @post_load - def fixup_output(self, d): - # Verify authorization - # Note: This has to be done here, since schema validation takes place - # before nested schemas, and in this case we need all the nested - # schema logic to run first. - self.validate_auth(d) - - # Base64-encode data for Web Push - d["body"] = base64url_encode(d["body"]) - - # Set the notification based on the validated request schema data - d["notification"] = WebPushNotification.from_webpush_request_schema( - data=d, fernet=self.context["settings"].fernet, - legacy=self.context["settings"]._notification_legacy, - ) - return d diff --git a/autopush/web/simplepush.py b/autopush/web/simplepush.py index 382e89f5..8ad3aabf 100644 --- a/autopush/web/simplepush.py +++ b/autopush/web/simplepush.py @@ -1,14 +1,98 @@ import time +import urlparse +from boto.dynamodb2.exceptions import ItemNotFound +from marshmallow import ( + Schema, + fields, + pre_load, + validates, + validates_schema, +) from twisted.internet.defer import Deferred +from autopush.exceptions import ( + InvalidRequest, + InvalidTokenException, +) + from autopush.db import hasher from autopush.web.base import ( threaded_validate, Notification, BaseWebHandler, ) -from autopush.web.push_validation import SimplePushRequestSchema + + +class SimplePushSubscriptionSchema(Schema): + uaid = fields.UUID(required=True) + chid = fields.UUID(required=True) + + @pre_load + def extract_subscription(self, d): + try: + result = self.context["settings"].parse_endpoint( + token=d["token"], + version=d["api_ver"], + ) + except InvalidTokenException: + raise InvalidRequest("invalid token", errno=102) + return result + + @validates_schema + def validate_uaid_chid(self, d): + try: + result = self.context["settings"].router.get_uaid(d["uaid"].hex) + except ItemNotFound: + raise InvalidRequest("UAID not found", status_code=410, errno=103) + + if result.get("router_type") != "simplepush": + raise InvalidRequest("Wrong URL for user", errno=108) + + # Propagate the looked up user data back out + d["user_data"] = result + + +class SimplePushRequestSchema(Schema): + subscription = fields.Nested(SimplePushSubscriptionSchema, + load_from="token_info") + version = fields.Integer(missing=time.time) + data = fields.String(missing=None) + + @validates('data') + def validate_data(self, value): + max_data = self.context["settings"].max_data + if value and len(value) > max_data: + raise InvalidRequest( + "Data payload must be smaller than {}".format(max_data), + errno=104, + ) + + @pre_load + def token_prep(self, d): + d["token_info"] = dict( + api_ver=d["path_kwargs"].get("api_ver"), + token=d["path_kwargs"].get("token"), + ) + return d + + @pre_load + def extract_fields(self, d): + body_string = d["body"] + if len(body_string) > 0: + body_args = urlparse.parse_qs(body_string, keep_blank_values=True) + version = body_args.get("version") + data = body_args.get("data") + else: + version = d["arguments"].get("version") + data = d["arguments"].get("data") + version = version[0] if version is not None else version + data = data[0] if data is not None else data + if version and version >= "1": + d["version"] = version + if data: + d["data"] = data + return d class SimplePushHandler(BaseWebHandler): diff --git a/autopush/web/webpush.py b/autopush/web/webpush.py index 3ad0cd63..bd225000 100644 --- a/autopush/web/webpush.py +++ b/autopush/web/webpush.py @@ -1,12 +1,329 @@ +import re import time +from boto.dynamodb2.exceptions import ItemNotFound +from cryptography.fernet import InvalidToken +from jose import JOSEError +from marshmallow import ( + Schema, + fields, + pre_load, + post_load, + validates, + validates_schema, +) +from marshmallow_polyfield import PolyField +from marshmallow.validate import OneOf from twisted.internet.defer import Deferred from twisted.internet.threads import deferToThread +from autopush.crypto_key import CryptoKey from autopush.db import dump_uaid, hasher -from autopush.utils import ms_time -from autopush.web.base import threaded_validate, BaseWebHandler -from autopush.web.push_validation import WebPushRequestSchema +from autopush.exceptions import ( + InvalidRequest, + InvalidTokenException, + VapidAuthException, +) +from autopush.utils import ( + base64url_encode, + extract_jwt, + ms_time, + WebPushNotification, +) +from autopush.web.base import ( + AUTH_SCHEMES, + threaded_validate, + BaseWebHandler, + PREF_SCHEME, +) + +MAX_TTL = 60 * 60 * 24 * 60 + +# Base64 URL validation +VALID_BASE64_URL = re.compile(r'^[0-9A-Za-z\-_]+=*$') + + +class WebPushSubscriptionSchema(Schema): + uaid = fields.UUID(required=True) + chid = fields.UUID(required=True) + public_key = fields.Raw(missing=None) + + @pre_load + def extract_subscription(self, d): + try: + result = self.context["settings"].parse_endpoint( + token=d["token"], + version=d["api_ver"], + ckey_header=d["ckey_header"], + auth_header=d["auth_header"], + ) + except (VapidAuthException): + raise InvalidRequest("missing authorization header", + status_code=401, errno=109) + except (InvalidTokenException, InvalidToken): + raise InvalidRequest("invalid token", status_code=404, errno=102) + return result + + @validates_schema(skip_on_field_errors=True) + def validate_uaid(self, d): + try: + result = self.context["settings"].router.get_uaid(d["uaid"].hex) + except ItemNotFound: + raise InvalidRequest("UAID not found", status_code=410, errno=103) + + if result.get("router_type") not in ["webpush", "gcm", "apns", "fcm"]: + raise InvalidRequest("Wrong URL for user", errno=108) + + if result.get("critical_failure"): + raise InvalidRequest("Critical Failure: %s" % + result.get("critical_failure"), + status_code=410, + errno=105) + + # Propagate the looked up user data back out + d["user_data"] = result + + +class WebPushBasicHeaderSchema(Schema): + authorization = fields.String() + ttl = fields.Integer(required=False, missing=None) + topic = fields.String(required=False, missing=None) + api_ver = fields.String() + + @validates('topic') + def validate_topic(self, value): + if value is None: + return True + + if len(value) > 32: + raise InvalidRequest("Topic must be no greater than 32 " + "characters", errno=113) + + if not VALID_BASE64_URL.match(value): + raise InvalidRequest("Topic must be URL and Filename safe Base" + "64 alphabet", errno=113) + + @post_load + def cap_ttl(self, d): + if 'ttl' in d: + d["ttl"] = min(d["ttl"], MAX_TTL) + + +class WebPushCrypto01HeaderSchema(Schema): + """Validates WebPush Message Encryption + + Uses draft-ietf-webpush-encryption-01 rules for validation. + + """ + content_encoding = fields.String( + required=True, + load_from="content-encoding", + validate=OneOf(["aesgcm128"]) + ) + encryption = fields.String(required=True) + encryption_key = fields.String( + required=True, + load_from="encryption-key" + ) + crypto_key = fields.String(load_from="crypto-key") + + @validates("encryption") + def validate_encryption(self, value): + """Must contain a salt value""" + ck = CryptoKey(value) + salt = ck.get_label("salt") + if not salt or not VALID_BASE64_URL.match(salt): + raise InvalidRequest("Invalid salt value in Encryption header", + status_code=400, + errno=110) + + @validates("crypto_key") + def validate_crypto_key(self, value): + """Must not contain a dh value""" + ck = CryptoKey(value) + dh = ck.get_label("dh") + if dh: + raise InvalidRequest( + "dh value in Crypto-Key header not valid for 01 or earlier " + "webpush-encryption", + status_code=400, + errno=110, + ) + + @validates("encryption_key") + def validate_encryption_key(self, value): + """Must contain a dh value""" + ck = CryptoKey(value) + dh = ck.get_label("dh") + if not dh or not VALID_BASE64_URL.match("dh"): + raise InvalidRequest("Invalid dh value in Encryption-Key header", + status_code=400, + errno=110) + + +class WebPushCrypto04HeaderSchema(Schema): + """Validates WebPush Message Encryption + + Uses draft-ietf-webpush-encryption-04 rules for validation. + + """ + content_encoding = fields.String( + required=True, + load_from="content-encoding", + validate=OneOf(["aesgcm"]) + ) + encryption = fields.String(required=True) + crypto_key = fields.String( + required=True, + load_from="crypto-key", + ) + + @validates("encryption") + def validate_encryption(self, value): + """Must contain a salt value""" + ck = CryptoKey(value) + salt = ck.get_label("salt") + if not salt or not VALID_BASE64_URL.match(salt): + raise InvalidRequest("Invalid salt value in Encryption header", + status_code=400, + errno=110) + + @validates("crypto_key") + def validate_crypto_key(self, value): + """Must contain a dh value""" + ck = CryptoKey(value) + dh = ck.get_label("dh") + if not dh or not VALID_BASE64_URL.match("dh"): + raise InvalidRequest("Invalid dh value in Encryption-Key header", + status_code=400, + errno=110) + + @validates_schema(pass_original=True) + def check_unknown_fields(self, data, original_data): + if "encryption-key" in original_data: + raise InvalidRequest( + "Encryption-Key header not valid for 02 or later " + "webpush-encryption", + status_code=400, + errno=110, + ) + + +class WebPushInvalidContentEncodingSchema(Schema): + """Returned to raise an Invalid Content-encoding error""" + @validates_schema + def invalid_content_encoding(self, d): + raise InvalidRequest( + "Unknown Content-Encoding", + status_code=400, + errno=110 + ) + + +def conditional_crypto_deserialize(object_dict, parent_object_dict): + """Return the WebPush Crypto Schema if there's a data payload""" + if parent_object_dict.get("body"): + encoding = object_dict.get("content-encoding") + # Validate the crypto headers appropriately + if encoding == "aesgcm128": + return WebPushCrypto01HeaderSchema() + elif encoding == "aesgcm": + return WebPushCrypto04HeaderSchema() + else: + return WebPushInvalidContentEncodingSchema() + else: + return Schema() + + +class WebPushRequestSchema(Schema): + subscription = fields.Nested(WebPushSubscriptionSchema, + load_from="token_info") + headers = fields.Nested(WebPushBasicHeaderSchema) + crypto_headers = PolyField( + load_from="headers", + deserialization_schema_selector=conditional_crypto_deserialize, + ) + body = fields.Raw() + token_info = fields.Raw() + + @validates('body') + def validate_data(self, value): + max_data = self.context["settings"].max_data + if value and len(value) > max_data: + raise InvalidRequest( + "Data payload must be smaller than {}".format(max_data), + errno=104, + ) + + @pre_load + def token_prep(self, d): + d["token_info"] = dict( + api_ver=d["path_kwargs"].get("api_ver"), + token=d["path_kwargs"].get("token"), + ckey_header=d["headers"].get("crypto-key", ""), + auth_header=d["headers"].get("authorization", ""), + ) + return d + + def validate_auth(self, d): + auth = d["headers"].get("authorization") + needs_auth = d["token_info"]["api_ver"] == "v2" + if not auth and not needs_auth: + return + + public_key = d["subscription"].get("public_key") + try: + auth_type, token = auth.split(' ', 1) + except ValueError: + raise InvalidRequest("Invalid Authorization Header", + status_code=401, errno=109, + headers={"www-authenticate": PREF_SCHEME}) + + # If its not a bearer token containing what may be JWT, stop + if auth_type.lower() not in AUTH_SCHEMES or '.' not in token: + if needs_auth: + raise InvalidRequest("Missing Authorization Header", + status_code=401, errno=109) + return + + try: + jwt = extract_jwt(token, public_key) + except (AssertionError, ValueError, JOSEError): + raise InvalidRequest("Invalid Authorization Header", + status_code=401, errno=109, + headers={"www-authenticate": PREF_SCHEME}) + if jwt.get('exp', 0) < time.time(): + raise InvalidRequest("Invalid bearer token: Auth expired", + status_code=401, errno=109, + headers={"www-authenticate": PREF_SCHEME}) + jwt_crypto_key = base64url_encode(public_key) + d["jwt"] = dict(jwt_crypto_key=jwt_crypto_key, jwt_data=jwt) + + @post_load + def fixup_output(self, d): + # Verify authorization + # Note: This has to be done here, since schema validation takes place + # before nested schemas, and in this case we need all the nested + # schema logic to run first. + self.validate_auth(d) + + # Merge crypto headers back in + if d["crypto_headers"]: + d["headers"].update( + {k.replace("_", "-"): v for k, v in + d["crypto_headers"].items()} + ) + + # Base64-encode data for Web Push + d["body"] = base64url_encode(d["body"]) + + # Set the notification based on the validated request schema data + d["notification"] = WebPushNotification.from_webpush_request_schema( + data=d, fernet=self.context["settings"].fernet, + legacy=self.context["settings"]._notification_legacy, + ) + + return d class WebPushHandler(BaseWebHandler): diff --git a/base-requirements.txt b/base-requirements.txt index b18a60b9..133ed610 100644 --- a/base-requirements.txt +++ b/base-requirements.txt @@ -23,6 +23,7 @@ idna==2.1 ipaddress==1.0.16 jmespath==0.9.0 marshmallow==2.10.2 +marshmallow_polyfield==3.1 pyOpenSSL==16.1.0 pyasn1==0.1.9 pyasn1-modules==0.0.8