diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index a6f1e7594e80..13b69d8dee5b 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -26,8 +26,8 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.storage.engines import create_engine -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.manhole import manhole @@ -36,7 +36,7 @@ from synapse import events -from twisted.internet import reactor, defer +from twisted.internet import reactor from twisted.web.resource import Resource from daemonize import Daemonize @@ -120,30 +120,23 @@ def start_listening(self, listeners): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - appservice_handler = self.get_application_service_handler() - - @defer.inlineCallbacks - def replicate(results): - stream = results.get("events") - if stream: - max_stream_id = stream["position"] - yield appservice_handler.notify_interested_services(max_stream_id) - - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - replicate(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(30) + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return ASReplicationHandler(self) + + +class ASReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(ASReplicationHandler, self).__init__(hs.get_datastore()) + self.appservice_handler = hs.get_application_service_handler() + + def on_rdata(self, stream_name, token, rows): + super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) + + if stream_name == "events": + max_stream_id = self.store.get_room_max_stream_ordering() + self.appservice_handler.notify_interested_services(max_stream_id) def start(config_options): @@ -199,7 +192,6 @@ def run(): reactor.run() def start(): - ps.replicate() ps.get_datastore().start_profiling() ps.get_state_handler().start_caching() diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index e4ea3ab933e6..9b72c649ac9a 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -30,11 +30,11 @@ from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v1.room import PublicRoomListRestServlet from synapse.server import HomeServer from synapse.storage.client_ips import ClientIpStore from synapse.storage.engines import create_engine -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.manhole import manhole @@ -45,7 +45,7 @@ from synapse import events -from twisted.internet import reactor, defer +from twisted.internet import reactor from twisted.web.resource import Resource from daemonize import Daemonize @@ -145,21 +145,10 @@ def start_listening(self, listeners): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url + self.get_tcp_replication().start_replication(self) - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(5) + def build_tcp_replication(self): + return ReplicationClientHandler(self.get_datastore()) def start(config_options): @@ -209,7 +198,6 @@ def run(): def start(): ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() - ss.replicate() reactor.callWhenRunning(start) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index e52b0f240d62..eb392e1c9d29 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -27,9 +27,9 @@ from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.directory import DirectoryStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage.engines import create_engine -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.manhole import manhole @@ -42,7 +42,7 @@ from synapse import events -from twisted.internet import reactor, defer +from twisted.internet import reactor from twisted.web.resource import Resource from daemonize import Daemonize @@ -134,21 +134,10 @@ def start_listening(self, listeners): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url + self.get_tcp_replication().start_replication(self) - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(5) + def build_tcp_replication(self): + return ReplicationClientHandler(self.get_datastore()) def start(config_options): @@ -198,7 +187,6 @@ def run(): def start(): ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() - ss.replicate() reactor.callWhenRunning(start) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 76c4cc54d171..8994891aeb13 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -31,9 +31,10 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.devices import SlavedDeviceStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.storage.engines import create_engine from synapse.storage.presence import UserPresenceState -from synapse.util.async import sleep +from synapse.util.async import Linearizer from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.manhole import manhole @@ -59,7 +60,23 @@ class FederationSenderSlaveStore( SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, SlavedRegistrationStore, SlavedDeviceStore, ): - pass + def __init__(self, db_conn, hs): + super(FederationSenderSlaveStore, self).__init__(db_conn, hs) + self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) + + def _get_federation_out_pos(self, db_conn): + sql = ( + "SELECT stream_id FROM federation_stream_position" + " WHERE type = ?" + ) + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, ("federation",)) + rows = txn.fetchall() + txn.close() + + return rows[0][0] if rows else -1 class FederationSenderServer(HomeServer): @@ -127,26 +144,29 @@ def start_listening(self, listeners): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - send_handler = FederationSenderHandler(self) - - send_handler.on_start() - - while True: - try: - args = store.stream_positions() - args.update((yield send_handler.stream_positions())) - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - yield send_handler.process_replication(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(30) + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return FederationSenderReplicationHandler(self) + + +class FederationSenderReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) + self.send_handler = FederationSenderHandler(hs) + + def on_rdata(self, stream_name, token, rows): + super(FederationSenderReplicationHandler, self).on_rdata( + stream_name, token, rows + ) + self.send_handler.process_replication_rows(stream_name, token, rows) + if stream_name == "federation": + self.send_federation_ack(token) + + def get_streams_to_replicate(self): + args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate() + args.update(self.send_handler.stream_positions()) + return args def start(config_options): @@ -205,7 +225,6 @@ def run(): reactor.run() def start(): - ps.replicate() ps.get_datastore().start_profiling() ps.get_state_handler().start_caching() @@ -233,6 +252,9 @@ def __init__(self, hs): self.store = hs.get_datastore() self.federation_sender = hs.get_federation_sender() + self.federation_position = self.store.federation_out_pos_startup + self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") + self._room_serials = {} self._room_typing = {} @@ -243,25 +265,13 @@ def on_start(self): self.store.get_room_max_stream_ordering() ) - @defer.inlineCallbacks def stream_positions(self): - stream_id = yield self.store.get_federation_out_pos("federation") - defer.returnValue({ - "federation": stream_id, - - # Ack stuff we've "processed", this should only be called from - # one process. - "federation_ack": stream_id, - }) + return {"federation": self.federation_position} - @defer.inlineCallbacks - def process_replication(self, result): + def process_replication_rows(self, stream_name, token, rows): # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. - fed_stream = result.get("federation") - if fed_stream: - latest_id = int(fed_stream["position"]) - + if stream_name == "federation": # The federation stream containis a bunch of different types of # rows that need to be handled differently. We parse the rows, put # them into the appropriate collection and then send them off. @@ -272,8 +282,9 @@ def process_replication(self, result): device_destinations = set() # Parse the rows in the stream - for row in fed_stream["rows"]: - position, typ, content_js = row + for row in rows: + typ = row.type + content_js = row.data content = json.loads(content_js) if typ == send_queue.PRESENCE_TYPE: @@ -325,16 +336,19 @@ def process_replication(self, result): for destination in device_destinations: self.federation_sender.send_device_messages(destination) - # Record where we are in the stream. - yield self.store.update_federation_out_pos( - "federation", latest_id - ) + self.update_token(token) # We also need to poke the federation sender when new events happen - event_stream = result.get("events") - if event_stream: - latest_pos = event_stream["position"] - self.federation_sender.notify_new_events(latest_pos) + elif stream_name == "events": + self.federation_sender.notify_new_events(token) + + @defer.inlineCallbacks + def update_token(self, token): + self.federation_position = token + with (yield self._fed_position_linearizer.queue(None)): + yield self.store.update_federation_out_pos( + "federation", self.federation_position + ) if __name__ == '__main__': diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2cdd2d39ffa9..990eb477e5cb 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -56,6 +56,7 @@ from synapse.metrics import register_memory_metrics, get_metrics_for from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.federation.transport.server import TransportLayerServer from synapse.util.rlimit import change_resource_limit @@ -222,6 +223,16 @@ def start_listening(self): ), interface=address ) + elif listener["type"] == "replication": + bind_addresses = listener["bind_addresses"] + for address in bind_addresses: + factory = ReplicationStreamProtocolFactory(self) + server_listener = reactor.listenTCP( + listener["port"], factory, interface=address + ) + reactor.addSystemEventTrigger( + "before", "shutdown", server_listener.stopListening, + ) else: logger.warn("Unrecognized listener type: %s", listener["type"]) diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 1444e69a4291..26c441695605 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -25,13 +25,13 @@ from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.transactions import TransactionStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.server import HomeServer from synapse.storage.client_ips import ClientIpStore from synapse.storage.engines import create_engine from synapse.storage.media_repository import MediaRepositoryStore -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.manhole import manhole @@ -45,7 +45,7 @@ from synapse import events -from twisted.internet import reactor, defer +from twisted.internet import reactor from twisted.web.resource import Resource from daemonize import Daemonize @@ -142,21 +142,10 @@ def start_listening(self, listeners): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url + self.get_tcp_replication().start_replication(self) - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(5) + def build_tcp_replication(self): + return ReplicationClientHandler(self.get_datastore()) def start(config_options): @@ -206,7 +195,6 @@ def run(): def start(): ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() - ss.replicate() reactor.callWhenRunning(start) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index ab682e52ecf1..cb76f058b0cc 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -27,9 +27,9 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.storage.engines import create_engine from synapse.storage import DataStore -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, preserve_fn, \ PreserveLoggingContext @@ -39,7 +39,7 @@ from synapse import events -from twisted.internet import reactor, defer +from twisted.internet import reactor from twisted.web.resource import Resource from daemonize import Daemonize @@ -89,7 +89,6 @@ class PusherSlaveStore( class PusherServer(HomeServer): - def get_db_conn(self, run_new_connection=True): # Any param beginning with cp_ is a parameter for adbapi, and should # not be passed to the database engine. @@ -109,16 +108,7 @@ def setup(self): logger.info("Finished setting up.") def remove_pusher(self, app_id, push_key, user_id): - http_client = self.get_simple_http_client() - replication_url = self.config.worker_replication_url - url = replication_url + "/remove_pushers" - return http_client.post_json_get_json(url, { - "remove": [{ - "app_id": app_id, - "push_key": push_key, - "user_id": user_id, - }] - }) + self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) def _listen_http(self, listener_config): port = listener_config["port"] @@ -166,73 +156,51 @@ def start_listening(self, listeners): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - pusher_pool = self.get_pusherpool() - - def stop_pusher(user_id, app_id, pushkey): - key = "%s:%s" % (app_id, pushkey) - pushers_for_user = pusher_pool.pushers.get(user_id, {}) - pusher = pushers_for_user.pop(key, None) - if pusher is None: - return - logger.info("Stopping pusher %r / %r", user_id, key) - pusher.on_stop() - - def start_pusher(user_id, app_id, pushkey): - key = "%s:%s" % (app_id, pushkey) - logger.info("Starting pusher %r / %r", user_id, key) - return pusher_pool._refresh_pusher(app_id, pushkey, user_id) - - @defer.inlineCallbacks - def poke_pushers(results): - pushers_rows = set( - map(tuple, results.get("pushers", {}).get("rows", [])) + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return PusherReplicationHandler(self) + + +class PusherReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(PusherReplicationHandler, self).__init__(hs.get_datastore()) + + self.pusher_pool = hs.get_pusherpool() + + def on_rdata(self, stream_name, token, rows): + super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) + self.poke_pushers(stream_name, token, rows) + + def poke_pushers(self, stream_name, token, rows): + if stream_name == "pushers": + for row in rows: + if row.deleted: + self.stop_pusher(row.user_id, row.app_id, row.pushkey) + else: + self.start_pusher(row.user_id, row.app_id, row.pushkey) + elif stream_name == "events": + preserve_fn(self.pusher_pool.on_new_notifications)( + token, token, ) - deleted_pushers_rows = set( - map(tuple, results.get("deleted_pushers", {}).get("rows", [])) + elif stream_name == "receipts": + preserve_fn(self.pusher_pool.on_new_receipts)( + token, token, set(row.room_id for row in rows) ) - for row in sorted(pushers_rows | deleted_pushers_rows): - if row in deleted_pushers_rows: - user_id, app_id, pushkey = row[1:4] - stop_pusher(user_id, app_id, pushkey) - elif row in pushers_rows: - user_id = row[1] - app_id = row[5] - pushkey = row[8] - yield start_pusher(user_id, app_id, pushkey) - - stream = results.get("events") - if stream and stream["rows"]: - min_stream_id = stream["rows"][0][0] - max_stream_id = stream["position"] - preserve_fn(pusher_pool.on_new_notifications)( - min_stream_id, max_stream_id - ) - - stream = results.get("receipts") - if stream and stream["rows"]: - rows = stream["rows"] - affected_room_ids = set(row[1] for row in rows) - min_stream_id = rows[0][0] - max_stream_id = stream["position"] - preserve_fn(pusher_pool.on_new_receipts)( - min_stream_id, max_stream_id, affected_room_ids - ) - - while True: - try: - args = store.stream_positions() - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - poke_pushers(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(30) + + def stop_pusher(self, user_id, app_id, pushkey): + key = "%s:%s" % (app_id, pushkey) + pushers_for_user = self.pusher_pool.pushers.get(user_id, {}) + pusher = pushers_for_user.pop(key, None) + if pusher is None: + return + logger.info("Stopping pusher %r / %r", user_id, key) + pusher.on_stop() + + def start_pusher(self, user_id, app_id, pushkey): + key = "%s:%s" % (app_id, pushkey) + logger.info("Starting pusher %r / %r", user_id, key) + return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id) def start(config_options): @@ -288,7 +256,6 @@ def run(): reactor.run() def start(): - ps.replicate() ps.get_pusherpool().start() ps.get_datastore().start_profiling() ps.get_state_handler().start_caching() diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 34e34e55805b..92cc6cb67a98 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -16,7 +16,7 @@ import synapse -from synapse.api.constants import EventTypes, PresenceState +from synapse.api.constants import EventTypes from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging @@ -40,15 +40,14 @@ from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.room import RoomStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage.client_ips import ClientIpStore from synapse.storage.engines import create_engine from synapse.storage.presence import PresenceStore, UserPresenceState from synapse.storage.roommember import RoomMemberStore -from synapse.util.async import sleep from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, preserve_fn, \ - PreserveLoggingContext +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.manhole import manhole from synapse.util.rlimit import change_resource_limit from synapse.util.stringutils import random_string @@ -111,7 +110,6 @@ def __init__(self, hs): self.http_client = hs.get_simple_http_client() self.store = hs.get_datastore() self.user_to_num_current_syncs = {} - self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" self.clock = hs.get_clock() self.notifier = hs.get_notifier() @@ -124,14 +122,7 @@ def __init__(self, hs): self.process_id = random_string(16) logger.info("Presence process_id is %r", self.process_id) - self._sending_sync = False - self._need_to_send_sync = False - self.clock.looping_call( - self._send_syncing_users_regularly, - UPDATE_SYNCING_USERS_MS, - ) - - reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) + self.sync_callback = None def set_state(self, user, state, ignore_status_msg=False): # TODO Hows this supposed to work? @@ -142,15 +133,14 @@ def set_state(self, user, state, ignore_status_msg=False): _get_interested_parties = PresenceHandler._get_interested_parties.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__ - @defer.inlineCallbacks def user_syncing(self, user_id, affect_presence): if affect_presence: curr_sync = self.user_to_num_current_syncs.get(user_id, 0) self.user_to_num_current_syncs[user_id] = curr_sync + 1 - prev_states = yield self.current_state_for_users([user_id]) - if prev_states[user_id].state == PresenceState.OFFLINE: - # TODO: Don't block the sync request on this HTTP hit. - yield self._send_syncing_users_now() + + if self.sync_callback: + if self.user_to_num_current_syncs[user_id] == 1: + self.sync_callback(user_id, True) def _end(): # We check that the user_id is in user_to_num_current_syncs because @@ -159,6 +149,10 @@ def _end(): if affect_presence and user_id in self.user_to_num_current_syncs: self.user_to_num_current_syncs[user_id] -= 1 + if self.sync_callback: + if self.user_to_num_current_syncs[user_id] == 0: + self.sync_callback(user_id, False) + @contextlib.contextmanager def _user_syncing(): try: @@ -166,49 +160,7 @@ def _user_syncing(): finally: _end() - defer.returnValue(_user_syncing()) - - @defer.inlineCallbacks - def _on_shutdown(self): - # When the synchrotron is shutdown tell the master to clear the in - # progress syncs for this process - self.user_to_num_current_syncs.clear() - yield self._send_syncing_users_now() - - def _send_syncing_users_regularly(self): - # Only send an update if we aren't in the middle of sending one. - if not self._sending_sync: - preserve_fn(self._send_syncing_users_now)() - - @defer.inlineCallbacks - def _send_syncing_users_now(self): - if self._sending_sync: - # We don't want to race with sending another update. - # Instead we wait for that update to finish and send another - # update afterwards. - self._need_to_send_sync = True - return - - # Flag that we are sending an update. - self._sending_sync = True - - yield self.http_client.post_json_get_json(self.syncing_users_url, { - "process_id": self.process_id, - "syncing_users": [ - user_id for user_id, count in self.user_to_num_current_syncs.items() - if count > 0 - ], - }) - - # Unset the flag as we are no longer sending an update. - self._sending_sync = False - if self._need_to_send_sync: - # If something happened while we were sending the update then - # we might need to send another update. - # TODO: Check if the update that was sent matches the current state - # as we only need to send an update if they are different. - self._need_to_send_sync = False - yield self._send_syncing_users_now() + return defer.succeed(_user_syncing()) @defer.inlineCallbacks def notify_from_replication(self, states, stream_id): @@ -223,26 +175,24 @@ def notify_from_replication(self, states, stream_id): ) @defer.inlineCallbacks - def process_replication(self, result): - stream = result.get("presence", {"rows": []}) - states = [] - for row in stream["rows"]: - ( - position, user_id, state, last_active_ts, - last_federation_update_ts, last_user_sync_ts, status_msg, - currently_active - ) = row - state = UserPresenceState( - user_id, state, last_active_ts, - last_federation_update_ts, last_user_sync_ts, status_msg, - currently_active - ) - self.user_to_current_state[user_id] = state - states.append(state) + def process_replication_rows(self, token, rows): + states = [UserPresenceState( + row.user_id, row.state, row.last_active_ts, + row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg, + row.currently_active + ) for row in rows] - if states and "position" in stream: - stream_id = int(stream["position"]) - yield self.notify_from_replication(states, stream_id) + for state in states: + self.user_to_current_state[row.user_id] = state + + stream_id = token + yield self.notify_from_replication(states, stream_id) + + def get_currently_syncing_users(self): + return [ + user_id for user_id, count in self.user_to_num_current_syncs.iteritems() + if count > 0 + ] class SynchrotronTyping(object): @@ -257,16 +207,13 @@ def stream_positions(self): # value which we *must* use for the next replication request. return {"typing": self._latest_room_serial} - def process_replication(self, result): - stream = result.get("typing") - if stream: - self._latest_room_serial = int(stream["position"]) + def process_replication_rows(self, token, rows): + self._latest_room_serial = token - for row in stream["rows"]: - position, room_id, typing_json = row - typing = json.loads(typing_json) - self._room_serials[room_id] = position - self._room_typing[room_id] = typing + for row in rows: + typing = json.loads(row.user_ids) + self._room_serials[row.room_id] = token + self._room_typing[row.room_id] = typing class SynchrotronApplicationService(object): @@ -351,124 +298,90 @@ def start_listening(self, listeners): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - notifier = self.get_notifier() - presence_handler = self.get_presence_handler() - typing_handler = self.get_typing_handler() - - def notify_from_stream( - result, stream_name, stream_key, room=None, user=None - ): - stream = result.get(stream_name) - if stream: - position_index = stream["field_names"].index("position") - if room: - room_index = stream["field_names"].index(room) - if user: - user_index = stream["field_names"].index(user) - - users = () - rooms = () - for row in stream["rows"]: - position = row[position_index] - - if user: - users = (row[user_index],) - - if room: - rooms = (row[room_index],) - - notifier.on_new_event( - stream_key, position, users=users, rooms=rooms - ) + self.get_tcp_replication().start_replication(self) - @defer.inlineCallbacks - def notify_device_list_update(result): - stream = result.get("device_lists") - if not stream: - return + def build_tcp_replication(self): + return SyncReplicationHandler(self) - position_index = stream["field_names"].index("position") - user_index = stream["field_names"].index("user_id") + def build_presence_handler(self): + return SynchrotronPresence(self) - for row in stream["rows"]: - position = row[position_index] - user_id = row[user_index] + def build_typing_handler(self): + return SynchrotronTyping(self) - room_ids = yield store.get_rooms_for_user(user_id) - notifier.on_new_event( - "device_list_key", position, rooms=room_ids, - ) +class SyncReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(SyncReplicationHandler, self).__init__(hs.get_datastore()) - @defer.inlineCallbacks - def notify(result): - stream = result.get("events") - if stream: - max_position = stream["position"] - - event_map = yield store.get_events([row[1] for row in stream["rows"]]) - - for row in stream["rows"]: - position = row[0] - event_id = row[1] - event = event_map.get(event_id, None) - if not event: - continue - - extra_users = () - if event.type == EventTypes.Member: - extra_users = (event.state_key,) - notifier.on_new_room_event( - event, position, max_position, extra_users - ) + self.store = hs.get_datastore() + self.typing_handler = hs.get_typing_handler() + self.presence_handler = hs.get_presence_handler() + self.notifier = hs.get_notifier() - notify_from_stream( - result, "push_rules", "push_rules_key", user="user_id" - ) - notify_from_stream( - result, "user_account_data", "account_data_key", user="user_id" - ) - notify_from_stream( - result, "room_account_data", "account_data_key", user="user_id" + self.presence_handler.sync_callback = self.send_user_sync + + def on_rdata(self, stream_name, token, rows): + super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) + + if stream_name == "typing": + self.typing_handler.process_replication_rows(token, rows) + elif stream_name == "presence": + self.presence_handler.process_replication_rows(token, rows) + self.notify(stream_name, token, rows) + + def get_streams_to_replicate(self): + args = super(SyncReplicationHandler, self).get_streams_to_replicate() + args.update(self.typing_handler.stream_positions()) + return args + + def get_currently_syncing_users(self): + return self.presence_handler.get_currently_syncing_users() + + @defer.inlineCallbacks + def notify(self, stream_name, token, rows): + if stream_name == "events": + # We shouldn't get multiple rows per token for events stream, so + # we don't need to optimise this for multiple rows. + for row in rows: + event = yield self.store.get_event(row.event_id) + extra_users = () + if event.type == EventTypes.Member: + extra_users = (event.state_key,) + max_token = self.store.get_room_max_stream_ordering() + self.notifier.on_new_room_event( + event, token, max_token, extra_users + ) + elif stream_name == "push_rules": + self.notifier.on_new_event( + "push_rules_key", token, users=[row.user_id for row in rows], ) - notify_from_stream( - result, "tag_account_data", "account_data_key", user="user_id" + elif stream_name in ("account_data", "tag_account_data",): + self.notifier.on_new_event( + "account_data_key", token, users=[row.user_id for row in rows], ) - notify_from_stream( - result, "receipts", "receipt_key", room="room_id" + elif stream_name == "receipts": + self.notifier.on_new_event( + "receipt_key", token, rooms=[row.room_id for row in rows], ) - notify_from_stream( - result, "typing", "typing_key", room="room_id" + elif stream_name == "typing": + self.notifier.on_new_event( + "typing_key", token, rooms=[row.room_id for row in rows], ) - notify_from_stream( - result, "to_device", "to_device_key", user="user_id" + elif stream_name == "to_device": + entities = [row.entity for row in rows if row.entity.startswith("@")] + if entities: + self.notifier.on_new_event( + "to_device_key", token, users=entities, + ) + elif stream_name == "device_lists": + all_room_ids = set() + for row in rows: + room_ids = yield self.store.get_rooms_for_user(row.user_id) + all_room_ids.update(room_ids) + self.notifier.on_new_event( + "device_list_key", token, rooms=all_room_ids, ) - yield notify_device_list_update(result) - - while True: - try: - args = store.stream_positions() - args.update(typing_handler.stream_positions()) - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - typing_handler.process_replication(result) - yield presence_handler.process_replication(result) - yield notify(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(5) - - def build_presence_handler(self): - return SynchrotronPresence(self) - - def build_typing_handler(self): - return SynchrotronTyping(self) def start(config_options): @@ -514,7 +427,6 @@ def run(): def start(): ss.get_datastore().start_profiling() - ss.replicate() ss.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index b165c67ee721..ad06ba969126 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -29,6 +29,9 @@ def read_config(self, config): self.worker_log_file = config.get("worker_log_file") self.worker_log_config = config.get("worker_log_config") self.worker_replication_url = config.get("worker_replication_url") + self.worker_replication_host = config.get("worker_replication_host", None) + self.worker_replication_port = config.get("worker_replication_port", None) + self.worker_name = config.get("worker_name", self.worker_app) if self.worker_listeners: for listener in self.worker_listeners: diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 5c9f7a86f0aa..332c6666bc08 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -54,6 +54,7 @@ class FederationRemoteSendQueue(object): def __init__(self, hs): self.server_name = hs.hostname self.clock = hs.get_clock() + self.notifier = hs.get_notifier() self.presence_map = {} self.presence_changed = sorteddict() @@ -186,6 +187,8 @@ def send_edu(self, destination, edu_type, content, key=None): else: self.edus[pos] = edu + self.notifier.on_new_replication_data() + def send_presence(self, destination, states): """As per TransactionQueue""" pos = self._next_pos() @@ -199,21 +202,28 @@ def send_presence(self, destination, states): (destination, state.user_id) for state in states ] + self.notifier.on_new_replication_data() + def send_failure(self, failure, destination): """As per TransactionQueue""" pos = self._next_pos() self.failures[pos] = (destination, str(failure)) + self.notifier.on_new_replication_data() def send_device_messages(self, destination): """As per TransactionQueue""" pos = self._next_pos() self.device_messages[pos] = destination + self.notifier.on_new_replication_data() def get_current_token(self): return self.pos - 1 - def get_replication_rows(self, token, limit, federation_ack=None): + def federation_ack(self, token): + self._clear_queue_before_pos(token) + + def get_replication_rows(self, from_token, to_token, limit, federation_ack=None): """ Args: token (int) @@ -225,8 +235,8 @@ def get_replication_rows(self, token, limit, federation_ack=None): # TODO: Handle limit. # To handle restarts where we wrap around - if token > self.pos: - token = -1 + if from_token > self.pos: + from_token = -1 rows = [] @@ -237,10 +247,11 @@ def get_replication_rows(self, token, limit, federation_ack=None): # Fetch changed presence keys = self.presence_changed.keys() - i = keys.bisect_right(token) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 dest_user_ids = set( (pos, dest_user_id) - for pos in keys[i:] + for pos in keys[i:j] for dest_user_id in self.presence_changed[pos] ) @@ -252,8 +263,9 @@ def get_replication_rows(self, token, limit, federation_ack=None): # Fetch changes keyed edus keys = self.keyed_edu_changed.keys() - i = keys.bisect_right(token) - keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:]) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:j]) for (pos, (destination, edu_key)) in keyed_edus: rows.append( @@ -265,16 +277,18 @@ def get_replication_rows(self, token, limit, federation_ack=None): # Fetch changed edus keys = self.edus.keys() - i = keys.bisect_right(token) - edus = set((k, self.edus[k]) for k in keys[i:]) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + edus = set((k, self.edus[k]) for k in keys[i:j]) for (pos, edu) in edus: rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict()))) # Fetch changed failures keys = self.failures.keys() - i = keys.bisect_right(token) - failures = set((k, self.failures[k]) for k in keys[i:]) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + failures = set((k, self.failures[k]) for k in keys[i:j]) for (pos, (destination, failure)) in failures: rows.append((pos, FAILURE_TYPE, ujson.dumps({ @@ -284,8 +298,9 @@ def get_replication_rows(self, token, limit, federation_ack=None): # Fetch changed device messages keys = self.device_messages.keys() - i = keys.bisect_right(token) - device_messages = set((k, self.device_messages[k]) for k in keys[i:]) + i = keys.bisect_right(from_token) + j = keys.bisect_right(to_token) + 1 + device_messages = set((k, self.device_messages[k]) for k in keys[i:j]) for (pos, destination) in device_messages: rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({ diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 1ede117c7959..9f9257029e45 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -30,6 +30,7 @@ from synapse.storage.presence import UserPresenceState from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.async import Linearizer from synapse.util.logcontext import preserve_fn from synapse.util.logutils import log_function from synapse.util.metrics import Measure @@ -187,6 +188,7 @@ def __init__(self, hs): # process_id to millisecond timestamp last updated. self.external_process_to_current_syncs = {} self.external_process_last_updated_ms = {} + self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") # Start a LoopingCall in 30s that fires every 5s. # The initial delay is to allow disconnected clients a chance to @@ -508,6 +510,70 @@ def update_external_syncs(self, process_id, syncing_user_ids): self.external_process_last_updated_ms[process_id] = self.clock.time_msec() self.external_process_to_current_syncs[process_id] = syncing_user_ids + @defer.inlineCallbacks + def update_external_syncs_row(self, process_id, user_id, is_syncing): + """Update the syncing users for an external process as a delta. + + Args: + process_id (str): An identifier for the process the users are + syncing against. This allows synapse to process updates + as user start and stop syncing against a given process. + user_id (str): The user who has started or stopped syncing + is_syncing (bool): Whether or not the user is now syncing + """ + with (yield self.external_sync_linearizer.queue(process_id)): + prev_state = yield self.current_state_for_user(user_id) + + process_presence = self.external_process_to_current_syncs.setdefault( + process_id, set() + ) + time_now_ms = self.clock.time_msec() + + updates = [] + if is_syncing and user_id not in process_presence: + if prev_state.state == PresenceState.OFFLINE: + updates.append(prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=time_now_ms, + last_user_sync_ts=time_now_ms, + )) + else: + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + )) + process_presence.add(user_id) + elif user_id in process_presence: + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + )) + + if updates: + yield self._update_states(updates) + + self.external_process_last_updated_ms[process_id] = self.clock.time_msec() + + @defer.inlineCallbacks + def update_external_syncs_clear(self, process_id): + """Marks all users that had been marked as syncing by a given process + as offline. + + Used when the process has stopped/disappeared. + """ + with (yield self.external_sync_linearizer.queue(process_id)): + process_presence = self.external_process_to_current_syncs.pop( + process_id, set() + ) + prev_states = yield self.current_state_for_users(process_presence) + time_now_ms = self.clock.time_msec() + + yield self._update_states([ + prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + ) + for prev_state in prev_states.itervalues() + ]) + self.external_process_last_updated_ms.pop(process_id, None) + @defer.inlineCallbacks def current_state_for_user(self, user_id): """Get the current presence state for a user. diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0eea7f8f9c29..d6809862e0de 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -293,6 +293,9 @@ def get_all_typing_updates(self, last_id, current_id): rows.sort() return rows + def get_current_token(self): + return self._latest_room_serial + class TypingNotificationEventSource(object): def __init__(self, hs): diff --git a/synapse/notifier.py b/synapse/notifier.py index 7eeba6d28e57..358ec86004a4 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -550,3 +550,6 @@ def wait_for_replication(self, callback, timeout): break defer.returnValue(result) + + def wait_once_for_replication(self): + return self.replication_deferred.observe() diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 03930fe958d5..abd3fe766530 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -489,7 +489,7 @@ def federation(self, writer, current_token, limit, request_streams, federation_a if federation is not None and federation != current_position: federation_rows = self.federation_sender.get_replication_rows( - federation, limit, federation_ack=federation_ack, + federation, current_position, limit, federation_ack=federation_ack, ) upto_token = _position_from_rows(federation_rows, current_position) writer.write_header_and_rows("federation", federation_rows, ( @@ -504,7 +504,7 @@ def device_lists(self, writer, current_token, limit, request_streams): if device_lists is not None and device_lists != current_position: changes = yield self.store.get_all_device_list_changes_for_remotes( - device_lists, + device_lists, current_position, ) writer.write_header_and_rows("device_lists", changes, ( "position", "user_id", "destination", diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index ab133db872fa..b962641166dc 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -15,7 +15,6 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine -from twisted.internet import defer from ._slaved_id_tracker import SlavedIdTracker @@ -34,8 +33,7 @@ def __init__(self, db_conn, hs): else: self._cache_id_gen = None - self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache" - self.http_client = hs.get_simple_http_client() + self.hs = hs def stream_positions(self): pos = {} @@ -43,35 +41,20 @@ def stream_positions(self): pos["caches"] = self._cache_id_gen.get_current_token() return pos - def process_replication(self, result): - stream = result.get("caches") - if stream: - for row in stream["rows"]: - ( - position, cache_func, keys, invalidation_ts, - ) = row - + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "caches": + self._cache_id_gen.advance(token) + for row in rows: try: - getattr(self, cache_func).invalidate(tuple(keys)) + getattr(self, row.cache_func).invalidate(tuple(row.keys)) except AttributeError: # We probably haven't pulled in the cache in this worker, # which is fine. pass - self._cache_id_gen.advance(int(stream["position"])) - return defer.succeed(None) def _invalidate_cache_and_stream(self, txn, cache_func, keys): txn.call_after(cache_func.invalidate, keys) txn.call_after(self._send_invalidation_poke, cache_func, keys) - @defer.inlineCallbacks def _send_invalidation_poke(self, cache_func, keys): - try: - yield self.http_client.post_json_get_json(self.expire_cache_url, { - "invalidate": [{ - "name": cache_func.__name__, - "keys": list(keys), - }] - }) - except: - logger.exception("Failed to poke on expire_cache") + self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys) diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 77c64722c788..efbd87918ec6 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -69,38 +69,25 @@ def stream_positions(self): result["tag_account_data"] = position return result - def process_replication(self, result): - stream = result.get("user_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id, data_type = row[:3] - self.get_global_account_data_by_type_for_user.invalidate( - (data_type, user_id,) - ) - self.get_account_data_for_user.invalidate((user_id,)) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "tag_account_data": + self._account_data_id_gen.advance(token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - stream = result.get("room_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] - self.get_account_data_for_user.invalidate((user_id,)) + elif stream_name == "account_data": + self._account_data_id_gen.advance(token) + for row in rows: + if not row.room_id: + self.get_global_account_data_by_type_for_user.invalidate( + (row.data_type, row.user_id,) + ) + self.get_account_data_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - stream = result.get("tag_account_data") - if stream: - self._account_data_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] - self.get_tags_for_user.invalidate((user_id,)) - self._account_data_stream_cache.entity_has_changed( - user_id, position - ) - - return super(SlavedAccountDataStore, self).process_replication(result) + return super(SlavedAccountDataStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index f9102e0d8913..6f3fb64770e7 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -53,21 +53,18 @@ def stream_positions(self): result["to_device"] = self._device_inbox_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("to_device") - if stream: - self._device_inbox_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - stream_id = row[0] - entity = row[1] - - if entity.startswith("@"): + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "to_device": + self._device_inbox_id_gen.advance(token) + for row in rows: + if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( - entity, stream_id + row.entity, token ) else: self._device_federation_outbox_stream_cache.entity_has_changed( - entity, stream_id + row.entity, token ) - - return super(SlavedDeviceInboxStore, self).process_replication(result) + return super(SlavedDeviceInboxStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index ca46aa17b621..4d4a435471b1 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -51,22 +51,18 @@ def stream_positions(self): result["device_lists"] = self._device_list_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("device_lists") - if stream: - self._device_list_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - stream_id = row[0] - user_id = row[1] - destination = row[2] - + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "device_lists": + self._device_list_id_gen.advance(token) + for row in rows: self._device_list_stream_cache.entity_has_changed( - user_id, stream_id + row.user_id, token ) - if destination: + if row.destination: self._device_list_federation_stream_cache.entity_has_changed( - destination, stream_id + row.destination, token ) - - return super(SlavedDeviceStore, self).process_replication(result) + return super(SlavedDeviceStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d4db1e452e91..5fd47706efc0 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -201,48 +201,25 @@ def stream_positions(self): result["backfill"] = -self._backfill_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("events") - if stream: - self._stream_id_gen.advance(int(stream["position"])) - - if stream["rows"]: - logger.info("Got %d event rows", len(stream["rows"])) - - for row in stream["rows"]: - self._process_replication_row( - row, backfilled=False, + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "events": + self._stream_id_gen.advance(token) + for row in rows: + self.invalidate_caches_for_event( + token, row.event_id, row.room_id, row.type, row.state_key, + row.redacts, + backfilled=False, ) - - stream = result.get("backfill") - if stream: - self._backfill_id_gen.advance(-int(stream["position"])) - for row in stream["rows"]: - self._process_replication_row( - row, backfilled=True, + elif stream_name == "backfill": + self._backfill_id_gen.advance(-token) + for row in rows: + self.invalidate_caches_for_event( + -token, row.event_id, row.room_id, row.type, row.state_key, + row.redacts, + backfilled=True, ) - - stream = result.get("forward_ex_outliers") - if stream: - self._stream_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - event_id = row[1] - self._invalidate_get_event_cache(event_id) - - stream = result.get("backward_ex_outliers") - if stream: - self._backfill_id_gen.advance(-int(stream["position"])) - for row in stream["rows"]: - event_id = row[1] - self._invalidate_get_event_cache(event_id) - - return super(SlavedEventStore, self).process_replication(result) - - def _process_replication_row(self, row, backfilled): - stream_ordering = row[0] if not backfilled else -row[0] - self.invalidate_caches_for_event( - stream_ordering, row[1], row[2], row[3], row[4], row[5], - backfilled=backfilled, + return super(SlavedEventStore, self).process_replication_rows( + stream_name, token, rows ) def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index e4a2414d78a8..dffc80adc378 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -48,15 +48,14 @@ def stream_positions(self): result["presence"] = position return result - def process_replication(self, result): - stream = result.get("presence") - if stream: - self._presence_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, user_id = row[:2] + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "presence": + self._presence_id_gen.advance(token) + for row in rows: self.presence_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - self._get_presence_for_user.invalidate((user_id,)) - - return super(SlavedPresenceStore, self).process_replication(result) + self._get_presence_for_user.invalidate((row.user_id,)) + return super(SlavedPresenceStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 21ceb0213ae5..83e880fdd2a2 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -50,18 +50,15 @@ def stream_positions(self): result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("push_rules") - if stream: - for row in stream["rows"]: - position = row[0] - user_id = row[2] - self.get_push_rules_for_user.invalidate((user_id,)) - self.get_push_rules_enabled_for_user.invalidate((user_id,)) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "push_rules": + self._push_rules_stream_id_gen.advance(token) + for row in rows: + self.get_push_rules_for_user.invalidate((row.user_id,)) + self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.push_rules_stream_cache.entity_has_changed( - user_id, position + row.user_id, token ) - - self._push_rules_stream_id_gen.advance(int(stream["position"])) - - return super(SlavedPushRuleStore, self).process_replication(result) + return super(SlavedPushRuleStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index d88206b3bba0..4e8d68ece9dc 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -40,13 +40,9 @@ def stream_positions(self): result["pushers"] = self._pushers_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("pushers") - if stream: - self._pushers_id_gen.advance(int(stream["position"])) - - stream = result.get("deleted_pushers") - if stream: - self._pushers_id_gen.advance(int(stream["position"])) - - return super(SlavedPusherStore, self).process_replication(result) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "pushers": + self._pushers_id_gen.advance(token) + return super(SlavedPusherStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index ac9662d399d9..b371574ece56 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -65,20 +65,22 @@ def stream_positions(self): result["receipts"] = self._receipts_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("receipts") - if stream: - self._receipts_id_gen.advance(int(stream["position"])) - for row in stream["rows"]: - position, room_id, receipt_type, user_id = row[:4] - self.invalidate_caches_for_receipt(room_id, receipt_type, user_id) - self._receipts_stream_cache.entity_has_changed(room_id, position) - - return super(SlavedReceiptsStore, self).process_replication(result) - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_linearized_receipts_for_room.invalidate_many((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) ) + + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "receipts": + self._receipts_id_gen.advance(token) + for row in rows: + self.invalidate_caches_for_receipt( + row.room_id, row.receipt_type, row.user_id + ) + self._receipts_stream_cache.entity_has_changed(row.room_id, token) + + return super(SlavedReceiptsStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 6df9a25ef311..f5103840333f 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -46,9 +46,10 @@ def stream_positions(self): result["public_rooms"] = self._public_room_id_gen.get_current_token() return result - def process_replication(self, result): - stream = result.get("public_rooms") - if stream: - self._public_room_id_gen.advance(int(stream["position"])) + def process_replication_rows(self, stream_name, token, rows): + if stream_name == "public_rooms": + self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication(result) + return super(RoomStore, self).process_replication_rows( + stream_name, token, rows + ) diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py new file mode 100644 index 000000000000..432f1e5b1c90 --- /dev/null +++ b/synapse/replication/tcp/__init__.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module implements the TCP replication protocol used by synapse to +communicate between the master process and its workers (when they're enabled). + +The protocol is based on fire and forget, line based commands. An example flow +would be (where '>' indicates master->worker and '<' worker->master flows):: + + > SERVER example.com + < REPLICATE events 53 + > RDATA events 54 ["$foo1:bar.com", ...] + > RDATA events 55 ["$foo4:bar.com", ...] + +The example shows the server accepting a new connection and sending its identity +with the `SERVER` command, followed by the client asking to subscribe to the +`events` stream from the token `53`. The server then periodically sends `RDATA` +commands which have the format `RDATA `, where the +format of `` is defined by the individual streams. + +Error reporting happens by either the client or server sending an `ERROR` +command, and usually the connection will be closed. + + +Structure of the module: + * client.py - the client classes used for workers to connect to master + * command.py - the definitions of all the valid commands + * protocol.py - contains bot the client and server protocol implementations, + these should not be used directly + * resource.py - the server classes that accepts and handle client connections + * streams.py - the definitons of all the valid streams + +Further detail about the wire protocol can be found in protocol.py and the +meaning of the various commands in command.py. + + +Since the protocol is a simple line based, its possible to manually connect to +the server using a tool like netcat. A few things should be noted when manually +using the protocol: + * When subscribing to a stream using `REPLICATE`, the special token `NOW` can + be used to get all future updates. The special stream name `ALL` can be used + with `NOW` to subscribe to all available streams. + * The federation stream is only available if federation sending has been + disabled on the main process. + * The server will only time connections out that have sent a `PING` command. + If a ping is sent then the connection will be closed if no further commands + are receieved within 15s. Both the client and server protocol implementations + will send an initial PING on connection and ensure at least one command every + 5s is sent (not necessarily `PING`). + * `RDATA` commands *usually* include a numeric token, however if the stream + has multiple rows to replicate per token the server will send multiple + `RDATA` commands, with all but the last having a token of `batch`. See + the documentation on `commands.RdataCommand` for further details. +""" diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py new file mode 100644 index 000000000000..dbaf6ac5afdd --- /dev/null +++ b/synapse/replication/tcp/client.py @@ -0,0 +1,196 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A replication client for use by synapse workers. +""" + +from twisted.internet import reactor, defer +from twisted.internet.protocol import ReconnectingClientFactory + +from .commands import ( + FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand, +) +from .protocol import ClientReplicationStreamProtocol + +import logging + +logger = logging.getLogger(__name__) + + +class ReplicationClientFactory(ReconnectingClientFactory): + """Factory for building connections to the master. Will reconnect if the + connection is lost. + + Accepts a handler that will be called when new data is available or data + is required. + """ + maxDelay = 5 # Try at least once every N seconds + + def __init__(self, hs, client_name, handler): + self.client_name = client_name + self.handler = handler + self.server_name = hs.config.server_name + self._clock = hs.get_clock() # As self.clock is defined in super class + + reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying) + + def startedConnecting(self, connector): + logger.info("Connecting to replication: %r", connector.getDestination()) + + def buildProtocol(self, addr): + logger.info("Connected to replication: %r", addr) + self.resetDelay() + return ClientReplicationStreamProtocol( + self.client_name, self.server_name, self._clock, self.handler + ) + + def clientConnectionLost(self, connector, reason): + logger.error("Lost replication conn: %r", reason) + ReconnectingClientFactory.clientConnectionLost(self, connector, reason) + + def clientConnectionFailed(self, connector, reason): + logger.error("Failed to connect to replication: %r", reason) + ReconnectingClientFactory.clientConnectionFailed( + self, connector, reason + ) + + +class ReplicationClientHandler(object): + """A base handler that can be passed to the ReplicationClientFactory. + + By default proxies incoming replication data to the SlaveStore. + """ + def __init__(self, store): + self.store = store + + # The current connection. None if we are currently (re)connecting + self.connection = None + + # Any pending commands to be sent once a new connection has been + # established + self.pending_commands = [] + + # Map from string -> deferred, to wake up when receiveing a SYNC with + # the given string. + # Used for tests. + self.awaiting_syncs = {} + + def start_replication(self, hs): + """Helper method to start a replication connection to the remote server + using TCP. + """ + client_name = hs.config.worker_name + factory = ReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + reactor.connectTCP(host, port, factory) + + def on_rdata(self, stream_name, token, rows): + """Called when we get new replication data. By default this just pokes + the slave store. + + Can be overriden in subclasses to handle more. + """ + logger.info("Received rdata %s -> %s", stream_name, token) + self.store.process_replication_rows(stream_name, token, rows) + + def on_position(self, stream_name, token): + """Called when we get new position data. By default this just pokes + the slave store. + + Can be overriden in subclasses to handle more. + """ + self.store.process_replication_rows(stream_name, token, []) + + def on_sync(self, data): + """When we received a SYNC we wake up any deferreds that were waiting + for the sync with the given data. + + Used by tests. + """ + d = self.awaiting_syncs.pop(data, None) + if d: + d.callback(data) + + def get_streams_to_replicate(self): + """Called when a new connection has been established and we need to + subscribe to streams. + + Returns a dictionary of stream name to token. + """ + args = self.store.stream_positions() + user_account_data = args.pop("user_account_data", None) + room_account_data = args.pop("room_account_data", None) + if user_account_data: + args["account_data"] = user_account_data + elif room_account_data: + args["account_data"] = room_account_data + return args + + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users. (Overriden by the synchrotron's only) + """ + return [] + + def send_command(self, cmd): + """Send a command to master (when we get establish a connection if we + don't have one already.) + """ + if self.connection: + self.connection.send_command(cmd) + else: + logger.warn("Queuing command as not connected: %r", cmd.NAME) + self.pending_commands.append(cmd) + + def send_federation_ack(self, token): + """Ack data for the federation stream. This allows the master to drop + data stored purely in memory. + """ + self.send_command(FederationAckCommand(token)) + + def send_user_sync(self, user_id, is_syncing): + """Poke the master that a user has started/stopped syncing. + """ + self.send_command(UserSyncCommand(user_id, is_syncing)) + + def send_remove_pusher(self, app_id, push_key, user_id): + """Poke the master to remove a pusher for a user + """ + cmd = RemovePusherCommand(app_id, push_key, user_id) + self.send_command(cmd) + + def send_invalidate_cache(self, cache_func, keys): + """Poke the master to invalidate a cache. + """ + cmd = InvalidateCacheCommand(cache_func, keys) + self.send_command(cmd) + + def await_sync(self, data): + """Returns a deferred that is resolved when we receive a SYNC command + with given data. + + Used by tests. + """ + return self.awaiting_syncs.setdefault(data, defer.Deferred()) + + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + self.connection = connection + if connection: + for cmd in self.pending_commands: + connection.send_command(cmd) + self.pending_commands = [] diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py new file mode 100644 index 000000000000..457c7279fe63 --- /dev/null +++ b/synapse/replication/tcp/commands.py @@ -0,0 +1,341 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines the various valid commands + +The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are +allowed to be sent by which side. +""" + +import logging +import ujson as json + + +logger = logging.getLogger(__name__) + + +class Command(object): + """The base command class. + + All subclasses must set the NAME variable which equates to the name of the + command on the wire. + + A full command line on the wire is constructed from `NAME + " " + to_line()` + + The default implementation creates a command of form ` ` + """ + NAME = None + + def __init__(self, data): + self.data = data + + @classmethod + def from_line(cls, line): + """Deserialises a line from the wire into this command. `line` does not + include the command. + """ + return cls(line) + + def to_line(self): + """Serialises the comamnd for the wire. Does not include the command + prefix. + """ + return self.data + + +class ServerCommand(Command): + """Sent by the server on new connection and includes the server_name. + + Format:: + + SERVER + """ + NAME = "SERVER" + + +class RdataCommand(Command): + """Sent by server when a subscribed stream has an update. + + Format:: + + RDATA + + The `` may either be a numeric stream id OR "batch". The latter case + is used to support sending multiple updates with the same stream ID. This + is done by sending an RDATA for each row, with all but the last RDATA having + a token of "batch" and the last having the final stream ID. + + The client should batch all incoming RDATA with a token of "batch" (per + stream_name) until it sees an RDATA with a numeric stream ID. + + `` of "batch" maps to the instance variable `token` being None. + + An example of a batched series of RDATA:: + + RDATA presence batch ["@foo:example.com", "online", ...] + RDATA presence batch ["@bar:example.com", "online", ...] + RDATA presence 59 ["@baz:example.com", "online", ...] + """ + NAME = "RDATA" + + def __init__(self, stream_name, token, row): + self.stream_name = stream_name + self.token = token + self.row = row + + @classmethod + def from_line(cls, line): + stream_name, token, row_json = line.split(" ", 2) + return cls( + stream_name, + None if token == "batch" else int(token), + json.loads(row_json) + ) + + def to_line(self): + return " ".join(( + self.stream_name, + str(self.token) if self.token is not None else "batch", + json.dumps(self.row), + )) + + +class PositionCommand(Command): + """Sent by the client to tell the client the stream postition without + needing to send an RDATA. + """ + NAME = "POSITION" + + def __init__(self, stream_name, token): + self.stream_name = stream_name + self.token = token + + @classmethod + def from_line(cls, line): + stream_name, token = line.split(" ", 1) + return cls(stream_name, int(token)) + + def to_line(self): + return " ".join((self.stream_name, str(self.token),)) + + +class ErrorCommand(Command): + """Sent by either side if there was an ERROR. The data is a string describing + the error. + """ + NAME = "ERROR" + + +class PingCommand(Command): + """Sent by either side as a keep alive. The data is arbitary (often timestamp) + """ + NAME = "PING" + + +class NameCommand(Command): + """Sent by client to inform the server of the client's identity. The data + is the name + """ + NAME = "NAME" + + +class ReplicateCommand(Command): + """Sent by the client to subsribe to the stream. + + Format:: + + REPLICATE + + Where may be either: + * a numeric stream_id to stream updates from + * "NOW" to stream all subsequent updates. + + The can be "ALL" to subscribe to all known streams, in which + case the must be set to "NOW", i.e.:: + + REPLICATE ALL NOW + """ + NAME = "REPLICATE" + + def __init__(self, stream_name, token): + self.stream_name = stream_name + self.token = token + + @classmethod + def from_line(cls, line): + stream_name, token = line.split(" ", 1) + if token in ("NOW", "now"): + token = "NOW" + else: + token = int(token) + return cls(stream_name, token) + + def to_line(self): + return " ".join((self.stream_name, str(self.token),)) + + +class UserSyncCommand(Command): + """Sent by the client to inform the server that a user has started or + stopped syncing. Used to calculate presence on the master. + + Format:: + + USER_SYNC + + Where is either "start" or "stop" + """ + NAME = "USER_SYNC" + + def __init__(self, user_id, is_syncing): + self.user_id = user_id + self.is_syncing = is_syncing + + @classmethod + def from_line(cls, line): + user_id, state = line.split(" ", 1) + + if state not in ("start", "end"): + raise Exception("Invalid USER_SYNC state %r" % (state,)) + + return cls(user_id, state == "start") + + def to_line(self): + return " ".join((self.user_id, "start" if self.is_syncing else "end")) + + +class FederationAckCommand(Command): + """Sent by the client when its processed upto a given point in the + federation stream. This allows the master to drop in memory caches of the + federation stream. + + This must only be sent from one worker (i.e. the one sending federation) + + Format:: + + FEDERATION_ACK + """ + NAME = "FEDERATION_ACK" + + def __init__(self, token): + self.token = token + + @classmethod + def from_line(cls, line): + return cls(int(line)) + + def to_line(self): + return str(self.token) + + +class SyncCommand(Command): + """Used for testing. The client protocol implementation allows waiting + on a SYNC command with a specified data. + """ + NAME = "SYNC" + + +class RemovePusherCommand(Command): + """Sent by the client to request the master remove the given pusher. + + Format:: + + REMOVE_PUSHER + """ + NAME = "REMOVE_PUSHER" + + def __init__(self, app_id, push_key, user_id): + self.user_id = user_id + self.app_id = app_id + self.push_key = push_key + + @classmethod + def from_line(cls, line): + app_id, push_key, user_id = line.split(" ", 2) + + return cls(app_id, push_key, user_id) + + def to_line(self): + return " ".join((self.app_id, self.push_key, self.user_id)) + + +class InvalidateCacheCommand(Command): + """Sent by the client to invalidate an upstream cache. + + THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE + NOT DISASTROUS IF WE DROP ON THE FLOOR. + + Mainly used to invalidate destination retry timing caches. + + Format:: + + INVALIDATE_CACHE + + Where is a json list. + """ + NAME = "INVALIDATE_CACHE" + + def __init__(self, cache_func, keys): + self.cache_func = cache_func + self.keys = keys + + @classmethod + def from_line(cls, line): + cache_func, keys_json = line.split(" ", 1) + + return cls(cache_func, json.loads(keys_json)) + + def to_line(self): + return " ".join((self.cache_func, json.dumps(self.keys))) + + +# Map of command name to command type. +COMMAND_MAP = { + cmd.NAME: cmd + for cmd in ( + ServerCommand, + RdataCommand, + PositionCommand, + ErrorCommand, + PingCommand, + NameCommand, + ReplicateCommand, + UserSyncCommand, + FederationAckCommand, + SyncCommand, + RemovePusherCommand, + InvalidateCacheCommand, + ) +} + +# The commands the server is allowed to send +VALID_SERVER_COMMANDS = ( + ServerCommand.NAME, + RdataCommand.NAME, + PositionCommand.NAME, + ErrorCommand.NAME, + PingCommand.NAME, + SyncCommand.NAME, +) + +# The commands the client is allowed to send +VALID_CLIENT_COMMANDS = ( + NameCommand.NAME, + ReplicateCommand.NAME, + PingCommand.NAME, + UserSyncCommand.NAME, + FederationAckCommand.NAME, + RemovePusherCommand.NAME, + InvalidateCacheCommand.NAME, + ErrorCommand.NAME, +) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py new file mode 100644 index 000000000000..6525e3dd3336 --- /dev/null +++ b/synapse/replication/tcp/protocol.py @@ -0,0 +1,671 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains the implementation of both the client and server +protocols. + +The basic structure of the protocol is line based, where the initial word of +each line specifies the command. The rest of the line is parsed based on the +command. For example, the `RDATA` command is defined as:: + + RDATA + +(Note that `` may contains spaces, but cannot contain newlines.) + +Blank lines are ignored. + + +# Keep alives + +Both sides are expected to send at least one command every 5s (`PING_TIME`), and +should send a `PING` command if necessary. If either side do not receive a +command within e.g. 15s then the connection should be closed. + +Because the server may be connected to manually using e.g. netcat, the timeouts +aren't enabled until an initial `PING` command is seen. Both the client and +server implementations below send a `PING` command immediately on connection to +ensure the timeouts are enabled. + +This ensures that both sides can quickly realize if the tcp connection has gone +and handle the situation appropriately. + + +# Start up + +When a new connection is made, the server: + * Sends a `SERVER` command, which includes the identity of the server, allowing + the client to detect if its connected to the expected server + * Sends a `PING` command as above, to enable the client to time out connections + promptly. + +The client: + * Sends a `NAME` command, allowing the server to associate a human friendly + name with the connection. This is optional. + * Sends a `PING` as above + * For each stream the client wishes to subscribe to it sends a `REPLICATE` + with the stream_name and token it wants to subscribe from. + * On receipt of a `SERVER` command, checks that the server name matches the + expected server name. + + +# Error handling + +If either side detects an error it can send an `ERROR` command and close the +connection. + +If the client side loses the connection to the server it should reconnect, +following the steps above. + + +# Congestion + +If the server sends messaegs faster than the client can consume them the server +will first buffer a (fairly large) number of commands and then disconnect the +client. This ensure that we don't queue up an unbounded number of commands in +memory and gives us a potential oppurtunity to squawk loudly. When/if the client +recovers it can reconnect to the server and ask for missed messages. + + +# Reliability + +In general the replication stream should be consisdered an unreliable transport +since e.g. commands are not resent if the connection disappears. + +The exception to that are the replication streams, i.e. RDATA commands, since +theses include tokens which can be used to restart the stream on connection +errors. + +The client should keep track of the token in the last RDATA command received +for each stream so that on reconneciton it can start streaming from the correct +place. Note: not all RDATA have valid tokens due to batching. See +`commands.RdataCommand` for more details. + + +# Example + +An example iteraction is shown below. Each line is prefixed with '>' or '<' to +indicate which side is sending, these are *not* included on the wire:: + + * connection established * + > SERVER localhost:8823 + > PING 1490197665618 + < NAME synapse.app.appservice + < PING 1490197665618 + < REPLICATE events 1 + < REPLICATE backfill 1 + < REPLICATE caches 1 + > POSITION events 1 + > POSITION backfill 1 + > POSITION caches 1 + > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] + > RDATA events 14 ["$149019767112vOHxz:localhost:8823", + "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null] + < PING 1490197675618 + > ERROR server stopping + * connection closed by server * + +The `POSITION` command sent by the server is used to set the clients position +without needing to send data with the `RDATA` command. +""" + +from twisted.internet import defer +from twisted.protocols.basic import LineOnlyReceiver + +from commands import ( + COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, + ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand, + NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand, +) +from streams import STREAMS_MAP + +from synapse.util.stringutils import random_string + +import logging +import synapse.metrics +import struct +import fcntl + + +metrics = synapse.metrics.get_metrics_for(__name__) + +inbound_commands_counter = metrics.register_counter( + "inbound_commands", labels=["command", "name", "conn_id"], +) +outbound_commands_counter = metrics.register_counter( + "outbound_commands", labels=["command", "name", "conn_id"], +) + + +# A list of all connected protocols. This allows us to send metrics about the +# connections. +connected_connections = [] + + +logger = logging.getLogger(__name__) + + +PING_TIME = 5000 + + +class ConnectionStates(object): + CONNECTING = "connecting" + ESTABLISHED = "established" + PAUSED = "paused" + CLOSED = "closed" + + +class BaseReplicationStreamProtocol(LineOnlyReceiver): + """Base replication protocol shared between client and server. + + Reads lines (ignoring blank ones) and parses them into command classes, + asserting that they are valid for the given direction, i.e. server commands + are only sent by the server. + + On receiving a new command it calls `on_` with the parsed + command. + + It also sends `PING` periodically, and correctly times out remote connections + (if they send a `PING` command) + """ + delimiter = b'\n' + + VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive + VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send + + max_line_buffer = 10000 + + def __init__(self, clock): + self.clock = clock + + self.last_received_command = self.clock.time_msec() + self.last_sent_command = 0 + self.time_we_closed = None # When we requested the connection be closed + + self.received_ping = False # Have we reecived a ping from the other side + + self.state = ConnectionStates.CONNECTING + + self.name = "anon" # The name sent by a client. + self.conn_id = random_string(5) # To dedupe in case of name clashes. + + # List of pending commands to send once we've established the connection + self.pending_commands = [] + + # The LoopingCall for sending pings. + self._send_ping_loop = None + + def connectionMade(self): + logger.info("[%s] Connection established", self.id()) + + self.state = ConnectionStates.ESTABLISHED + + connected_connections.append(self) # Register connection for metrics + + self.transport.registerProducer(self, True) # For the *Producing callbacks + + self._send_pending_commands() + + # Starts sending pings + self._send_ping_loop = self.clock.looping_call(self.send_ping, 5000) + + # Always send the initial PING so that the other side knows that they + # can time us out. + self.send_command(PingCommand(self.clock.time_msec())) + + def send_ping(self): + """Periodically sends a ping and checks if we should close the connection + due to the other side timing out. + """ + now = self.clock.time_msec() + + if self.time_we_closed: + if now - self.time_we_closed > PING_TIME * 3: + logger.info( + "[%s] Failed to close connection gracefully, aborting", self.id() + ) + self.transport.abortConnection() + else: + if now - self.last_sent_command >= PING_TIME: + self.send_command(PingCommand(now)) + + if self.received_ping and now - self.last_received_command > PING_TIME * 3: + logger.info( + "[%s] Connection hasn't received command in %r ms. Closing.", + self.id(), now - self.last_received_command + ) + self.send_error("ping timeout") + + def lineReceived(self, line): + """Called when we've received a line + """ + if line.strip() == "": + # Ignore blank lines + return + + line = line.decode("utf-8") + cmd_name, rest_of_line = line.split(" ", 1) + + if cmd_name not in self.VALID_INBOUND_COMMANDS: + logger.error("[%s] invalid command %s", self.id(), cmd_name) + self.send_error("invalid command: %s", cmd_name) + return + + self.last_received_command = self.clock.time_msec() + + inbound_commands_counter.inc(cmd_name, self.name, self.conn_id) + + cmd_cls = COMMAND_MAP[cmd_name] + try: + cmd = cmd_cls.from_line(rest_of_line) + except Exception as e: + logger.exception( + "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line + ) + self.send_error( + "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line) + ) + return + + # Now lets try and call on_ function + try: + getattr(self, "on_%s" % (cmd_name,))(cmd) + except Exception: + logger.exception("[%s] Failed to handle line: %r", self.id(), line) + + def close(self): + self.time_we_closed = self.clock.time_msec() + self.transport.loseConnection() + self.on_connection_closed() + + def send_error(self, error_string, *args): + """Send an error to remote and close the connection. + """ + self.send_command(ErrorCommand(error_string % args)) + self.close() + + def send_command(self, cmd, do_buffer=True): + """Send a command if connection has been established. + + Args: + cmd (Command) + do_buffer (bool): Whether to buffer the message or always attempt + to send the command. This is mostly used to send an error + message if we're about to close the connection due our buffers + becoming full. + """ + if self.state == ConnectionStates.CLOSED: + logger.info("[%s] Not sending, connection closed", self.id()) + return + + if do_buffer and self.state != ConnectionStates.ESTABLISHED: + self._queue_command(cmd) + return + + outbound_commands_counter.inc(cmd.NAME, self.name, self.conn_id) + + string = "%s %s" % (cmd.NAME, cmd.to_line(),) + if "\n" in string: + raise Exception("Unexpected newline in command: %r", string) + + self.sendLine(string.encode("utf-8")) + + self.last_sent_command = self.clock.time_msec() + + def _queue_command(self, cmd): + """Queue the command until the connection is ready to write to again. + """ + logger.info("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) + self.pending_commands.append(cmd) + + if len(self.pending_commands) > self.max_line_buffer: + # The other side is failing to keep up and out buffers are becoming + # full, so lets close the connection. + # XXX: should we squawk more loudly? + logger.error("[%s] Remote failed to keep up", self.id()) + self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False) + self.close() + + def _send_pending_commands(self): + """Send any queued commandes + """ + pending = self.pending_commands + self.pending_commands = [] + for cmd in pending: + self.send_command(cmd) + + def on_PING(self, line): + self.received_ping = True + + def on_ERROR(self, cmd): + logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) + + def pauseProducing(self): + """This is called when both the kernel send buffer and the twisted + tcp connection send buffers have become full. + + We don't actually have any control over those sizes, so we buffer some + commands ourselves before knifing the connection due to the remote + failing to keep up. + """ + logger.info("[%s] Pause producing", self.id()) + self.state = ConnectionStates.PAUSED + + def resumeProducing(self): + """The remote has caught up after we started buffering! + """ + logger.info("[%s] Resume producing", self.id()) + self.state = ConnectionStates.ESTABLISHED + self._send_pending_commands() + + def stopProducing(self): + """We're never going to send any more data (normally because either + we or the remote has closed the connection) + """ + logger.info("[%s] Stop producing", self.id()) + self.on_connection_closed() + + def connectionLost(self, reason): + logger.info("[%s] Replication connection closed: %r", self.id(), reason) + + try: + # Remove us from list of connections to be monitored + connected_connections.remove(self) + except ValueError: + pass + + # Stop the looping call sending pings. + if self._send_ping_loop and self._send_ping_loop.running: + self._send_ping_loop.stop() + + self.on_connection_closed() + + def on_connection_closed(self): + logger.info("[%s] Connection was closed", self.id()) + + self.state = ConnectionStates.CLOSED + self.pending_commands = [] + + if self.transport: + self.transport.unregisterProducer() + + def __str__(self): + return "ReplicationConnection" % ( + self.name, self.conn_id, self.addr, + ) + + def id(self): + return "%s-%s" % (self.name, self.conn_id) + + +class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): + VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS + VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS + + def __init__(self, server_name, clock, streamer, addr): + BaseReplicationStreamProtocol.__init__(self, clock) # Old style class + + self.server_name = server_name + self.streamer = streamer + self.addr = addr + + # The streams the client has subscribed to and is up to date with + self.replication_streams = set() + + # The streams the client is currently subscribing to. + self.connecting_streams = set() + + # Map from stream name to list of updates to send once we've finished + # subscribing the client to the stream. + self.pending_rdata = {} + + def connectionMade(self): + self.send_command(ServerCommand(self.server_name)) + BaseReplicationStreamProtocol.connectionMade(self) + self.streamer.new_connection(self) + + def on_NAME(self, cmd): + self.name = cmd.data + + def on_USER_SYNC(self, cmd): + self.streamer.on_user_sync(self.conn_id, cmd.user_id, cmd.is_syncing) + + def on_REPLICATE(self, cmd): + stream_name = cmd.stream_name + token = cmd.token + + if stream_name == "ALL": + # Subscribe to all streams we're publishing to. + for stream in self.streamer.streams_by_name.iterkeys(): + self.subscripe_to_stream(stream, token) + else: + self.subscripe_to_stream(stream_name, token) + + def on_FEDERATION_ACK(self, cmd): + self.streamer.federation_ack(cmd.token) + + def on_REMOVE_PUSHER(self, cmd): + self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) + + def onINVALIDATE_CACHE(self, cmd): + self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + + @defer.inlineCallbacks + def subscripe_to_stream(self, stream_name, token): + """Subscribe the remote to a streams. + + This invloves checking if they've missed anything and sending those + updates down if they have. During that time new updates for the stream + are queued and sent once we've sent down any missed updates. + """ + self.replication_streams.discard(stream_name) + self.connecting_streams.add(stream_name) + + try: + # Get missing updates + updates, current_token = yield self.streamer.get_stream_updates( + stream_name, token, + ) + + # Send all the missing updates + for update in updates: + token, row = update[0], update[1] + self.send_command(RdataCommand(stream_name, token, row)) + + # Now we can send any updates that came in while we were subscribing + pending_rdata = self.pending_rdata.pop(stream_name, []) + for token, update in pending_rdata: + self.send_command(RdataCommand(stream_name, token, update)) + + # We send a POSITION command to ensure that they have an up to + # date token (especially useful if we didn't send any updates + # above) + self.send_command(PositionCommand(stream_name, current_token)) + + # They're now fully subscribed + self.replication_streams.add(stream_name) + except Exception as e: + logger.exception("[%s] Failed to handle REPLICATE command", self.id()) + self.send_error("failed to handle replicate: %r", e) + finally: + self.connecting_streams.discard(stream_name) + + def stream_update(self, stream_name, token, data): + """Called when a new update is available to stream to clients. + + We need to check if the client is interested in the stream or not + """ + if stream_name in self.replication_streams: + # The client is subscribed to the stream + self.send_command(RdataCommand(stream_name, token, data)) + elif stream_name in self.connecting_streams: + # The client is being subscribed to the stream + logger.info("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) + self.pending_rdata.setdefault(stream_name, []).append((token, data)) + else: + # The client isn't subscribed + logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) + + def send_sync(self, data): + self.send_command(SyncCommand(data)) + + def on_connection_closed(self): + BaseReplicationStreamProtocol.on_connection_closed(self) + logger.info("[%s] Replication connection closed", self.id()) + self.streamer.lost_connection(self) + + +class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): + VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS + VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS + + def __init__(self, client_name, server_name, clock, handler): + BaseReplicationStreamProtocol.__init__(self, clock) + + self.client_name = client_name + self.server_name = server_name + self.handler = handler + + # Map of stream to batched updates. See RdataCommand for info on how + # batching works. + self.pending_batches = {} + + def connectionMade(self): + self.send_command(NameCommand(self.client_name)) + BaseReplicationStreamProtocol.connectionMade(self) + + # Once we've connected subscribe to the necessary streams + for stream_name, token in self.handler.get_streams_to_replicate().iteritems(): + self.replicate(stream_name, token) + + # Tell the server if we have any users currently syncing (should only + # happen on synchrotrons) + currently_syncing = self.handler.get_currently_syncing_users() + for user_id in currently_syncing: + self.send_command(UserSyncCommand(user_id, True)) + + # We've now finished connecting to so inform the client handler + self.handler.update_connection(self) + + def on_SERVER(self, cmd): + if cmd.data != self.server_name: + logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) + self.transport.abortConnection() + + def on_RDATA(self, cmd): + try: + row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row) + except Exception: + logger.exception( + "[%s] Failed to parse RDATA: %r %r", + self.id(), cmd.stream_name, cmd.row + ) + raise + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream. Batch + # until we get an update for the stream with a non None token + self.pending_batches.setdefault(cmd.stream_name, []).append(row) + else: + # Check if this is the last of a batch of updates + rows = self.pending_batches.pop(cmd.stream_name, []) + rows.append(row) + + self.handler.on_rdata(cmd.stream_name, cmd.token, rows) + + def on_POSITION(self, cmd): + self.handler.on_position(cmd.stream_name, cmd.token) + + def on_SYNC(self, cmd): + self.handler.on_sync(cmd.data) + + def replicate(self, stream_name, token): + """Send the subscription request to the server + """ + if stream_name not in STREAMS_MAP: + raise Exception("Invalid stream name %r" % (stream_name,)) + + logger.info( + "[%s] Subscribing to replication stream: %r from %r", + self.id(), stream_name, token + ) + + self.send_command(ReplicateCommand(stream_name, token)) + + def on_connection_closed(self): + BaseReplicationStreamProtocol.on_connection_closed(self) + self.handler.update_connection(None) + + +# The following simply registers metrics for the replication connections + +metrics.register_callback( + "pending_commands", + lambda: { + (p.name, p.conn_id): len(p.pending_commands) + for p in connected_connections + }, + labels=["name", "conn_id"], +) + + +def transport_buffer_size(protocol): + if protocol.transport: + size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen + return size + return 0 + + +metrics.register_callback( + "transport_send_buffer", + lambda: { + (p.name, p.conn_id): transport_buffer_size(p) + for p in connected_connections + }, + labels=["name", "conn_id"], +) + + +def transport_kernel_read_buffer_size(protocol, read=True): + SIOCINQ = 0x541B + SIOCOUTQ = 0x5411 + + if protocol.transport: + fileno = protocol.transport.getHandle().fileno() + if read: + op = SIOCINQ + else: + op = SIOCOUTQ + size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0] + return size + return 0 + + +metrics.register_callback( + "transport_kernel_send_buffer", + lambda: { + (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False) + for p in connected_connections + }, + labels=["name", "conn_id"], +) + + +metrics.register_callback( + "transport_kernel_read_buffer", + lambda: { + (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True) + for p in connected_connections + }, + labels=["name", "conn_id"], +) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py new file mode 100644 index 000000000000..3e19ddd20939 --- /dev/null +++ b/synapse/replication/tcp/resource.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The server side of the replication stream. +""" + +from twisted.internet import defer, reactor +from twisted.internet.protocol import Factory + +from streams import STREAMS_MAP, FederationStream +from protocol import ServerReplicationStreamProtocol + +from synapse.util.metrics import Measure, measure_func + +import logging +import synapse.metrics + + +metrics = synapse.metrics.get_metrics_for(__name__) +stream_updates_counter = metrics.register_counter( + "stream_updates", labels=["stream_name"] +) +user_sync_counter = metrics.register_counter("user_sync") +federation_ack_counter = metrics.register_counter("federation_ack") +remove_pusher_counter = metrics.register_counter("remove_pusher") +invalidate_cache_counter = metrics.register_counter("invalidate_cache") + +logger = logging.getLogger(__name__) + + +class ReplicationStreamProtocolFactory(Factory): + """Factory for new replication connections. + """ + def __init__(self, hs): + self.streamer = ReplicationStreamer(hs) + self.clock = hs.get_clock() + self.server_name = hs.config.server_name + + def buildProtocol(self, addr): + return ServerReplicationStreamProtocol( + self.server_name, + self.clock, + self.streamer, + addr + ) + + +class ReplicationStreamer(object): + """Handles replication connections. + + This needs to be poked when new replication data may be available. When new + data is available it will propogate to all connected clients. + """ + + def __init__(self, hs): + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + self.presence_handler = hs.get_presence_handler() + self.clock = hs.get_clock() + + # Current connections. + self.connections = [] + + metrics.register_callback("total_connections", lambda: len(self.connections)) + + # List of streams that clients can subscribe to. + # We only support federation stream if federation sending hase been + # disabled on the master. + self.streams = [ + stream(hs) for stream in STREAMS_MAP.itervalues() + if stream != FederationStream or not hs.config.send_federation + ] + + self.streams_by_name = {stream.NAME: stream for stream in self.streams} + + metrics.register_callback( + "connections_per_stream", + lambda: { + (stream_name,): len([ + conn for conn in self.connections + if stream_name in conn.replication_streams + ]) + for stream_name in self.streams_by_name + }, + labels=["stream_name"], + ) + + self.federation_sender = None + if not hs.config.send_federation: + self.federation_sender = hs.get_federation_sender() + + # Start listening for updates from the notifier + self.notifier_listener() + + # Keeps track of whether we are currently checking for updates + self.is_looping = False + self.pending_updates = False + + reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown) + + def on_shutdown(self): + # close all connections on shutdown + for conn in self.connections: + conn.send_error("server shutting down") + + @defer.inlineCallbacks + def notifier_listener(self): + """Sits forever looping on the notifier waiting for new data. + """ + while True: + yield self.notifier.wait_once_for_replication() + logger.debug("Woken up by notifier") + self.on_notifier_poke() + + @defer.inlineCallbacks + def on_notifier_poke(self): + """Checks if there is actually any new data and sends it to the + connections if there are. + """ + if not self.connections: + # Don't bother if nothing is listening + return + + # If we're in the process of checking for new updates, mark that fact + # and return + if self.is_looping: + logger.debug("Noitifier poke loop already running") + self.pending_updates = True + return + + self.pending_updates = True + self.is_looping = True + + try: + # Keep looping while there have been pokes about potential updates. + # This protects against the race where a stream we already checked + # gets an update while we're handling other streams. + while self.pending_updates: + self.pending_updates = False + + with Measure(self.clock, "repl.stream.get_updates"): + # First we tell the streams that they should update their + # current tokens. + for stream in self.streams: + stream.advance_current_token() + + for stream in self.streams: + if stream.last_token == stream.upto_token: + continue + + logger.debug( + "Getting stream: %s: %s -> %s", + stream.NAME, stream.last_token, stream.upto_token + ) + updates, current_token = yield stream.get_updates() + + logger.debug( + "Sending %d updates to %d connections", + len(updates), len(self.connections), + ) + + if updates: + logger.info( + "Streaming: %s -> %s", stream.NAME, updates[-1][0] + ) + stream_updates_counter.inc_by(len(updates), stream.NAME) + + # Some streams return multiple rows with the same stream IDs, + # we need to make sure they get sent out in batches. We do + # this by setting the current token to all but the last of + # a series of updates with the same token to have a None + # token. See RdataCommand for more details. + batched_updates = _batch_updates(updates) + + for conn in self.connections: + for token, row in batched_updates: + try: + conn.stream_update(stream.NAME, token, row) + except Exception: + logger.exception("Failed to replicate") + + logger.debug("No more pending updates, breaking poke loop") + finally: + self.pending_updates = False + self.is_looping = False + + @measure_func("repl.get_stream_updates") + def get_stream_updates(self, stream_name, token): + """For a given stream get all updates since token. This is called when + a client first subscribes to a stream. + """ + stream = self.streams_by_name.get(stream_name, None) + if not stream: + raise Exception("unknown stream %s", stream_name) + + return stream.get_updates_since(token) + + @measure_func("repl.federation_ack") + def federation_ack(self, token): + """We've received an ack for federation stream from a client. + """ + federation_ack_counter.inc() + if self.federation_sender: + self.federation_sender.federation_ack(token) + + @measure_func("repl.on_user_sync") + def on_user_sync(self, conn_id, user_id, is_syncing): + """A client has started/stopped syncing on a worker. + """ + user_sync_counter.inc() + self.presence_handler.update_external_syncs_row( + conn_id, user_id, is_syncing + ) + + @measure_func("repl.on_remove_pusher") + @defer.inlineCallbacks + def on_remove_pusher(self, app_id, push_key, user_id): + """A client has asked us to remove a pusher + """ + remove_pusher_counter.inc() + yield self.store.delete_pusher_by_app_id_pushkey_user_id( + app_id=app_id, pushkey=push_key, user_id=user_id + ) + + self.notifier.on_new_replication_data() + + @measure_func("repl.on_invalidate_cache") + def on_invalidate_cache(self, cache_func, keys): + """The client has asked us to invalidate a cache + """ + invalidate_cache_counter.inc() + getattr(self.store, cache_func).invalidate(tuple(keys)) + + def send_sync_to_all_connections(self, data): + """Sends a SYNC command to all clients. + + Used in tests. + """ + for conn in self.connections: + conn.send_sync(data) + + def new_connection(self, connection): + """A new client connection has been established + """ + self.connections.append(connection) + + def lost_connection(self, connection): + """A client connection has been lost + """ + try: + self.connections.remove(connection) + except ValueError: + pass + + # We need to tell the presence handler that the connection has been + # lost so that it can handle any ongoing syncs on that connection. + self.presence_handler.update_external_syncs_clear(connection.conn_id) + + +def _batch_updates(updates): + """Takes a list of updates of form [(token, row)] and sets the token to + None for all rows where the next row has the same token. This is used to + implement batching. + + For example: + + [(1, _), (1, _), (2, _), (3, _), (3, _)] + + becomes: + + [(None, _), (1, _), (2, _), (None, _), (3, _)] + """ + if not updates: + return [] + + new_updates = [] + for i, update in enumerate(updates[:-1]): + if update[0] == updates[i + 1][0]: + new_updates.append((None, update[1])) + else: + new_updates.append(update) + + new_updates.append(updates[-1]) + return new_updates diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py new file mode 100644 index 000000000000..b180a666f7aa --- /dev/null +++ b/synapse/replication/tcp/streams.py @@ -0,0 +1,393 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines all the valid streams that clients can subscribe to, and the format +of the rows returned by each stream. + +Each stream is defined by the following information: + + stream name: The name of the stream + row type: The type that is used to serialise/deserialse the row + current_token: The function that returns the current token for the stream + update_function: The function that returns a list of updates between two tokens +""" + +from twisted.internet import defer +from collections import namedtuple + +import logging + + +logger = logging.getLogger(__name__) + + +MAX_EVENTS_BEHIND = 10000 + + +EventStreamRow = namedtuple("EventStreamRow", + ("event_id", "room_id", "type", "state_key", "redacts")) +BackfillStreamRow = namedtuple("BackfillStreamRow", + ("event_id", "room_id", "type", "state_key", "redacts")) +PresenceStreamRow = namedtuple("PresenceStreamRow", + ("user_id", "state", "last_active_ts", + "last_federation_update_ts", "last_user_sync_ts", + "status_msg", "currently_active")) +TypingStreamRow = namedtuple("TypingStreamRow", + ("room_id", "user_ids")) +ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", + ("room_id", "receipt_type", "user_id", "event_id", + "data")) +PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) +PushersStreamRow = namedtuple("PushersStreamRow", + ("user_id", "app_id", "pushkey", "deleted",)) +CachesStreamRow = namedtuple("CachesStreamRow", + ("cache_func", "keys", "invalidation_ts",)) +PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", + ("room_id", "visibility", "appservice_id", + "network_id",)) +DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ("user_id", "destination",)) +ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) +FederationStreamRow = namedtuple("FederationStreamRow", ("type", "data",)) +TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", + ("user_id", "room_id", "data")) +AccountDataStreamRow = namedtuple("AccountDataStream", + ("user_id", "room_id", "data_type", "data")) + + +class Stream(object): + """Base class for the streams. + + Provides a `get_updates()` function that returns new updates since the last + time it was called up until the point `advance_current_token` was called. + """ + NAME = None # The name of the string + ROW_TYPE = None # The type of the row + _LIMITED = True # Whether the update funciton takes a limit + + def __init__(self, hs): + # The token from which we last asked for updates + self.last_token = self.current_token() + + # The token that we will get updates up to + self.upto_token = self.current_token() + + def advance_current_token(self): + """Updates `upto_token` to "now", which updates up until which point + get_updates[_since] will fetch rows till. + """ + self.upto_token = self.current_token() + + @defer.inlineCallbacks + def get_updates(self): + """Gets all updates since the last time this function was called (or + since the stream was constructed if it hadn't been called before), + until the `upto_token` + """ + updates, current_token = yield self.get_updates_since(self.last_token) + self.last_token = current_token + + defer.returnValue((updates, current_token)) + + @defer.inlineCallbacks + def get_updates_since(self, from_token): + """Like get_updates except allows specifying from when we should + stream updates + """ + if from_token in ("NOW", "now"): + defer.returnValue(([], self.upto_token)) + + current_token = self.upto_token + + from_token = int(from_token) + + if from_token == current_token: + defer.returnValue(([], current_token)) + + if self._LIMITED: + rows = yield self.update_function( + from_token, current_token, + limit=MAX_EVENTS_BEHIND + 1, + ) + + if len(rows) >= MAX_EVENTS_BEHIND: + raise Exception("stream %s has fallen behined" % (self.NAME)) + else: + rows = yield self.update_function( + from_token, current_token, + ) + + updates = [(row[0], self.ROW_TYPE(*row[1:])) for row in rows] + + defer.returnValue((updates, current_token)) + + def current_token(self): + """Gets the current token of the underlying streams. Should be provided + by the sub classes + """ + raise NotImplementedError() + + def update_function(self, from_token, current_token, limit=None): + """Get updates between from_token and to_token. If Stream._LIMITED is + True then limit is provided, otherwise its not. + """ + raise NotImplementedError() + + +class EventsStream(Stream): + """We received a new event, or an event went from being an outlier to not + """ + NAME = "events" + ROW_TYPE = EventStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + self.current_token = store.get_current_events_token + self.update_function = store.get_all_new_forward_event_rows + + super(EventsStream, self).__init__(hs) + + +class BackfillStream(Stream): + """We fetched some old events and either we had never seen that event before + or it went from being an outlier to not. + """ + NAME = "backfill" + ROW_TYPE = BackfillStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + self.current_token = store.get_current_backfill_token + self.update_function = store.get_all_new_backfill_event_rows + + super(BackfillStream, self).__init__(hs) + + +class PresenceStream(Stream): + NAME = "presence" + _LIMITED = False + ROW_TYPE = PresenceStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + presence_handler = hs.get_presence_handler() + + self.current_token = store.get_current_presence_token + self.update_function = presence_handler.get_all_presence_updates + + super(PresenceStream, self).__init__(hs) + + +class TypingStream(Stream): + NAME = "typing" + _LIMITED = False + ROW_TYPE = TypingStreamRow + + def __init__(self, hs): + typing_handler = hs.get_typing_handler() + + self.current_token = typing_handler.get_current_token + self.update_function = typing_handler.get_all_typing_updates + + super(TypingStream, self).__init__(hs) + + +class ReceiptsStream(Stream): + NAME = "receipts" + ROW_TYPE = ReceiptsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_max_receipt_stream_id + self.update_function = store.get_all_updated_receipts + + super(ReceiptsStream, self).__init__(hs) + + +class PushRulesStream(Stream): + """A user has changed their push rules + """ + NAME = "push_rules" + ROW_TYPE = PushRulesStreamRow + + def __init__(self, hs): + self.store = hs.get_datastore() + super(PushRulesStream, self).__init__(hs) + + def current_token(self): + push_rules_token, _ = self.store.get_push_rules_stream_token() + return push_rules_token + + @defer.inlineCallbacks + def update_function(self, from_token, to_token, limit): + rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit) + defer.returnValue([(row[0], row[2]) for row in rows]) + + +class PushersStream(Stream): + """A user has added/changed/removed a pusher + """ + NAME = "pushers" + ROW_TYPE = PushersStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_pushers_stream_token + self.update_function = store.get_all_updated_pushers_rows + + super(PushersStream, self).__init__(hs) + + +class CachesStream(Stream): + """A cache was invalidated on the master and no other stream would invalidate + the cache on the workers + """ + NAME = "caches" + ROW_TYPE = CachesStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_cache_stream_token + self.update_function = store.get_all_updated_caches + + super(CachesStream, self).__init__(hs) + + +class PublicRoomsStream(Stream): + """The public rooms list changed + """ + NAME = "public_rooms" + ROW_TYPE = PublicRoomsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_current_public_room_stream_id + self.update_function = store.get_all_new_public_rooms + + super(PublicRoomsStream, self).__init__(hs) + + +class DeviceListsStream(Stream): + """Someone added/changed/removed a device + """ + NAME = "device_lists" + _LIMITED = False + ROW_TYPE = DeviceListsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_device_stream_token + self.update_function = store.get_all_device_list_changes_for_remotes + + super(DeviceListsStream, self).__init__(hs) + + +class ToDeviceStream(Stream): + """New to_device messages for a client + """ + NAME = "to_device" + ROW_TYPE = ToDeviceStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_to_device_stream_token + self.update_function = store.get_all_new_device_messages + + super(ToDeviceStream, self).__init__(hs) + + +class FederationStream(Stream): + """Data to be sent over federation. Only available when master has federation + sending disabled. + """ + NAME = "federation" + ROW_TYPE = FederationStreamRow + + def __init__(self, hs): + federation_sender = hs.get_federation_sender() + + self.current_token = federation_sender.get_current_token + self.update_function = federation_sender.get_replication_rows + + super(FederationStream, self).__init__(hs) + + +class TagAccountDataStream(Stream): + """Someone added/removed a tag for a room + """ + NAME = "tag_account_data" + ROW_TYPE = TagAccountDataStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_max_account_data_stream_id + self.update_function = store.get_all_updated_tags + + super(TagAccountDataStream, self).__init__(hs) + + +class AccountDataStream(Stream): + """Global or per room account data was changed + """ + NAME = "account_data" + ROW_TYPE = AccountDataStreamRow + + def __init__(self, hs): + self.store = hs.get_datastore() + + self.current_token = self.store.get_max_account_data_stream_id + + super(AccountDataStream, self).__init__(hs) + + @defer.inlineCallbacks + def update_function(self, from_token, to_token, limit): + global_results, room_results = yield self.store.get_all_updated_account_data( + from_token, from_token, to_token, limit + ) + + results = list(room_results) + results.extend( + (stream_id, user_id, None, account_data_type, content,) + for stream_id, user_id, account_data_type, content in global_results + ) + + defer.returnValue(results) + + +STREAMS_MAP = { + stream.NAME: stream + for stream in ( + EventsStream, + BackfillStream, + PresenceStream, + TypingStream, + ReceiptsStream, + PushRulesStream, + PushersStream, + CachesStream, + PublicRoomsStream, + DeviceListsStream, + ToDeviceStream, + FederationStream, + TagAccountDataStream, + AccountDataStream, + ) +} diff --git a/synapse/server.py b/synapse/server.py index c577032041c7..6310152560f6 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -132,6 +132,7 @@ def build_DEPENDENCY(self) 'federation_sender', 'receipts_handler', 'macaroon_generator', + 'tcp_replication', ] def __init__(self, hostname, **kwargs): @@ -290,6 +291,9 @@ def build_federation_sender(self): def build_receipts_handler(self): return ReceiptsHandler(self) + def build_tcp_replication(self): + raise NotImplementedError() + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 6beeff8b00a3..bcb8713c0195 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -529,7 +529,7 @@ def get_user_whose_devices_changed(self, from_key): rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) defer.returnValue(set(row[0] for row in rows)) - def get_all_device_list_changes_for_remotes(self, from_key): + def get_all_device_list_changes_for_remotes(self, from_key, to_key): """Return a list of `(stream_id, user_id, destination)` which is the combined list of changes to devices, and which destinations need to be poked. `destination` may be None if no destinations need to be poked. @@ -537,11 +537,11 @@ def get_all_device_list_changes_for_remotes(self, from_key): sql = """ SELECT stream_id, user_id, destination FROM device_lists_stream LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) - WHERE stream_id > ? + WHERE ? < stream_id AND stream_id <= ? """ return self._execute( "get_all_device_list_changes_for_remotes", None, - sql, from_key, + sql, from_key, to_key ) @defer.inlineCallbacks diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5f4fcd983239..483df8e4b352 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1756,6 +1756,94 @@ def get_current_backfill_token(self): """The current minimum token that backfilled events have reached""" return -self._backfill_id_gen.get_current_token() + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_id, upper_bound)) + new_event_updates.extend(txn.fetchall()) + + return new_event_updates + return self.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend(txn.fetchall()) + + return new_event_updates + return self.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + @cached(num_args=5, max_entries=10) def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8cc9f0353bf6..0c8a9f1aa469 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -135,6 +135,36 @@ def get_all_updated_pushers_txn(txn): "get_all_updated_pushers", get_all_updated_pushers_txn ) + def get_all_updated_pushers_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed(([], [])) + + def get_all_updated_pushers_rows_txn(txn): + sql = ( + "SELECT id, user_name, app_id, pushkey" + " FROM pushers" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + results = [list(row) + [False] for row in txn.fetchall()] + + sql = ( + "SELECT stream_id, user_id, app_id, pushkey" + " FROM deleted_pushers" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + + results.extend(list(row) + [True] for row in txn.fetchall()) + results.sort() + + return results + return self.runInteraction( + "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn + ) + @cachedInlineCallbacks(num_args=1, max_entries=15000) def get_if_user_has_pusher(self, user_id): # This only exists for the cachedList decorator diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index b82868054d37..81063f19a152 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from twisted.internet import defer, reactor from tests import unittest from mock import Mock, NonCallableMock from tests.utils import setup_test_homeserver -from synapse.replication.resource import ReplicationResource +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.replication.tcp.client import ( + ReplicationClientHandler, ReplicationClientFactory, +) class BaseSlavedStoreTestCase(unittest.TestCase): @@ -33,18 +36,29 @@ def setUp(self): ) self.hs.get_ratelimiter().send_message.return_value = (True, 0) - self.replication = ReplicationResource(self.hs) - self.master_store = self.hs.get_datastore() self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.event_id = 0 + server_factory = ReplicationStreamProtocolFactory(self.hs) + listener = reactor.listenUNIX("\0xxx", server_factory) + self.addCleanup(listener.stopListening) + self.streamer = server_factory.streamer + + self.replication_handler = ReplicationClientHandler(self.slaved_store) + client_factory = ReplicationClientFactory( + self.hs, "client_name", self.replication_handler + ) + client_connector = reactor.connectUNIX("\0xxx", client_factory) + self.addCleanup(client_factory.stopTrying) + self.addCleanup(client_connector.disconnect) + @defer.inlineCallbacks def replicate(self): - streams = self.slaved_store.stream_positions() - writer = yield self.replication.replicate(streams, 100) - result = writer.finish() - yield self.slaved_store.process_replication(result) + yield self.streamer.on_notifier_poke() + d = self.replication_handler.await_sync("replication_test") + self.streamer.send_sync_to_all_connections("replication_test") + yield d @defer.inlineCallbacks def check(self, method, args, expected_result=None):