From 7a5d485060a5ba5a09b62795bde740fd1b701c63 Mon Sep 17 00:00:00 2001 From: Ben Bangert Date: Tue, 15 May 2018 14:11:20 -0700 Subject: [PATCH] feat: port migrate user to Rust, remove Python calling Closes #1206 --- autopush/tests/test_rs_integration.py | 273 +++++++++++++++++++++++++- autopush/tests/test_webpush_server.py | 226 --------------------- autopush/webpush_server.py | 254 +----------------------- autopush_rs/__init__.py | 45 +---- autopush_rs/build.rs | 16 ++ autopush_rs/src/call.rs | 182 ----------------- autopush_rs/src/client.rs | 32 +-- autopush_rs/src/lib.rs | 2 - autopush_rs/src/queue.rs | 77 -------- autopush_rs/src/server/mod.rs | 26 +-- autopush_rs/src/util/ddb_helpers.rs | 118 ++++++++++- 11 files changed, 428 insertions(+), 823 deletions(-) delete mode 100644 autopush/tests/test_webpush_server.py create mode 100644 autopush_rs/build.rs delete mode 100644 autopush_rs/src/call.rs delete mode 100644 autopush_rs/src/queue.rs diff --git a/autopush/tests/test_rs_integration.py b/autopush/tests/test_rs_integration.py index be7b388a..7d498ef9 100644 --- a/autopush/tests/test_rs_integration.py +++ b/autopush/tests/test_rs_integration.py @@ -25,12 +25,18 @@ import requests import twisted.internet.base from cryptography.fernet import Fernet -from twisted.internet.defer import inlineCallbacks, returnValue +from twisted.internet.defer import inlineCallbacks, returnValue, Deferred +from twisted.internet.threads import deferToThread from twisted.trial import unittest from twisted.logger import globalLogPublisher import autopush.tests from autopush.config import AutopushConfig +from autopush.db import ( + get_month, + has_connected_this_month, + Message, +) from autopush.logging import begin_or_register from autopush.main import EndpointApplication, RustConnectionApplication from autopush.utils import base64url_encode @@ -733,6 +739,271 @@ def test_delete_saved_notification(self): assert result is None yield self.shut_down(client) + + @inlineCallbacks + def test_webpush_monthly_rotation(self): + from autopush.db import make_rotating_tablename + client = yield self.quick_register() + yield client.disconnect() + + # Move the client back one month to the past + last_month = make_rotating_tablename( + prefix=self.conn.conf.message_table.tablename, delta=-1) + lm_message = Message(last_month, boto_resource=self.conn.db.resource) + yield deferToThread( + self.conn.db.router.update_message_month, + client.uaid, + last_month, + ) + + # Verify the move + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid) + assert c["current_month"] == last_month + + # Verify last_connect is current, then move that back + assert has_connected_this_month(c) + today = get_month(delta=-1) + last_connect = int("%s%s020001" % (today.year, + str(today.month).zfill(2))) + + yield deferToThread( + self.conn.db.router._update_last_connect, + client.uaid, + last_connect) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid) + assert has_connected_this_month(c) is False + + # Move the clients channels back one month + exists, chans = yield deferToThread( + self.conn.db.message.all_channels, + client.uaid + ) + assert exists is True + assert len(chans) == 1 + yield deferToThread( + lm_message.save_channels, + client.uaid, + chans, + ) + + # Remove the channels entry entirely from this month + yield deferToThread( + self.conn.db.message.table.delete_item, + Key={'uaid': client.uaid, 'chidmessageid': ' '} + ) + + # Verify the channel is gone + exists, chans = yield deferToThread( + self.conn.db.message.all_channels, + client.uaid, + ) + assert exists is False + assert len(chans) == 0 + + # Send in a notification, verify it landed in last months notification + # table + data = uuid.uuid4().hex + with self.legacy_endpoint(): + yield client.send_notification(data=data) + ts, notifs = yield deferToThread(lm_message.fetch_timestamp_messages, + uuid.UUID(client.uaid), + " ") + assert len(notifs) == 1 + + # Connect the client, verify the migration + yield client.connect() + yield client.hello() + + # Pull down the notification + result = yield client.get_notification() + chan = client.channels.keys()[0] + assert result is not None + assert chan == result["channelID"] + + # Acknowledge the notification, which triggers the migration + yield client.ack(chan, result["version"]) + + # Wait up to 4 seconds for the table rotation to occur + start = time.time() + while time.time()-start < 4: + c = yield deferToThread( + self.conn.db.router.get_uaid, + client.uaid) + if c["current_month"] == self.conn.db.current_msg_month: + break + else: + yield deferToThread(time.sleep, 0.2) + + # Verify the month update in the router table + c = yield deferToThread( + self.conn.db.router.get_uaid, + client.uaid) + assert c["current_month"] == self.conn.db.current_msg_month + + # Verify the client moved last_connect + assert has_connected_this_month(c) is True + + # Verify the channels were moved + exists, chans = yield deferToThread( + self.conn.db.message.all_channels, + client.uaid + ) + assert exists is True + assert len(chans) == 1 + yield self.shut_down(client) + + + @inlineCallbacks + def test_webpush_monthly_rotation_prior_record_exists(self): + from autopush.db import make_rotating_tablename + client = yield self.quick_register() + yield client.disconnect() + + # Move the client back one month to the past + last_month = make_rotating_tablename( + prefix=self.conn.conf.message_table.tablename, delta=-1) + lm_message = Message(last_month, + boto_resource=autopush.tests.boto_resource) + yield deferToThread( + self.conn.db.router.update_message_month, + client.uaid, + last_month, + ) + + # Verify the move + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid) + assert c["current_month"] == last_month + + # Verify last_connect is current, then move that back + assert has_connected_this_month(c) + today = get_month(delta=-1) + yield deferToThread( + self.conn.db.router._update_last_connect, + client.uaid, + int("%s%s020001" % (today.year, str(today.month).zfill(2))), + ) + c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + assert has_connected_this_month(c) is False + + # Move the clients channels back one month + exists, chans = yield deferToThread( + self.conn.db.message.all_channels, + client.uaid, + ) + assert exists is True + assert len(chans) == 1 + yield deferToThread( + lm_message.save_channels, + client.uaid, + chans, + ) + + # Send in a notification, verify it landed in last months notification + # table + data = uuid.uuid4().hex + with self.legacy_endpoint(): + yield client.send_notification(data=data) + _, notifs = yield deferToThread(lm_message.fetch_timestamp_messages, + uuid.UUID(client.uaid), + " ") + assert len(notifs) == 1 + + # Connect the client, verify the migration + yield client.connect() + yield client.hello() + + # Pull down the notification + result = yield client.get_notification() + chan = client.channels.keys()[0] + assert result is not None + assert chan == result["channelID"] + + # Acknowledge the notification, which triggers the migration + yield client.ack(chan, result["version"]) + + # Wait up to 4 seconds for the table rotation to occur + start = time.time() + while time.time()-start < 4: + c = yield deferToThread( + self.conn.db.router.get_uaid, + client.uaid) + if c["current_month"] == self.conn.db.current_msg_month: + break + else: + yield deferToThread(time.sleep, 0.2) + + # Verify the month update in the router table + c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + assert c["current_month"] == self.conn.db.current_msg_month + + # Verify the client moved last_connect + assert has_connected_this_month(c) is True + + # Verify the channels were moved + exists, chans = yield deferToThread( + self.conn.db.message.all_channels, + client.uaid + ) + assert exists is True + assert len(chans) == 1 + yield self.shut_down(client) + + @inlineCallbacks + def test_webpush_monthly_rotation_no_channels(self): + from autopush.db import make_rotating_tablename + client = Client("ws://localhost:{}/".format(self.connection_port)) + yield client.connect() + yield client.hello() + yield client.disconnect() + + # Move the client back one month to the past + last_month = make_rotating_tablename( + prefix=self.conn.conf.message_table.tablename, delta=-1) + yield deferToThread( + self.conn.db.router.update_message_month, + client.uaid, + last_month + ) + + # Verify the move + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid + ) + assert c["current_month"] == last_month + + # Verify there's no channels + exists, chans = yield deferToThread( + self.conn.db.message.all_channels, + client.uaid, + ) + assert exists is False + assert len(chans) == 0 + + # Connect the client, verify the migration + yield client.connect() + yield client.hello() + + # Wait up to 2 seconds for the table rotation to occur + start = time.time() + while time.time()-start < 2: + c = yield deferToThread( + self.conn.db.router.get_uaid, + client.uaid, + ) + if c["current_month"] == self.conn.db.current_msg_month: + break + else: + yield deferToThread(time.sleep, 0.2) + + # Verify the month update in the router table + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid) + assert c["current_month"] == self.conn.db.current_msg_month + yield self.shut_down(client) + @inlineCallbacks def test_with_key(self): private_key = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p) diff --git a/autopush/tests/test_webpush_server.py b/autopush/tests/test_webpush_server.py deleted file mode 100644 index 33715a01..00000000 --- a/autopush/tests/test_webpush_server.py +++ /dev/null @@ -1,226 +0,0 @@ -import random -import time -import unittest -from threading import Event -from uuid import uuid4, UUID - -import attr -import factory -from mock import Mock -from twisted.logger import globalLogPublisher - -from autopush.db import ( - DatabaseManager, - generate_last_connect, - make_rotating_tablename, - Message, -) -from autopush.metrics import SinkMetrics -from autopush.config import AutopushConfig -from autopush.exceptions import ItemNotFound -from autopush.logging import begin_or_register -from autopush.tests.support import TestingLogObserver -from autopush.utils import WebPushNotification, ns_time -from autopush.websocket import USER_RECORD_VERSION -from autopush.webpush_server import ( - MigrateUser, - WebPushMessage, -) -import autopush.tests - - -class AutopushCall(object): - """Placeholder object for real Rust binding one""" - called = Event() - val = None - payload = None - - def complete(self, ret): - self.val = ret - self.called.set() - - def json(self): - return self.payload - - -class UserItemFactory(factory.Factory): - class Meta: - model = dict - - uaid = factory.LazyFunction(lambda: uuid4().hex) - connected_at = factory.LazyFunction(lambda: int(time.time() * 1000)-10000) - node_id = "http://something:3242/" - router_type = "webpush" - last_connect = factory.LazyFunction(generate_last_connect) - record_version = USER_RECORD_VERSION - current_month = factory.LazyFunction( - lambda: make_rotating_tablename("message") - ) - - -def generate_random_headers(): - return dict( - encryption="aesgcm128", - encryption_key="someneatkey", - crypto_key="anotherneatkey", - ) - - -class WebPushNotificationFactory(factory.Factory): - class Meta: - model = WebPushNotification - - uaid = factory.LazyFunction(uuid4) - channel_id = factory.LazyFunction(uuid4) - ttl = 86400 - data = factory.LazyFunction( - lambda: random.randint(30, 4096) * "*" - ) - headers = factory.LazyFunction(generate_random_headers) - - -def generate_version(obj): - if obj.topic: - msg_key = ":".join(["01", obj.uaid, obj.channelID.hex, - obj.topic]) - else: - sortkey_timestamp = ns_time() - msg_key = ":".join(["02", obj.uaid, obj.channelID.hex, - str(sortkey_timestamp)]) - # Technically this should be fernet encrypted, but this is fine for - # testing here - return msg_key - - -class WebPushMessageFactory(factory.Factory): - class Meta: - model = WebPushMessage - - uaid = factory.LazyFunction(lambda: str(uuid4())) - channelID = factory.LazyFunction(uuid4) - ttl = 86400 - data = factory.LazyFunction( - lambda: random.randint(30, 4096) * "*" - ) - topic = None - timestamp = factory.LazyFunction(lambda: int(time.time() * 1000)) - headers = factory.LazyFunction(generate_random_headers) - version = factory.LazyAttribute(generate_version) - - -def webpush_messages(obj): - return [attr.asdict(WebPushMessageFactory(uaid=obj.uaid)) - for _ in range(obj.message_count)] - - -class BaseSetup(unittest.TestCase): - def setUp(self): - self.conf = AutopushConfig( - hostname="localhost", - resolve_hostname=True, - port=8080, - router_port=8081, - statsd_host=None, - env="test", - auto_ping_interval=float(300), - auto_ping_timeout=float(10), - close_handshake_timeout=10, - max_connections=2000000, - ) - - self.logs = TestingLogObserver() - begin_or_register(self.logs) - self.addCleanup(globalLogPublisher.removeObserver, self.logs) - - self.db = db = DatabaseManager.from_config( - self.conf, - resource=autopush.tests.boto_resource) - self.metrics = db.metrics = Mock(spec=SinkMetrics) - db.setup_tables() - - def _store_messages(self, uaid, topic=False, num=5): - try: - item = self.db.router.get_uaid(uaid.hex) - message_table = Message( - item["current_month"], - boto_resource=autopush.tests.boto_resource) - except ItemNotFound: - message_table = self.db.message - messages = [WebPushNotificationFactory(uaid=uaid) - for _ in range(num)] - channels = set([m.channel_id for m in messages]) - for channel in channels: - message_table.register_channel(uaid.hex, channel.hex) - for idx, notif in enumerate(messages): - if topic: - notif.topic = "something_{}".format(idx) - notif.generate_message_id(self.conf.fernet) - message_table.store_message(notif) - return messages - - -class TestWebPushServer(BaseSetup): - def _makeFUT(self): - from autopush.webpush_server import WebPushServer - return WebPushServer(self.conf, self.db, num_threads=2) - - def test_start_stop(self): - ws = self._makeFUT() - ws.start() - try: - assert len(ws.workers) == 2 - finally: - ws.stop() - - -class TestMigrateUserProcessor(BaseSetup): - def _makeFUT(self): - from autopush.webpush_server import MigrateUserCommand - return MigrateUserCommand(self.conf, self.db) - - def test_migrate_user(self): - migrate_command = self._makeFUT() - - # Create a user - last_month = make_rotating_tablename("message", delta=-1) - user = UserItemFactory(current_month=last_month) - uaid = user["uaid"] - self.db.router.register_user(user) - - # Store some messages so we have some channels - self._store_messages(UUID(uaid), num=3) - - # Check that it's there - item = self.db.router.get_uaid(uaid) - _, channels = Message( - last_month, - boto_resource=self.db.resource).all_channels(uaid) - assert item["current_month"] != self.db.current_msg_month - assert item is not None - assert len(channels) == 3 - - # Migrate it - migrate_command.process( - MigrateUser(uaid=uaid, message_month=last_month) - ) - - # Check that it's in the new spot - item = self.db.router.get_uaid(uaid) - _, channels = self.db.message.all_channels(uaid) - assert item["current_month"] == self.db.current_msg_month - assert item is not None - assert len(channels) == 3 - - def test_no_migrate(self): - self.conf.allow_table_rotation = False - self.conf.message_table.tablename = "message_int_test" - self.db = db = DatabaseManager.from_config( - self.conf, - resource=autopush.tests.boto_resource - ) - assert self.db.allow_table_rotation is False - db.setup_tables() - tablename = autopush.tests.boto_resource.get_latest_message_tablename( - prefix="message_int_test" - ) - assert db.message.tablename == tablename diff --git a/autopush/webpush_server.py b/autopush/webpush_server.py index 1d3eb057..edde690e 100644 --- a/autopush/webpush_server.py +++ b/autopush/webpush_server.py @@ -1,33 +1,14 @@ """WebPush Server """ -from threading import Thread -from uuid import UUID - -import attr -from attr import ( - attrs, - attrib, -) -from typing import ( # noqa - Dict, - List, - Optional, - Tuple -) from twisted.logger import Logger from autopush.db import ( # noqa DatabaseManager, - Message, ) from autopush.config import AutopushConfig # noqa -from autopush.metrics import IMetrics # noqa -from autopush.web.webpush import MAX_TTL -from autopush.types import JSONDict # noqa -from autopush.utils import WebPushNotification -from autopush_rs import AutopushServer, AutopushQueue # noqa +from autopush_rs import AutopushServer # noqa log = Logger() @@ -35,112 +16,6 @@ _STOP = object() -# Conversion functions -def uaid_from_str(input): - # type: (Optional[str]) -> Optional[UUID] - """Parse a uaid and verify the raw input matches the hex version (no - dashes)""" - try: - uuid = UUID(input) - if uuid.hex != input: - return None - return uuid - except (TypeError, ValueError): - return None - - -def dict_to_webpush_message(input): - if isinstance(input, dict): - return WebPushMessage( - uaid=input.get("uaid"), - timestamp=input["timestamp"], - channelID=input["channelID"], - ttl=input["ttl"], - topic=input.get("topic"), - version=input["version"], - sortkey_timestamp=input.get("sortkey_timestamp"), - data=input.get("data"), - headers=input.get("headers"), - ) - return input - - -@attrs(slots=True) -class WebPushMessage(object): - """Serializable version of attributes needed for message delivery""" - uaid = attrib() # type: str - timestamp = attrib() # type: int - channelID = attrib() # type: str - ttl = attrib() # type: int - topic = attrib() # type: str - version = attrib() # type: str - sortkey_timestamp = attrib(default=None) # type: Optional[str] - data = attrib(default=None) # type: Optional[str] - headers = attrib(default=None) # type: Optional[JSONDict] - - @classmethod - def from_WebPushNotification(cls, notif): - # type: (WebPushNotification) -> WebPushMessage - p = notif.websocket_format() - del p["messageType"] - return cls( - uaid=notif.uaid.hex, - timestamp=int(notif.timestamp), - sortkey_timestamp=notif.sortkey_timestamp, - ttl=MAX_TTL if notif.ttl is None else int(notif.ttl), - topic=notif.topic, - **p - ) - - def to_WebPushNotification(self): - # type: () -> WebPushNotification - notif = WebPushNotification( - uaid=UUID(self.uaid), - channel_id=self.channelID, - data=self.data, - headers=self.headers, - ttl=self.ttl, - topic=self.topic, - timestamp=self.timestamp, - message_id=self.version, - update_id=self.version, - sortkey_timestamp=self.sortkey_timestamp, - ) - - # If there's no sortkey_timestamp and no topic, its legacy - if notif.sortkey_timestamp is None and not notif.topic: - notif.legacy = True - - return notif - - -############################################################################### -# Input messages off the incoming queue -############################################################################### -@attrs(slots=True) -class InputCommand(object): - pass - - -@attrs(slots=True) -class MigrateUser(InputCommand): - uaid = attrib(convert=uaid_from_str) # type: UUID - message_month = attrib() # type: str - - -############################################################################### -# Output messages serialized to the outgoing queue -############################################################################### -@attrs(slots=True) -class OutputCommand(object): - pass - - -@attrs(slots=True) -class MigrateUserResponse(OutputCommand): - message_month = attrib() # type: str - - ############################################################################### # Main push server class ############################################################################### @@ -151,139 +26,14 @@ def __init__(self, conf, db, num_threads=10): self.db = db self.db.setup_tables() self.num_threads = num_threads - self.incoming = AutopushQueue() - self.workers = [] # type: List[Thread] - self.command_processor = CommandProcessor(conf, self.db) - self.rust = AutopushServer(conf, db.message_tables, self.incoming) + self.rust = AutopushServer(conf, db.message_tables) self.running = False def start(self): # type: () -> None self.running = True - for _ in range(self.num_threads): - self.workers.append( - self._create_thread_worker( - processor=self.command_processor, - input_queue=self.incoming, - ) - ) self.rust.startService() def stop(self): self.running = False self.rust.stopService() - for worker in self.workers: - worker.join() - - def _create_thread_worker(self, processor, input_queue): - # type: (CommandProcessor, AutopushQueue) -> Thread - def _thread_worker(): - while self.running: - call = input_queue.recv() - try: - if call is None: - break - command = call.json() - result = processor.process_message(command) - call.complete(result) - except Exception as exc: - # TODO: Handle traceback better - import traceback - traceback.print_exc() - log.error("Exception in worker queue thread") - call.complete(dict( - error=True, - error_msg=str(exc), - )) - return self.spawn(_thread_worker) - - def spawn(self, func, *args, **kwargs): - t = Thread(target=func, args=args, kwargs=kwargs) - t.start() - return t - - -class CommandProcessor(object): - def __init__(self, conf, db): - # type: (AutopushConfig, DatabaseManager) -> None - self.conf = conf - self.db = db - self.migrate_user_proocessor = MigrateUserCommand(conf, db) - self.deserialize = dict( - migrate_user=MigrateUser, - ) - self.command_dict = dict( - migrate_user=self.migrate_user_proocessor, - ) # type: Dict[str, ProcessorCommand] - - def process_message(self, input): - # type: (JSONDict) -> JSONDict - """Process incoming message from the Rust server""" - command = input.pop("command", None) # type: str - if command not in self.command_dict: - log.critical("No command present: %s" % command) - return dict( - error=True, - error_msg="Command not found", - ) - from pprint import pformat - log.debug( - 'command: {command} {input}', - command=pformat(command), - input=input - ) - command_obj = self.deserialize[command](**input) - response = attr.asdict(self.command_dict[command].process(command_obj)) - log.debug('response: {response}', response=response) - return response - - -class ProcessorCommand(object): - """Parent class for processor commands""" - def __init__(self, conf, db): - # type: (AutopushConfig, DatabaseManager) -> None - self.conf = conf - self.db = db - - @property - def metrics(self): - # type: () -> IMetrics - return self.db.metrics - - def process(self, command): - raise NotImplementedError() - - -class MigrateUserCommand(ProcessorCommand): - def process(self, command): - # type: (MigrateUser) -> MigrateUserResponse - # Get the current channels for this month - message = Message(command.message_month, - boto_resource=self.db.resource) - _, channels = message.all_channels(command.uaid.hex) - - # Get the current message month - cur_month = self.db.current_msg_month - if channels: - # Save the current channels into this months message table - msg_table = Message(cur_month, - boto_resource=self.db.resource) - msg_table.save_channels(command.uaid.hex, - channels) - - # Finally, update the route message month - self.db.router.update_message_month(command.uaid.hex, - cur_month) - return MigrateUserResponse(message_month=cur_month) - - -def _validate_chid(chid): - # type: (str) -> Tuple[bool, Optional[str]] - """Ensure valid channel id format for register/unregister""" - try: - result = UUID(chid) - except ValueError: - return False, "Invalid UUID specified" - if chid != str(result): - return False, "Bad UUID format, use lower case, dashed format" - return True, None diff --git a/autopush_rs/__init__.py b/autopush_rs/__init__.py index a78343bb..33b4b082 100644 --- a/autopush_rs/__init__.py +++ b/autopush_rs/__init__.py @@ -19,8 +19,8 @@ def free(obj, free_fn): class AutopushServer(object): - def __init__(self, conf, message_tables, queue): - # type: (AutopushConfig, List[str], AutopushQueue) -> AutopushServer + def __init__(self, conf, message_tables): + # type: (AutopushConfig, List[str]) -> AutopushServer cfg = ffi.new('AutopushServerOptions*') cfg.auto_ping_interval = conf.auto_ping_interval cfg.auto_ping_timeout = conf.auto_ping_timeout @@ -51,12 +51,10 @@ def __init__(self, conf, message_tables, queue): ptr = _call(lib.autopush_server_new, cfg) self.ffi = ffi.gc(ptr, lib.autopush_server_free) - self.queue = queue def startService(self): _call(lib.autopush_server_start, - self.ffi, - self.queue.ffi) + self.ffi) def stopService(self): if self.ffi is None: @@ -68,43 +66,6 @@ def _free_ffi(self): free(self, lib.autopush_server_free) -class AutopushCall: - def __init__(self, ptr): - self.ffi = ffi.gc(ptr, lib.autopush_python_call_free) - - def json(self): - msg_ptr = _call(lib.autopush_python_call_input_ptr, self.ffi) - msg_len = _call(lib.autopush_python_call_input_len, self.ffi) - 1 - buf = ffi.buffer(msg_ptr, msg_len) - return json.loads(str(buf[:])) - - def complete(self, ret): - s = json.dumps(ret) - _call(lib.autopush_python_call_complete, self.ffi, s) - self._free_ffi() - - def cancel(self): - self._free_ffi() - - def _free_ffi(self): - free(self, lib.autopush_python_call_free) - - -class AutopushQueue: - def __init__(self): - ptr = _call(lib.autopush_queue_new) - self.ffi = ffi.gc(ptr, lib.autopush_queue_free) - - def recv(self): - if self.ffi is None: - return None - ret = _call(lib.autopush_queue_recv, self.ffi) - if ffi.cast('size_t', ret) == 1: - return None - else: - return AutopushCall(ret) - - last_err = None diff --git a/autopush_rs/build.rs b/autopush_rs/build.rs new file mode 100644 index 00000000..ac06bb68 --- /dev/null +++ b/autopush_rs/build.rs @@ -0,0 +1,16 @@ +//! Generate autopush.h via cbindgen +extern crate cbindgen; + +use std::env; + +fn main() { + let crate_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR undefined"); + let pkg_name = env::var("CARGO_PKG_NAME").expect("CARGO_PKG_NAME undefined"); + let target = format!("{}/target/{}.h", crate_dir, pkg_name); + cbindgen::Builder::new() + .with_crate(crate_dir) + .with_language(cbindgen::Language::C) + .generate() + .expect("cbindgen unable to generate bindings") + .write_to_file(target.as_str()); +} diff --git a/autopush_rs/src/call.rs b/autopush_rs/src/call.rs deleted file mode 100644 index 8d24486e..00000000 --- a/autopush_rs/src/call.rs +++ /dev/null @@ -1,182 +0,0 @@ -//! Implementation of calling methods/objects in python -//! -//! The main `Server` has a channel that goes back to the main python thread, -//! and that's used to send instances of `PythonCall` from the Rust thread to -//! the Python thread. Typically you won't work with `PythonCall` directly -//! though but rather the various methods on the `Server` struct, documented -//! below. Each method will return a `MyFuture` of the result, representing the -//! decoded value from Python. -//! -//! Implementation-wise what's happening here is that each function call into -//! Python creates a `futures::sync::oneshot`. The `Sender` half of this oneshot -//! is sent to Python while the `Receiver` half stays in Rust. Arguments sent to -//! Python are serialized as JSON and arguments are received from Python as JSON -//! as well, meaning that they're deserialized in Rust from JSON as well. - -use std::cell::RefCell; -use std::ffi::CStr; - -use futures::Future; -use futures::sync::oneshot; -use libc::c_char; -use serde::de; -use serde::ser; -use serde_json; - -use errors::*; -use rt::{self, AutopushError, UnwindGuard}; -use server::Server; - -pub struct AutopushPythonCall { - inner: UnwindGuard, -} - -struct Inner { - input: String, - done: RefCell>>, -} - -pub struct PythonCall { - input: String, - output: Box, -} - -#[no_mangle] -pub extern "C" fn autopush_python_call_input_ptr( - call: *mut AutopushPythonCall, - err: &mut AutopushError, -) -> *const u8 { - unsafe { (*call).inner.catch(err, |call| call.input.as_ptr()) } -} - -#[no_mangle] -pub extern "C" fn autopush_python_call_input_len( - call: *mut AutopushPythonCall, - err: &mut AutopushError, -) -> usize { - unsafe { (*call).inner.catch(err, |call| call.input.len()) } -} - -#[no_mangle] -pub extern "C" fn autopush_python_call_complete( - call: *mut AutopushPythonCall, - input: *const c_char, - err: &mut AutopushError, -) -> i32 { - unsafe { - (*call).inner.catch(err, |call| { - let input = CStr::from_ptr(input).to_str().unwrap(); - call.done.borrow_mut().take().unwrap().call(input); - }) - } -} - -#[no_mangle] -pub extern "C" fn autopush_python_call_free(call: *mut AutopushPythonCall) { - rt::abort_on_panic(|| unsafe { - Box::from_raw(call); - }) -} - -impl AutopushPythonCall { - pub fn new(call: PythonCall) -> AutopushPythonCall { - AutopushPythonCall { - inner: UnwindGuard::new(Inner { - input: call.input, - done: RefCell::new(Some(call.output)), - }), - } - } - - fn _new(input: String, f: F) -> AutopushPythonCall - where - F: FnOnce(&str) + Send + 'static, - { - AutopushPythonCall { - inner: UnwindGuard::new(Inner { - input: input, - done: RefCell::new(Some(Box::new(f))), - }), - } - } -} - -trait FnBox: Send { - fn call(self: Box, input: &str); -} - -impl FnBox for F { - fn call(self: Box, input: &str) { - (*self)(input) - } -} - -#[derive(Serialize)] -#[serde(tag = "command", rename_all = "snake_case")] -enum Call { - MigrateUser { - uaid: String, - message_month: String, - }, -} - -#[derive(Deserialize)] -struct PythonError { - pub error: bool, - pub error_msg: String, -} - -#[derive(Deserialize)] -pub struct MigrateUserResponse { - pub message_month: String, -} - -impl Server { - pub fn migrate_user( - &self, - uaid: String, - message_month: String, - ) -> MyFuture { - let (call, fut) = PythonCall::new(&Call::MigrateUser { - uaid, - message_month, - }); - self.send_to_python(call); - return fut; - } - - fn send_to_python(&self, call: PythonCall) { - self.tx.send(Some(call)).expect("python went away?"); - } -} - -impl PythonCall { - fn new(input: &T) -> (PythonCall, MyFuture) - where - T: ser::Serialize, - U: for<'de> de::Deserialize<'de> + 'static, - { - let (tx, rx) = oneshot::channel(); - let call = PythonCall { - input: serde_json::to_string(input).unwrap(), - output: Box::new(|json: &str| { - drop(tx.send(json_or_error(json))); - }), - }; - let rx = Box::new(rx.then(|res| match res { - Ok(Ok(s)) => Ok(serde_json::from_str(&s)?), - Ok(Err(e)) => Err(e), - Err(_) => Err("call canceled from python".into()), - })); - (call, rx) - } -} - -fn json_or_error(json: &str) -> Result { - if let Ok(err) = serde_json::from_str::(json) { - if err.error { - return Err(format!("python exception: {}", err.error_msg).into()); - } - } - Ok(json.to_string()) -} diff --git a/autopush_rs/src/client.rs b/autopush_rs/src/client.rs index eadb66e2..cc5dcfe6 100644 --- a/autopush_rs/src/client.rs +++ b/autopush_rs/src/client.rs @@ -21,7 +21,6 @@ use tokio_core::reactor::Timeout; use uuid::Uuid; use woothee::parser::Parser; -use call; use errors::*; use protocol::{ClientMessage, Notification, ServerMessage, ServerNotification}; use server::Server; @@ -508,14 +507,14 @@ where for notif in notifs.iter_mut() { notif.sortkey_timestamp = Some(0); } - srv.handle.spawn(srv.ddb.store_messages( - &webpush.uaid, - &webpush.message_month, - notifs, - ).then(|_| { - debug!("Finished saving unacked direct notifications"); - Ok(()) - })); + srv.handle.spawn( + srv.ddb + .store_messages(&webpush.uaid, &webpush.message_month, notifs) + .then(|_| { + debug!("Finished saving unacked direct notifications"); + Ok(()) + }), + ); } // Log out the final stats message @@ -586,7 +585,7 @@ where #[state_machine_future(transitions(DetermineAck))] AwaitMigrateUser { - response: MyFuture, + response: MyFuture, data: AuthClientData, }, @@ -690,9 +689,12 @@ where } else if all_acked && webpush.flags.check { transition!(CheckStorage { data }); } else if all_acked && webpush.flags.rotate_message_table { - let response = data.srv.migrate_user( - webpush.uaid.simple().to_string(), - webpush.message_month.clone(), + debug!("Triggering migration"); + let response = data.srv.ddb.migrate_user( + &webpush.uaid, + &webpush.message_month, + &data.srv.opts.current_message_month, + &data.srv.opts.router_table_name, ); transition!(AwaitMigrateUser { response, data }); } else if all_acked && webpush.flags.reset_uaid { @@ -945,9 +947,7 @@ where await_migrate_user: &'a mut RentToOwn<'a, AwaitMigrateUser>, ) -> Poll, Error> { debug!("State: AwaitMigrateUser"); - let message_month = match try_ready!(await_migrate_user.response.poll()) { - call::MigrateUserResponse { message_month } => message_month, - }; + let message_month =try_ready!(await_migrate_user.response.poll()); let AwaitMigrateUser { data, .. } = await_migrate_user.take(); { let mut webpush = data.webpush.borrow_mut(); diff --git a/autopush_rs/src/lib.rs b/autopush_rs/src/lib.rs index f400a38c..5fe6c12c 100644 --- a/autopush_rs/src/lib.rs +++ b/autopush_rs/src/lib.rs @@ -123,6 +123,4 @@ mod protocol; #[macro_use] pub mod rt; -pub mod call; -pub mod queue; pub mod server; diff --git a/autopush_rs/src/queue.rs b/autopush_rs/src/queue.rs deleted file mode 100644 index 04f05a65..00000000 --- a/autopush_rs/src/queue.rs +++ /dev/null @@ -1,77 +0,0 @@ -//! Thread-safe MPMC queue for working with Python and Rust -//! -//! This is created in Python and shared amongst a number of worker threads for -//! the receiving side, and then the sending side is done by the Rust thread -//! pushing requests over to Python. A `Sender` here is saved off in the -//! `Server` for sending messages. - -use std::sync::Mutex; -use std::sync::mpsc; - -use call::{AutopushPythonCall, PythonCall}; -use rt::{self, AutopushError}; - -pub struct AutopushQueue { - tx: Mutex, - rx: Mutex>>>, -} - -pub type AutopushSender = mpsc::Sender>; - -fn _assert_kinds() { - fn _assert() {} - _assert::(); -} - -#[no_mangle] -pub extern "C" fn autopush_queue_new(err: &mut AutopushError) -> *mut AutopushQueue { - rt::catch(err, || { - let (tx, rx) = mpsc::channel(); - - Box::new(AutopushQueue { - tx: Mutex::new(tx), - rx: Mutex::new(Some(rx)), - }) - }) -} - -#[no_mangle] -pub extern "C" fn autopush_queue_recv( - queue: *mut AutopushQueue, - err: &mut AutopushError, -) -> *mut AutopushPythonCall { - rt::catch(err, || unsafe { - let mut rx = (*queue).rx.lock().unwrap(); - let msg = match *rx { - // this can't panic because we hold a reference to at least one - // sender, so it'll always block waiting for the next message - Some(ref rx) => rx.recv().unwrap(), - - // the "done" message was received by someone else, so we just keep - // propagating that - None => return None, - }; - match msg { - Some(msg) => Some(Box::new(AutopushPythonCall::new(msg))), - - // the senders are done, so all future calls shoudl bail out - None => { - *rx = None; - None - } - } - }) -} - -#[no_mangle] -pub extern "C" fn autopush_queue_free(queue: *mut AutopushQueue) { - rt::abort_on_panic(|| unsafe { - Box::from_raw(queue); - }) -} - -impl AutopushQueue { - pub fn tx(&self) -> AutopushSender { - self.tx.lock().unwrap().clone() - } -} diff --git a/autopush_rs/src/server/mod.rs b/autopush_rs/src/server/mod.rs index 28bc75da..7ff07293 100644 --- a/autopush_rs/src/server/mod.rs +++ b/autopush_rs/src/server/mod.rs @@ -42,7 +42,6 @@ use errors::*; use errors::{Error, Result}; use http; use protocol::{ClientMessage, Notification, ServerMessage, ServerNotification}; -use queue::{self, AutopushQueue}; use rt::{self, AutopushError, UnwindGuard}; use server::dispatch::{Dispatch, RequestType}; use server::metrics::metrics_from_opts; @@ -104,7 +103,6 @@ pub struct Server { pub ddb: DynamoStorage, open_connections: Cell, tls_acceptor: Option, - pub tx: queue::AutopushSender, pub opts: Arc, pub handle: Handle, pub metrics: StatsdClient, @@ -124,6 +122,7 @@ pub struct ServerOptions { pub max_connections: Option, pub close_handshake_timeout: Option, pub message_table_names: Vec, + pub current_message_month: String, pub router_table_name: String, pub router_url: String, pub endpoint_url: String, @@ -195,6 +194,7 @@ pub extern "C" fn autopush_server_new( .split(",") .map(|s| s.trim().to_string()) .collect(), + current_message_month: "".to_string(), router_table_name: to_s(opts.router_table_name) .map(|s| s.to_string()) .expect("router table name must be specified"), @@ -223,6 +223,10 @@ pub extern "C" fn autopush_server_new( .expect("poll interval cannot be 0"), }; opts.message_table_names.sort_unstable(); + opts.current_message_month = opts.message_table_names + .last() + .expect("No last message month found") + .to_string(); Box::new(AutopushServer { inner: UnwindGuard::new(AutopushServerInner { @@ -236,13 +240,11 @@ pub extern "C" fn autopush_server_new( #[no_mangle] pub extern "C" fn autopush_server_start( srv: *mut AutopushServer, - queue: *mut AutopushQueue, err: &mut AutopushError, ) -> i32 { unsafe { (*srv).inner.catch(err, |srv| { - let tx = (*queue).tx(); - let handles = Server::start(&srv.opts, tx).expect("failed to start server"); + let handles = Server::start(&srv.opts).expect("failed to start server"); srv.shutdown_handles.set(Some(handles)); }) } @@ -294,7 +296,7 @@ impl Server { /// This will spawn a new server with the `opts` specified, spinning up a /// separate thread for the tokio reactor. The returned ShutdownHandles can /// be used to interact with it (e.g. shut it down). - fn start(opts: &Arc, tx: queue::AutopushSender) -> Result> { + fn start(opts: &Arc) -> Result> { let mut shutdown_handles = vec![]; if let Some(handle) = Server::start_sentry()? { shutdown_handles.push(handle); @@ -305,7 +307,7 @@ impl Server { let opts = opts.clone(); let thread = thread::spawn(move || { - let (srv, mut core) = match Server::new(&opts, tx) { + let (srv, mut core) = match Server::new(&opts) { Ok(core) => { inittx.send(None).unwrap(); core @@ -369,7 +371,7 @@ impl Server { Ok(Some(ShutdownHandle(donetx, thread))) } - fn new(opts: &Arc, tx: queue::AutopushSender) -> Result<(Rc, Core)> { + fn new(opts: &Arc) -> Result<(Rc, Core)> { let core = Core::new()?; let broadcaster = if let Some(ref megaphone_url) = opts.megaphone_api_url { let megaphone_token = opts.megaphone_api_token @@ -387,7 +389,6 @@ impl Server { uaids: RefCell::new(HashMap::new()), open_connections: Cell::new(0), handle: core.handle(), - tx: tx, tls_acceptor: tls::configure(opts), metrics: metrics_from_opts(opts)?, }); @@ -620,13 +621,6 @@ impl Server { } } -impl Drop for Server { - fn drop(&mut self) { - // we're done sending messages, close out the queue - drop(self.tx.send(None)); - } -} - enum MegaphoneState { Waiting, Requesting(MyFuture), diff --git a/autopush_rs/src/util/ddb_helpers.rs b/autopush_rs/src/util/ddb_helpers.rs index 251741cf..9e30aecf 100644 --- a/autopush_rs/src/util/ddb_helpers.rs +++ b/autopush_rs/src/util/ddb_helpers.rs @@ -664,16 +664,78 @@ impl DynamoStorage { Box::new(response) } - fn register_channel_id( + fn update_user_message_month( + ddb: Rc>, + uaid: &Uuid, + router_table_name: &str, + message_month: &str, + ) -> MyFuture<()> { + let attr_values = hashmap! { + ":curmonth".to_string() => val!(S => message_month.to_string()), + ":lastconnect".to_string() => val!(N => generate_last_connect().to_string()), + }; + let update_item = UpdateItemInput { + key: ddb_item! { uaid: s => uaid.simple().to_string() }, + update_expression: Some("SET current_month=:curmonth, last_connect=:lastconnect".to_string()), + expression_attribute_values: Some(attr_values), + table_name: router_table_name.to_string(), + ..Default::default() + }; + let ddb_response = retry_if( + move || { + ddb.update_item(&update_item) + .and_then(|_| Box::new(future::ok(()))) + }, + |err: &UpdateItemError| { + matches!(err, &UpdateItemError::ProvisionedThroughputExceeded(_)) + }, + ).chain_err(|| "Error updating user message month"); + Box::new(ddb_response) + } + + fn all_channels( ddb: Rc>, uaid: &Uuid, - channel_id: &Uuid, message_table_name: &str, - ) -> MyFuture { - let chid = channel_id.hyphenated().to_string(); + ) -> MyFuture> { + let get_input = GetItemInput { + table_name: message_table_name.to_string(), + consistent_read: Some(true), + key: ddb_item! { + uaid: s => uaid.simple().to_string(), + chidmessageid: s => " ".to_string() + }, + ..Default::default() + }; + let response = retry_if( + move || ddb.get_item(&get_input), + |err: &GetItemError| matches!(err, &GetItemError::ProvisionedThroughputExceeded(_)), + ).and_then(|get_item_output| { + let result = get_item_output.item.and_then(|item| { + let record: Option = serde_dynamodb::from_hashmap(item).ok(); + record + }); + let channels = if let Some(record) = result { + record.chids.unwrap_or_else(|| HashSet::new()) + } else { + HashSet::new() + }; + Box::new(future::ok(channels)) + }) + .or_else(|_err| Box::new(future::ok(HashSet::new()))); + Box::new(response) + } + + fn save_channels( + ddb: Rc>, + uaid: &Uuid, + channels: HashSet, + message_table_name: &str, + ) -> MyFuture<()> { + let chids: Vec = channels.into_iter().collect(); let expiry = sec_since_epoch() + 2 * MAX_EXPIRY; let attr_values = hashmap! { - ":channel_id".to_string() => val!(SS => vec![chid]), + ":chids".to_string() => val!(SS => chids), ":expiry".to_string() => val!(N => expiry), }; let update_item = UpdateItemInput { @@ -681,17 +743,20 @@ impl DynamoStorage { uaid: s => uaid.simple().to_string(), chidmessageid: s => " ".to_string() }, - update_expression: Some("ADD chids :channel_id SET expiry=:expiry".to_string()), + update_expression: Some("ADD chids :chids SET expiry=:expiry".to_string()), expression_attribute_values: Some(attr_values), table_name: message_table_name.to_string(), ..Default::default() }; let ddb_response = retry_if( - move || ddb.update_item(&update_item), + move || { + ddb.update_item(&update_item) + .and_then(|_| Box::new(future::ok(()))) + }, |err: &UpdateItemError| { matches!(err, &UpdateItemError::ProvisionedThroughputExceeded(_)) }, - ).chain_err(|| "Error registering channel"); + ).chain_err(|| "Error saving channels"); Box::new(ddb_response) } @@ -888,7 +953,9 @@ impl DynamoStorage { })) } }; - let response = DynamoStorage::register_channel_id(ddb, uaid, channel_id, message_month) + let mut chids = HashSet::new(); + chids.insert(channel_id.hyphenated().to_string()); + let response = DynamoStorage::save_channels(ddb, uaid, chids, message_month) .and_then(move |_| -> MyFuture<_> { Box::new(future::ok(RegisterResponse::Success { endpoint })) }) @@ -929,6 +996,39 @@ impl DynamoStorage { Box::new(response) } + /// Migrate a user to a new month table + pub fn migrate_user( + &self, + uaid: &Uuid, + message_month: &str, + current_message_month: &str, + router_table_name: &str, + ) -> MyFuture { + let ddb = self.ddb.clone(); + let ddb1 = self.ddb.clone(); + let ddb2 = self.ddb.clone(); + let uaid = uaid.clone(); + let cur_month = current_message_month.to_string(); + let cur_month1 = cur_month.clone(); + let cur_month2 = cur_month.clone(); + let cur_month3 = cur_month.clone(); + let router_table_name = router_table_name.to_string(); + let response = DynamoStorage::all_channels(ddb, &uaid, message_month) + .and_then(move |channels| -> MyFuture<_> { + if channels.is_empty() { + Box::new(future::ok(())) + } else { + DynamoStorage::save_channels(ddb1, &uaid, channels, &cur_month1) + } + }) + .and_then(move |_| -> MyFuture<_> { + DynamoStorage::update_user_message_month(ddb2, &uaid, &router_table_name, &cur_month2) + }) + .and_then(move |_| -> MyFuture<_> { Box::new(future::ok(cur_month3)) }) + .chain_err(|| "Unable to migrate user"); + Box::new(response) + } + /// Store a batch of messages when shutting down pub fn store_messages( &self,