diff --git a/CHANGELOG.md b/CHANGELOG.md index ecb01d3d..9452120d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,33 @@ Every entry has a category for which we use the following visual abbreviations: ## Unreleased +- 🎁 The `zmq-app` and `zeek` plugins now use the Unix select system call for + improved performance during message passing. The previous approach impacted + the performance with a constant delay for every message and did not scale. + The new approach saves at least that constant factor *per message*. For ZeroMQ + publishing we observed a speedup of approximately factor 183 for 100k events. + [#61](https://github.com/tenzir/threatbus/pull/61) + +- 🎁 The `rabbitmq` backbone plugin now uses an asynchronous + [SelectConnection](https://pika.readthedocs.io/en/stable/modules/adapters/select.html) + instead of a blocking one. We measured a speedup of approximately factor 1.2 + for 100k events. + [#61](https://github.com/tenzir/threatbus/pull/61) + +- 🎁 Threat Bus now has a controlled shutdown. Pressing ctrl+c first shuts down + backbone plugins, then app plugins, and lastly Threat Bus itself. + [#61](https://github.com/tenzir/threatbus/pull/61) + +- ⚠️ There exists a new base class for implementing plugin-threads. Plugin + developers should extend the new `StoppableWorker` for every plugin. Threat + Bus and all plugins in this repository now implement that class. + [#61](https://github.com/tenzir/threatbus/pull/61) + +- ⚠️ Threat Bus and all plugins now use + [multiprocessing.JoinableQueue](https://docs.python.org/3.8/library/multiprocessing.html#multiprocessing.JoinableQueue) + for message passing. + [#61](https://github.com/tenzir/threatbus/pull/61) + - 🎁 The `zmq-app` plugin now supports synchronous heartbeats. With heartbeats, both Threat Bus and the connected apps can mutually ensure that the connected party is still alive. diff --git a/apps/vast/CHANGELOG.md b/apps/vast/CHANGELOG.md index 81cd4f9e..b001873e 100644 --- a/apps/vast/CHANGELOG.md +++ b/apps/vast/CHANGELOG.md @@ -12,6 +12,13 @@ Every entry has a category for which we use the following visual abbreviations: ## Unreleased +- 🎁 `pyvast-threatbus` now uses asynchronous background tasks to query VAST + concurrently. VAST queries were executed sequentially prior to this change. + This boosts the performance by the factor of allowed concurrent background + tasks. Users can control the maximum number of concurrent background tasks + with the new `max-background-tasks` configuration option. + [#61](https://github.com/tenzir/threatbus/pull/61) + - 🎁 The Python app to connect [VAST](https://github.com/tenzir/vast) with Threat Bus is now packaged and published on [PyPI](https://pypi.org/). You can install the package via `pip install pyvast-threatbus`. diff --git a/apps/vast/config.yaml.example b/apps/vast/config.yaml.example index 570f9657..e596a049 100644 --- a/apps/vast/config.yaml.example +++ b/apps/vast/config.yaml.example @@ -10,3 +10,5 @@ unflatten: true transform_context: fever alertify --alert-prefix 'MY PREFIX' --extra-key my-ioc --ioc %ioc # optional. remove the field if you simply want to report back sightings to Threat Bus sink: STDOUT +# limits the amount of concurrent background tasks for querying vast +max_background_tasks: 100 diff --git a/apps/vast/message_mapping.py b/apps/vast/message_mapping.py index ddd1520e..d7c92d1b 100644 --- a/apps/vast/message_mapping.py +++ b/apps/vast/message_mapping.py @@ -1,4 +1,3 @@ -from datetime import datetime from dateutil import parser as dateutil_parser from ipaddress import ip_address import json diff --git a/apps/vast/pyvast_threatbus.py b/apps/vast/pyvast_threatbus.py index 29dd9c5d..bd7bb7a4 100755 --- a/apps/vast/pyvast_threatbus.py +++ b/apps/vast/pyvast_threatbus.py @@ -24,8 +24,11 @@ logger = logging.getLogger(__name__) matcher_name = None -async_tasks = [] # list of all running async tasks of the app -p2p_topic = None # the p2p topic that was given to the app upon successful subscription +# list of all running async tasks of the bridge +async_tasks = [] +# the p2p topic that was given to the vast-bridge upon successful subscription +p2p_topic = None +max_open_tasks = None def setup_logging(level): @@ -53,6 +56,7 @@ def validate_config(config: confuse.Subview): config["retro_match"].get(bool) config["retro_match_max_events"].get(int) config["unflatten"].get(bool) + config["max_background_tasks"].get(int) # fallback values for the optional arguments config["transform_context"].add(None) @@ -71,13 +75,14 @@ def cancel_async_tasks(): async def start( - cmd: str, + vast_binary: str, vast_endpoint: str, zmq_endpoint: str, snapshot: int, retro_match: bool, retro_match_max_events: int, unflatten: bool, + max_open_files: int, transform_cmd: str = None, sink: str = None, ): @@ -92,12 +97,15 @@ async def start( @param retro_match Boolean flag to use retro-matching over live-matching @param retro_match_max_events Max amount of retro match results @param unflatten Boolean flag to unflatten JSON when received from VAST + @param max_open_files The maximum number of concurrent background tasks for VAST queries. @param transform_cmd The command to use to transform Sighting context with @param sink Forward sighting context to this sink (subprocess) instead of reporting back to Threat Bus """ - global logger, async_tasks, p2p_topic - vast = VAST(binary=cmd, endpoint=vast_endpoint, logger=logger) + global logger, async_tasks, p2p_topic, max_open_tasks + # needs to be created inside the same eventloop where it is used + max_open_tasks = asyncio.Semaphore(max_open_files) + vast = VAST(binary=vast_binary, endpoint=vast_endpoint, logger=logger) assert await vast.test_connection() is True, "Cannot connect to VAST" logger.debug(f"Calling Threat Bus management endpoint {zmq_endpoint}") @@ -140,7 +148,7 @@ async def start( async_tasks.append( asyncio.create_task( match_intel( - cmd, + vast_binary, vast_endpoint, intel_queue, sightings_queue, @@ -153,7 +161,9 @@ async def start( if not retro_match: async_tasks.append( - asyncio.create_task(live_match_vast(cmd, vast_endpoint, sightings_queue)) + asyncio.create_task( + live_match_vast(vast_binary, vast_endpoint, sightings_queue) + ) ) atexit.register(cancel_async_tasks) @@ -196,8 +206,84 @@ async def receive(pub_endpoint: str, topic: str, intel_queue: asyncio.Queue): await asyncio.sleep(0.05) # free event loop for other tasks +async def retro_match_vast( + vast_binary, + vast_endpoint, + retro_match_max_events, + intel, + sightings_queue, + unflatten, +): + """ + Turns the given intel into a valid VAST query and forwards all all query + results (sightings) to the sightings_queue. + @param vast_binary The vast binary command to use with PyVAST + @param vast_endpoint The endpoint of a running vast node ('host:port') + @param retro_match_max_events Max amount of retro match results + @param intel The IoC to query VAST for + @param sightings_queue The queue to put new sightings into + @param unflatten Boolean flag to unflatten JSON when received from VAST + """ + query = to_vast_query(intel) + if not query: + return + global logger, max_open_tasks + async with max_open_tasks: + vast = VAST(binary=vast_binary, endpoint=vast_endpoint, logger=logger) + proc = await vast.export(max_events=retro_match_max_events).json(query).exec() + reported = 0 + while not proc.stdout.at_eof(): + line = (await proc.stdout.readline()).decode().rstrip() + if line: + sighting = query_result_to_threatbus_sighting(line, intel, unflatten) + if not sighting: + logger.error(f"Could not parse VAST query result: {line}") + continue + reported += 1 + await sightings_queue.put(sighting) + logger.debug(f"Retro-matched {reported} sighting(s) for intel: {intel}") + + +async def ingest_vast_ioc(vast_binary, vast_endpoint, intel): + """ + Ingests the given intel as IoC into a VAST matcher. + @param vast_binary The vast binary command to use with PyVAST + @param vast_endpoint The endpoint of a running vast node ('host:port') + @param intel The IoC to query VAST for + """ + global logger + ioc = to_vast_ioc(intel) + if not ioc: + logger.error(f"Unable to convert Intel to VAST compatible IoC: {intel}") + return + vast = VAST(binary=vast_binary, endpoint=vast_endpoint, logger=logger) + proc = await vast.import_().json(type="intel.indicator").exec(stdin=ioc) + await proc.wait() + logger.debug(f"Ingested intel for live matching: {intel}") + + +async def remove_vast_ioc(vast_binary, vast_endpoint, intel): + """ + Removes the given intel as IoC from a VAST matcher. + @param vast_binary The vast binary command to use with PyVAST + @param vast_endpoint The endpoint of a running vast node ('host:port') + @param intel The IoC to query VAST for + """ + global logger, matcher_name + intel_type = get_vast_intel_type(intel) + ioc = get_ioc(intel) + if not ioc or not intel_type: + logger.error( + f"Cannot remove intel with missing intel_type or indicator: {intel}" + ) + return + vast = VAST(binary=vast_binary, endpoint=vast_endpoint, logger=logger) + await vast.matcher().ioc_remove(matcher_name, ioc, intel_type).exec() + logger.debug(f"Removed indicator {intel}") + + async def match_intel( - cmd: str, + vast_binary: str, vast_endpoint: str, intel_queue: asyncio.Queue, sightings_queue: asyncio.Queue, @@ -208,7 +294,7 @@ async def match_intel( """ Reads from the intel_queue and matches all IoCs, either via VAST's live-matching or retro-matching. - @param cmd The vast binary command to use with PyVAST + @param vast_binary The vast binary command to use with PyVAST @param vast_endpoint The endpoint of a running vast node ('host:port') @param intel_queue The queue to read new IoCs from @param sightings_queue The queue to put new sightings into @@ -216,8 +302,7 @@ async def match_intel( @param retro_match_max_events Max amount of retro match results @param unflatten Boolean flag to unflatten JSON when received from VAST """ - global logger - vast = VAST(binary=cmd, endpoint=vast_endpoint, logger=logger) + global logger, open_tasks while True: msg = await intel_queue.get() try: @@ -232,66 +317,40 @@ async def match_intel( continue if intel.operation == Operation.ADD: if retro_match: - query = to_vast_query(intel) - if not query: - continue - proc = ( - await vast.export(max_events=retro_match_max_events) - .json(query) - .exec() + asyncio.create_task( + retro_match_vast( + vast_binary, + vast_endpoint, + retro_match_max_events, + intel, + sightings_queue, + unflatten, + ) ) - reported = 0 - while not proc.stdout.at_eof(): - line = (await proc.stdout.readline()).decode().rstrip() - if line: - sighting = query_result_to_threatbus_sighting( - line, intel, unflatten - ) - if not sighting: - logger.error(f"Could not parse VAST query result: {line}") - continue - reported += 1 - await sightings_queue.put(sighting) - logger.debug(f"Retro-matched {reported} sighting(s) for intel: {intel}") else: - ioc = to_vast_ioc(intel) - if not ioc: - logger.error( - f"Unable to convert Intel to VAST compatible IoC: {intel}" - ) - continue - proc = await vast.import_().json(type="intel.indicator").exec(stdin=ioc) - await proc.wait() - logger.debug(f"Ingested intel for live matching: {intel}") + asyncio.create_task(ingest_vast_ioc(vast_binary, vast_endpoint, intel)) elif intel.operation == Operation.REMOVE: if retro_match: continue - intel_type = get_vast_intel_type(intel) - ioc = get_ioc(intel) - if not ioc or not intel_type: - logger.error( - f"Cannot remove intel with missing intel_type or indicator: {intel}" - ) - continue - global matcher_name - await vast.matcher().ioc_remove(matcher_name, ioc, intel_type).exec() - logger.debug(f"Removed indicator {intel}") + asyncio.create_task(remove_vast_ioc(vast_binary, vast_endpoint, intel)) else: logger.warning(f"Unsupported operation for indicator: {intel}") intel_queue.task_done() -async def live_match_vast(cmd: str, vast_endpoint: str, sightings_queue: asyncio.Queue): +async def live_match_vast( + vast_binary: str, vast_endpoint: str, sightings_queue: asyncio.Queue +): """ Starts a VAST matcher. Enqueues all matches from VAST to the sightings_queue. - @param cmd The VAST binary command to use with PyVAST + @param vast_binary The VAST binary command to use with PyVAST @param vast_endpoint The endpoint of a running VAST node @param sightings_queue The queue to put new sightings into @param retro_match Boolean flag to use retro-matching over live-matching """ global logger, matcher_name - vast = VAST(binary=cmd, endpoint=vast_endpoint, logger=logger) + vast = VAST(binary=vast_binary, endpoint=vast_endpoint, logger=logger) matcher_name = "threatbus-" + "".join(random.choice(letters) for i in range(10)) proc = await vast.matcher().start(name=matcher_name).exec() while True: @@ -565,6 +624,14 @@ def main(): default=None, help="If sink is specified, sightings are not reported back to Threat Bus. Instead, the context of a sighting (only the contents without the Threat Bus specific sighting structure) is forwarded to the specified sink via a UNIX pipe. This option takes a command line string to use and invokes it as direct subprocess without shell / globbing support.", ) + parser.add_argument( + "--max-background-tasks", + "-U", + dest="max-background-tasks", + default=100, + type=int, + help="Controls the maximum number of concurrent background tasks for VAST queries. Default is 100.", + ) args = parser.parse_args() config = confuse.Configuration("pyvast-threatbus") @@ -589,6 +656,7 @@ def main(): config["retro_match"].get(), config["retro_match_max_events"].get(), config["unflatten"].get(), + config["max_background_tasks"].get(), config["transform_context"].get(), config["sink"].get(), ) diff --git a/apps/vast/test_message_mapping.py b/apps/vast/test_message_mapping.py index 2bcf3ead..487e9bf6 100644 --- a/apps/vast/test_message_mapping.py +++ b/apps/vast/test_message_mapping.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone import unittest import json diff --git a/plugins/apps/threatbus_cif3/message_mapping.py b/plugins/apps/threatbus_cif3/message_mapping.py index 36e4bd04..43c5b44f 100644 --- a/plugins/apps/threatbus_cif3/message_mapping.py +++ b/plugins/apps/threatbus_cif3/message_mapping.py @@ -20,7 +20,8 @@ def map_to_cif(intel: Intel, logger, confidence, tags, tlp, group): - """Maps an Intel item to a CIFv3 compatible indicator format. + """ + Maps an Intel item to a CIFv3 compatible indicator format. @param intel The item to map @return the mapped intel item or None """ @@ -49,10 +50,8 @@ def map_to_cif(intel: Intel, logger, confidence, tags, tlp, group): } try: - ii = Indicator(**ii) + return Indicator(**ii) except InvalidIndicator as e: logger.error(f"Invalid CIF indicator {e}") except Exception as e: logger.error(f"CIF indicator error: {e}") - - return ii diff --git a/plugins/apps/threatbus_cif3/plugin.py b/plugins/apps/threatbus_cif3/plugin.py index b3c50553..4ad5090e 100644 --- a/plugins/apps/threatbus_cif3/plugin.py +++ b/plugins/apps/threatbus_cif3/plugin.py @@ -1,16 +1,69 @@ -from queue import Queue -import threading from cifsdk.client.http import HTTP as Client - -from threatbus_cif3.message_mapping import map_to_cif +from confuse import Subview +from multiprocessing import JoinableQueue +from queue import Empty import threatbus +from threatbus_cif3.message_mapping import map_to_cif +from typing import Callable, List + """Threatbus - Open Source Threat Intelligence Platform - plugin for CIFv3""" plugin_name = "cif3" +workers: List[threatbus.StoppableWorker] = list() -def validate_config(config): +class CIFPublisher(threatbus.StoppableWorker): + """ + Reports / publishes intel items back to the given CIF endpoint. + """ + + def __init__(self, intel_outq: JoinableQueue, cif: Client, config: Subview): + """ + @param intel_outq Publish all intel from this queue to CIF + @param cif The CIF client to use + @config the plugin config + """ + super(CIFPublisher, self).__init__() + self.intel_outq = intel_outq + self.cif = cif + self.config = config + + def run(self): + global logger + if not self.cif: + logger.error("CIF is not properly configured. Exiting.") + return + confidence = self.config["confidence"].as_number() + if not confidence: + confidence = 5 + tags = self.config["tags"].get(list) + tlp = self.config["tlp"].get(str) + group = self.config["group"].get(str) + + while self._running(): + try: + intel = self.intel_outq.get(block=True, timeout=1) + except Empty: + continue + if not intel: + logger.warning("Received unparsable intel item") + self.intel_outq.task_done() + continue + cif_mapped_intel = map_to_cif(intel, logger, confidence, tags, tlp, group) + if not cif_mapped_intel: + self.intel_outq.task_done() + continue + try: + logger.debug(f"Adding intel to CIF: {cif_mapped_intel}") + self.cif.indicators_create(cif_mapped_intel) + except Exception as err: + logger.error(f"CIF submission error: {err}") + finally: + self.intel_outq.task_done() + + +def validate_config(config: Subview): assert config, "config must not be None" config["tags"].get(list) config["tlp"].get(str) @@ -22,43 +75,15 @@ def validate_config(config): config["api"]["token"].get(str) -def receive_intel_from_backbone(watched_queue, cif, config): - """ - Reports / publishes intel items back to the given CIF endpoint. - @param watched_queue The py queue from which to read messages to submit on to CIF - """ - global logger - if not cif: - logger.error("CIF is not properly configured. Exiting.") - return - - confidence = config["confidence"].as_number() - if not confidence: - confidence = 5 - - tags = config["tags"].get(list) - tlp = config["tlp"].get(str) - group = config["group"].get(str) - - while True: - intel = watched_queue.get() - if not intel: - logger.warning("Received unparsable intel item") - continue - cif_mapped_intel = map_to_cif(intel, logger, confidence, tags, tlp, group) - if not cif_mapped_intel: - logger.warning("Could not map intel item") - continue - try: - logger.debug(f"Adding intel to CIF: {cif_mapped_intel}") - cif.indicators_create(cif_mapped_intel) - except Exception as err: - logger.error(f"CIF submission error: {err}") - - @threatbus.app -def run(config, logging, inq, subscribe_callback, unsubscribe_callback): - global logger +def run( + config: Subview, + logging: Subview, + inq: JoinableQueue, + subscribe_callback: Callable, + unsubscribe_callback: Callable, +): + global logger, workers logger = threatbus.logger.setup(logging, __name__) config = config[plugin_name] try: @@ -81,14 +106,20 @@ def run(config, logging, inq, subscribe_callback, unsubscribe_callback): ) return - from_backbone_to_cifq = Queue() + intel_outq = JoinableQueue() topic = "threatbus/intel" - subscribe_callback(topic, from_backbone_to_cifq) + subscribe_callback(topic, intel_outq) - threading.Thread( - target=receive_intel_from_backbone, - args=[from_backbone_to_cifq, cif, config], - daemon=True, - ).start() + workers.append(CIFPublisher(intel_outq, cif, config)) + for w in workers: + w.start() logger.info("CIF3 plugin started") + + +@threatbus.app +def stop(): + global logger, workers + for w in workers: + w.join() + logger.info("CIF3 plugin stopped") diff --git a/plugins/apps/threatbus_misp/plugin.py b/plugins/apps/threatbus_misp/plugin.py index cc84721e..1df1b36a 100644 --- a/plugins/apps/threatbus_misp/plugin.py +++ b/plugins/apps/threatbus_misp/plugin.py @@ -3,8 +3,9 @@ from datetime import datetime from itertools import product import json +from multiprocessing import JoinableQueue import pymisp -from queue import Queue +from queue import Empty import threading import threatbus from threatbus.data import MessageType, SnapshotEnvelope, SnapshotRequest @@ -22,9 +23,122 @@ plugin_name: str = "misp" misp: pymisp.api.PyMISP = None lock: threading.Lock = threading.Lock() -filter_config: List[ - Dict -] = None # required for message mapping, not available when Threat Bus invokes `snapshot()` -> global, initialized on startup +# filter_config is required for message mapping, but not available when Threat Bus invokes `snapshot()` -> global, initialized on startup +filter_config: List[Dict] = None +workers: List[threatbus.StoppableWorker] = list() + + +class SightingsPublisher(threatbus.StoppableWorker): + """ + Reports / publishes true-positive sightings of intelligence items back to the given MISP endpoint. + """ + + def __init__(self, outq: JoinableQueue): + """ + @param outq The queue from which to forward messages to MISP + """ + super(SightingsPublisher, self).__init__() + self.outq = outq + + def run(self): + global logger, misp, lock + if not misp: + return + while self._running(): + try: + sighting = self.outq.get(block=True, timeout=1) + except Empty: + continue + logger.debug(f"Reporting sighting: {sighting}") + misp_sighting = map_to_misp(sighting) + lock.acquire() + misp.add_sighting(misp_sighting) + lock.release() + self.outq.task_done() + + +class KafkaReceiver(threatbus.StoppableWorker): + """ + Binds a Kafka consumer to the the given host/port. Forwards all received messages to the inq. + """ + + def __init__(self, kafka_config: Subview, inq: JoinableQueue): + """ + @param kafka_config A configuration object for Kafka binding + @param inq The queue to which intel items from MISP are forwarded to + """ + super(KafkaReceiver, self).__init__() + self.kafka_config = kafka_config + self.inq = inq + + def run(self): + consumer = Consumer(self.kafka_config["config"].get(dict)) + consumer.subscribe(self.kafka_config["topics"].get(list)) + global logger, filter_config + while self._running(): + message = consumer.poll( + timeout=self.kafka_config["poll_interval"].get(float) + ) + if message is None: + continue + if message.error(): + logger.error(f"Kafka error: {message.error()}") + continue + try: + msg = json.loads(message.value()) + except Exception as e: + logger.error(f"Error decoding Kafka message: {e}") + continue + if not is_whitelisted(msg, filter_config): + continue + intel = map_to_internal(msg["Attribute"], msg.get("action", None), logger) + if not intel: + logger.debug(f"Discarding unparsable intel {msg['Attribute']}") + else: + self.inq.put(intel) + + +class ZmqReceiver(threatbus.StoppableWorker): + """ + Binds a ZMQ poller to the the given host/port. Forwards all received messages to the inq. + """ + + def __init__(self, zmq_config: Subview, inq: JoinableQueue): + """ + @param zmq_config A configuration object for ZeroMQ binding + @param inq The queue to which intel items from MISP are forwarded to + """ + super(ZmqReceiver, self).__init__() + self.inq = inq + self.zmq_config = zmq_config + + def run(self): + global logger, filter_config + socket = zmq.Context().socket(zmq.SUB) + socket.connect(f"tcp://{self.zmq_config['host']}:{self.zmq_config['port']}") + # TODO: allow reception of more topics, i.e. handle events. + socket.setsockopt(zmq.SUBSCRIBE, b"misp_json_attribute") + poller = zmq.Poller() + poller.register(socket, zmq.POLLIN) + + while self._running(): + socks = dict(poller.poll(timeout=1000)) + if socket not in socks or socks[socket] != zmq.POLLIN: + continue + raw = socket.recv() + _, message = raw.decode("utf-8").split(" ", 1) + try: + msg = json.loads(message) + except Exception as e: + logger.error(f"Error decoding message {message}: {e}") + continue + if not is_whitelisted(msg, filter_config): + continue + intel = map_to_internal(msg["Attribute"], msg.get("action", None), logger) + if not intel: + logger.debug(f"Discarding unparsable intel {msg['Attribute']}") + else: + self.inq.put(intel) def validate_config(config: Subview): @@ -55,89 +169,8 @@ def validate_config(config: Subview): config["kafka"]["config"].get(dict) -def publish_sightings(outq: Queue): - """ - Reports / publishes true-positive sightings of intelligence items back to the given MISP endpoint. - @param outq The queue from which to forward messages to MISP - """ - global logger, misp, lock - if not misp: - return - while True: - sighting = outq.get(block=True) - logger.debug(f"Reporting sighting: {sighting}") - misp_sighting = map_to_misp(sighting) - lock.acquire() - misp.add_sighting(misp_sighting) - lock.release() - outq.task_done() - - -def receive_kafka(kafka_config: Subview, inq: Queue): - """ - Binds a Kafka consumer to the the given host/port. Forwards all received messages to the inq. - @param kafka_config A configuration object for Kafka binding - @param inq The queue to which intel items from MISP are forwarded to - """ - consumer = Consumer(kafka_config["config"].get(dict)) - consumer.subscribe(kafka_config["topics"].get(list)) - global logger, filter_config - while True: - message = consumer.poll(timeout=kafka_config["poll_interval"].get(float)) - if message is None: - continue - if message.error(): - logger.error(f"Kafka error: {message.error()}") - continue - try: - msg = json.loads(message.value()) - except Exception as e: - logger.error(f"Error decoding Kafka message: {e}") - continue - if not is_whitelisted(msg, filter_config): - continue - intel = map_to_internal(msg["Attribute"], msg.get("action", None), logger) - if not intel: - logger.debug(f"Discarding unparsable intel {msg['Attribute']}") - else: - inq.put(intel) - - -def receive_zmq(zmq_config: Subview, inq: Queue): - """ - Binds a ZMQ poller to the the given host/port. Forwards all received messages to the inq. - @param zmq_config A configuration object for ZeroMQ binding - @param inq The queue to which intel items from MISP are forwarded to - """ - global logger, filter_config - socket = zmq.Context().socket(zmq.SUB) - socket.connect(f"tcp://{zmq_config['host']}:{zmq_config['port']}") - # TODO: allow reception of more topics, i.e. handle events. - socket.setsockopt(zmq.SUBSCRIBE, b"misp_json_attribute") - poller = zmq.Poller() - poller.register(socket, zmq.POLLIN) - - while True: - socks = dict(poller.poll(timeout=None)) - if socket in socks and socks[socket] == zmq.POLLIN: - raw = socket.recv() - _, message = raw.decode("utf-8").split(" ", 1) - try: - msg = json.loads(message) - except Exception as e: - logger.error(f"Error decoding message {message}: {e}") - continue - if not is_whitelisted(msg, filter_config): - continue - intel = map_to_internal(msg["Attribute"], msg.get("action", None), logger) - if not intel: - logger.debug(f"Discarding unparsable intel {msg['Attribute']}") - else: - inq.put(intel) - - @threatbus.app -def snapshot(snapshot_request: SnapshotRequest, result_q: Queue): +def snapshot(snapshot_request: SnapshotRequest, result_q: JoinableQueue): global logger, misp, lock, filter_config if snapshot_request.snapshot_type != MessageType.INTEL: logger.debug("Sighting snapshot feature not yet implemented.") @@ -190,11 +223,11 @@ def snapshot(snapshot_request: SnapshotRequest, result_q: Queue): def run( config: Subview, logging: Subview, - inq: Queue, + inq: JoinableQueue, subscribe_callback: Callable, unsubscribe_callback: Callable, ): - global logger, filter_config + global logger, filter_config, workers logger = threatbus.logger.setup(logging, __name__) config = config[plugin_name] try: @@ -205,15 +238,10 @@ def run( filter_config = config["filter"].get(list) # start Attribute-update receiver - receiver_thread = None if config["zmq"].get(): - receiver_thread = threading.Thread( - target=receive_zmq, args=(config["zmq"], inq), daemon=True - ) + workers.append(ZmqReceiver(config["zmq"], inq)) elif config["kafka"].get(): - receiver_thread = threading.Thread( - target=receive_kafka, args=(config["kafka"], inq), daemon=True - ) + workers.append(KafkaReceiver(config["kafka"], inq)) # bind to MISP if config["api"].get(dict): @@ -242,9 +270,17 @@ def run( "Starting MISP plugin without API connection, cannot report back sightings or request snapshots." ) - outq = Queue() + outq = JoinableQueue() subscribe_callback("threatbus/sighting", outq) - threading.Thread(target=publish_sightings, args=(outq,), daemon=True).start() - if receiver_thread is not None: - receiver_thread.start() + workers.append(SightingsPublisher(outq)) + for w in workers: + w.start() logger.info("MISP plugin started") + + +@threatbus.app +def stop(): + global logger, workers + for w in workers: + w.join() + logger.info("MISP plugin stopped") diff --git a/plugins/apps/threatbus_zeek/plugin.py b/plugins/apps/threatbus_zeek/plugin.py index 6d5519be..9f4564e9 100644 --- a/plugins/apps/threatbus_zeek/plugin.py +++ b/plugins/apps/threatbus_zeek/plugin.py @@ -1,5 +1,6 @@ import broker -from queue import Queue +from confuse import Subview +from multiprocessing import JoinableQueue import random import select import string @@ -11,129 +12,180 @@ map_to_internal, map_management_message, ) -import time +from typing import Callable, Dict, List, Union """Zeek network monitor - plugin for Threat Bus""" plugin_name = "zeek" lock = threading.Lock() -subscriptions = dict() - - -def validate_config(config): - assert config, "config must not be None" - config["host"].get(str) - config["port"].get(int) - config["module_namespace"].get(str) - - -def rand_string(length): - """Generates a pseudo-random string with the requested length""" - letters = string.ascii_lowercase - return "".join(random.choice(letters) for i in range(length)) - - -def manage_subscription( - ep, module_namespace, task, subscribe_callback, unsubscribe_callback -): - global lock, subscriptions - rand_suffix_length = 10 - if type(task) is Subscription: - # point-to-point topic and queue for that particular subscription - logger.info(f"Received subscription for topic: {task.topic}") - p2p_topic = task.topic + rand_string(rand_suffix_length) - p2p_q = Queue() - ack = broker.zeek.Event( - f"{module_namespace}::subscription_acknowledged", p2p_topic - ) - ep.publish("threatbus/manage", ack) - subscribe_callback(task.topic, p2p_q, task.snapshot) - lock.acquire() - subscriptions[p2p_topic] = p2p_q - lock.release() - elif type(task) is Unsubscription: - logger.info(f"Received unsubscription from topic: {task.topic}") - threatbus_topic = task.topic[: len(task.topic) - rand_suffix_length] - p2p_q = subscriptions.get(task.topic, None) - if p2p_q: - unsubscribe_callback(threatbus_topic, p2p_q) +subscriptions: Dict[str, JoinableQueue] = dict() # p2p_topic => queue +workers: List[threatbus.StoppableWorker] = list() + + +class SubscriptionManager(threatbus.StoppableWorker): + def __init__( + self, + module_namespace: str, + ep: broker.Endpoint, + subscribe_callback: Callable, + unsubscribe_callback: Callable, + ): + """ + @param module_namespace A Zeek namespace to accept events from + @param ep The broker endpoint used for listening + @param subscribe_callback The callback to invoke for new subscriptions + @param unsubscribe_callback The callback to invoke for revoked subscriptions + """ + super(SubscriptionManager, self).__init__() + self.ep = ep + self.module_namespace = module_namespace + self.subscribe_callback = subscribe_callback + self.unsubscribe_callback = unsubscribe_callback + self.rand_suffix_length = 10 + + def run(self): + """ + Binds a broker subscriber to the given endpoint. Only listens for management + messages, such as un/subscriptions of new clients. + """ + global logger + sub = self.ep.make_subscriber("threatbus/manage") + while self._running(): + (ready_readers, [], []) = select.select([sub.fd()], [], [], 1) + if not ready_readers: + continue + (topic, broker_data) = sub.get() + msg = map_management_message(broker_data, self.module_namespace) + if msg: + self.manage_subscription(msg) + + def rand_string(self, length): + """ + Generates a pseudo-random string with the requested length + """ + letters = string.ascii_lowercase + return "".join(random.choice(letters) for i in range(length)) + + def manage_subscription(self, task: Union[Subscription, Unsubscription]): + global lock, subscriptions + if type(task) is Subscription: + # point-to-point topic and queue for that particular subscription + logger.info(f"Received subscription for topic: {task.topic}") + p2p_topic = task.topic + self.rand_string(self.rand_suffix_length) + p2p_q = JoinableQueue() + ack = broker.zeek.Event( + f"{self.module_namespace}::subscription_acknowledged", p2p_topic + ) + self.ep.publish("threatbus/manage", ack) + self.subscribe_callback(task.topic, p2p_q, task.snapshot) lock.acquire() - del subscriptions[task.topic] + subscriptions[p2p_topic] = p2p_q lock.release() - else: - logger.debug(f"Skipping unknown management message of type: {type(task)}") - - -def publish(module_namespace, ep): + elif type(task) is Unsubscription: + logger.info(f"Received unsubscription from topic: {task.topic}") + threatbus_topic = task.topic[: len(task.topic) - self.rand_suffix_length] + p2p_q = subscriptions.get(task.topic, None) + if p2p_q: + self.unsubscribe_callback(threatbus_topic, p2p_q) + lock.acquire() + del subscriptions[task.topic] + lock.release() + else: + logger.debug(f"Skipping unknown management message of type: {type(task)}") + + +class BrokerPublisher(threatbus.StoppableWorker): """ Publishes messages for all subscriptions in a round-robin fashion to via broker. - @param module_namespace A Zeek namespace to use for event sending - @param ep The broker endpoint used for publishing """ - global subscriptions, lock, logger - while True: - lock.acquire() - for topic, q in subscriptions.items(): - if q.empty(): - continue - msg = q.get() - if not msg: - continue - event = map_to_broker(msg, module_namespace) - if event: - ep.publish(topic, event) - logger.debug(f"Published {msg} on topic {topic}") - q.task_done() - lock.release() - time.sleep(0.05) - - -def manage(module_namespace, ep, subscribe_callback, unsubscribe_callback): - """Binds a broker subscriber to the given endpoint. Only listens for - management messages, such as un/subscriptions of new clients. - @param module_namespace A Zeek namespace to accept events from - @param ep The broker endpoint used for listening - @param subscribe_callback The callback to invoke for new subscriptions - @param unsubscribe_callback The callback to invoke for revoked subscriptions + + def __init__(self, module_namespace: str, ep: broker.Endpoint): + """ + @param module_namespace A Zeek namespace to use for event sending + @param ep The broker endpoint used for publishing + """ + super(BrokerPublisher, self).__init__() + self.module_namespace = module_namespace + self.ep = ep + + def run(self): + global subscriptions, lock, logger + while self._running(): + lock.acquire() + # subscriptions is a dict with p2p_topic => queue + # qt_lookup is a dict with queue-reader => (p2p_topic, queue) + qt_lookup = { + sub[1]._reader: (sub[0], sub[1]) for sub in subscriptions.items() + } + readers = [q._reader for q in subscriptions.values()] + lock.release() + (ready_readers, [], []) = select.select(readers, [], [], 1) + for fd in ready_readers: + topic, q = qt_lookup[fd] + if q.empty(): + continue + msg = q.get() + if not msg: + q.task_done() + continue + try: + event = map_to_broker(msg, self.module_namespace) + if event: + self.ep.publish(topic, event) + logger.debug(f"Published {msg} on topic {topic}") + except Exception as e: + logger.error(f"Error publishing message to broker {msg}: {e}") + finally: + q.task_done() + + +class BrokerReceiver(threatbus.StoppableWorker): + """ + Binds a broker subscriber to the given endpoint. Forwards all received intel + and sightings to the inq. """ - global logger - sub = ep.make_subscriber("threatbus/manage") - while True: - ready = select.select([sub.fd()], [], []) - if not ready[0]: - logger.critical("Broker management subscriber filedescriptor error.") - (topic, broker_data) = sub.get() - msg = map_management_message(broker_data, module_namespace) - if msg: - manage_subscription( - ep, module_namespace, msg, subscribe_callback, unsubscribe_callback - ) + def __init__(self, module_namespace: str, ep: broker.Endpoint, inq: JoinableQueue): + """ + @param module_namespace A Zeek namespace to accept events from + @param ep The broker endpoint used for listening + @param inq The queue to forward messages to + """ + super(BrokerReceiver, self).__init__() + self.module_namespace = module_namespace + self.ep = ep + self.inq = inq + + def run(self): + sub = self.ep.make_subscriber(["threatbus/intel", "threatbus/sighting"]) + global logger + while self._running(): + (ready_readers, [], []) = select.select([sub.fd()], [], [], 1) + if not ready_readers: + continue + (topic, broker_data) = sub.get() + msg = map_to_internal(broker_data, self.module_namespace) + if msg: + self.inq.put(msg) -def listen(module_namespace, ep, inq): - """Binds a broker subscriber to the given endpoint. Forwards all received - intel and sightings to the inq. - @param logger A logging.logger object - @param module_namespace A Zeek namespace to accept events from - @param ep The broker endpoint used for listening - @param inq The queue to forward messages to - """ - sub = ep.make_subscriber(["threatbus/intel", "threatbus/sighting"]) - global logger - while True: - ready = select.select([sub.fd()], [], []) - if not ready[0]: - logger.critical("Broker intel/sightings subscriber filedescriptor error.") - (topic, broker_data) = sub.get() - msg = map_to_internal(broker_data, module_namespace) - if msg: - inq.put(msg) + +def validate_config(config): + assert config, "config must not be None" + config["host"].get(str) + config["port"].get(int) + config["module_namespace"].get(str) @threatbus.app -def run(config, logging, inq, subscribe_callback, unsubscribe_callback): - global logger +def run( + config: Subview, + logging: Subview, + inq: JoinableQueue, + subscribe_callback: Callable, + unsubscribe_callback: Callable, +): + global logger, workers logger = threatbus.logger.setup(logging, __name__) config = config[plugin_name] try: @@ -150,11 +202,19 @@ def run(config, logging, inq, subscribe_callback, unsubscribe_callback): ep = broker.Endpoint(broker.Configuration(broker_opts)) ep.listen(host, port) - threading.Thread(target=listen, args=(namespace, ep, inq), daemon=True).start() - threading.Thread( - target=manage, - args=(namespace, ep, subscribe_callback, unsubscribe_callback), - daemon=True, - ).start() - threading.Thread(target=publish, args=(namespace, ep), daemon=True).start() + workers.append( + SubscriptionManager(namespace, ep, subscribe_callback, unsubscribe_callback) + ) + workers.append(BrokerReceiver(namespace, ep, inq)) + workers.append(BrokerPublisher(namespace, ep)) + for w in workers: + w.start() logger.info("Zeek plugin started") + + +@threatbus.app +def stop(): + global logger, workers + for w in workers: + w.join() + logger.info("Zeek plugin stopped") diff --git a/plugins/apps/threatbus_zmq_app/plugin.py b/plugins/apps/threatbus_zmq_app/plugin.py index 2960782c..2466f6c8 100644 --- a/plugins/apps/threatbus_zmq_app/plugin.py +++ b/plugins/apps/threatbus_zmq_app/plugin.py @@ -1,7 +1,8 @@ from confuse import Subview import json -from queue import Queue +from multiprocessing import JoinableQueue import random +import select import string import threading from threatbus_zmq_app.message_mapping import Heartbeat, map_management_message @@ -20,8 +21,7 @@ Subscription, Unsubscription, ) -import time -from typing import Callable, Dict, Tuple +from typing import Callable, Dict, List, Tuple import zmq @@ -32,171 +32,206 @@ plugin_name = "zmq-app" subscriptions_lock = threading.Lock() -subscriptions: Dict[str, Tuple[str, Queue]] = dict() # p2p_topic => (topic, queue) +# subscriptions: p2p_topic => (topic, queue) +subscriptions: Dict[str, Tuple[str, JoinableQueue]] = dict() snapshots_lock = threading.Lock() snapshots: Dict[str, str] = dict() # snapshot_id => topic p2p_topic_prefix_length = 32 # length of random topic prefix +workers: List[threatbus.StoppableWorker] = list() -def validate_config(config: Subview): - assert config, "config must not be None" - config["host"].get(str) - config["manage"].get(int) - config["pub"].get(int) - config["sub"].get(int) +class SubscriptionManager(threatbus.StoppableWorker): + """ + Management endpoint to handle (un)subscriptions of apps. + """ + def __init__( + self, + zmq_config: Subview, + subscribe_callback: Callable, + unsubscribe_callback: Callable, + ): + """ + @param zmq_config Config object for the ZeroMQ endpoints + @param subscribe_callback Callback from Threat Bus to unsubscribe new apps + @param unsubscribe_callback Callback from Threat Bus to unsubscribe apps + """ + super(SubscriptionManager, self).__init__() + self.zmq_config = zmq_config + self.subscribe_callback = subscribe_callback + self.unsubscribe_callback = unsubscribe_callback -def rand_string(length: int): - """Generates a pseudo-random string with the requested length""" - letters = string.ascii_lowercase - return "".join(random.choice(letters) for i in range(length)) + def run(self): + global logger, subscriptions_lock, subscriptions, snapshots_lock, snapshots + context = zmq.Context() + socket = context.socket(zmq.REP) # REP socket for point-to-point reply + socket.bind(f"tcp://{self.zmq_config['host']}:{self.zmq_config['manage']}") + pub_endpoint = f"{self.zmq_config['host']}:{self.zmq_config['pub']}" + sub_endpoint = f"{self.zmq_config['host']}:{self.zmq_config['sub']}" -def receive_management( - zmq_config: Subview, subscribe_callback: Callable, unsubscribe_callback: Callable -): - """ - Management endpoint to handle (un)subscriptions of apps. - @param zmq_config Config object for the ZeroMQ endpoints - @param subscribe_callback Callback from Threat Bus to unsubscribe new apps - @param unsubscribe_callback Callback from Threat Bus to unsubscribe apps - """ - global logger, subscriptions_lock, subscriptions, snapshots_lock, snapshots - - context = zmq.Context() - socket = context.socket(zmq.REP) # REP socket for point-to-point reply - socket.bind(f"tcp://{zmq_config['host']}:{zmq_config['manage']}") - pub_endpoint = f"{zmq_config['host']}:{zmq_config['pub']}" - sub_endpoint = f"{zmq_config['host']}:{zmq_config['sub']}" - - while True: - # Wait for next request from client - try: - msg = None - msg = socket.recv_json() - task = map_management_message(msg) - - if type(task) is Subscription: - # point-to-point topic and queue for that particular subscription - logger.info( - f"Received subscription for topic {task.topic}, snapshot {task.snapshot}" - ) - try: - p2p_topic = rand_string(p2p_topic_prefix_length) - p2p_q = Queue() + poller = zmq.Poller() + poller.register(socket, zmq.POLLIN) + + while self._running(): + socks = dict(poller.poll(timeout=1000)) + if socket not in socks or socks[socket] != zmq.POLLIN: + continue + try: + msg = None + msg = socket.recv_json() + task = map_management_message(msg) + + if type(task) is Subscription: + # point-to-point topic and queue for that particular subscription + logger.info( + f"Received subscription for topic {task.topic}, snapshot {task.snapshot}" + ) + try: + p2p_topic = rand_string(p2p_topic_prefix_length) + p2p_q = JoinableQueue() + subscriptions_lock.acquire() + subscriptions[p2p_topic] = (task.topic, p2p_q) + subscriptions_lock.release() + snapshot_id = self.subscribe_callback( + task.topic, p2p_q, task.snapshot + ) + if snapshot_id: + # remember that this snapshot was requested by this particular + # subscriber (identified by unique topic), so it is not asked to + # execute it's own request + snapshots_lock.acquire() + snapshots[snapshot_id] = p2p_topic + snapshots_lock.release() + # send success message for reconnecting + socket.send_json( + { + "topic": p2p_topic, + "pub_endpoint": pub_endpoint, + "sub_endpoint": sub_endpoint, + "status": "success", + } + ) + except Exception as e: + logger.error(f"Error handling subscription request {task}: {e}") + socket.send_json({"status": "error"}) + elif type(task) is Unsubscription: + logger.info(f"Received unsubscription from topic {task.topic}") + threatbus_topic, p2p_q = subscriptions.get(task.topic, (None, None)) + if not p2p_q: + logger.warn("No one was subscribed for that topic. Skipping.") + socket.send_json({"status": "error"}) + continue + self.unsubscribe_callback(threatbus_topic, p2p_q) subscriptions_lock.acquire() - subscriptions[p2p_topic] = (task.topic, p2p_q) + del subscriptions[task.topic] subscriptions_lock.release() - snapshot_id = subscribe_callback(task.topic, p2p_q, task.snapshot) - if snapshot_id: - # remember that this snapshot was requested by this particular - # subscriber (identified by unique topic), so it is not asked to - # execute it's own request - snapshots_lock.acquire() - snapshots[snapshot_id] = p2p_topic - snapshots_lock.release() - # send success message for reconnecting - socket.send_json( - { - "topic": p2p_topic, - "pub_endpoint": pub_endpoint, - "sub_endpoint": sub_endpoint, - "status": "success", - } - ) - except Exception as e: - logger.error(f"Error handling subscription request {task}: {e}") - socket.send_json({"status": "error"}) - elif type(task) is Unsubscription: - logger.info(f"Received unsubscription from topic {task.topic}") - threatbus_topic, p2p_q = subscriptions.get(task.topic, (None, None)) - if not p2p_q: - logger.warn("No one was subscribed for that topic. Skipping.") - socket.send_json({"status": "error"}) - continue - unsubscribe_callback(threatbus_topic, p2p_q) - subscriptions_lock.acquire() - del subscriptions[task.topic] - subscriptions_lock.release() - socket.send_json({"status": "success"}) - elif type(task) is Heartbeat: - if task.topic not in subscriptions: - socket.send_json({"status": "error"}) - continue - socket.send_json({"status": "success"}) - else: - socket.send_json({"status": "unknown request"}) - except Exception as e: - socket.send_json({"status": "error"}) - logger.error(f"Error handling management message {msg}: {e}") + socket.send_json({"status": "success"}) + elif type(task) is Heartbeat: + if task.topic not in subscriptions: + socket.send_json({"status": "error"}) + continue + socket.send_json({"status": "success"}) + else: + socket.send_json({"status": "unknown request"}) + except Exception as e: + socket.send_json({"status": "error"}) + logger.error(f"Error handling management message {msg}: {e}") -def pub_zmq(zmq_config: Subview): +class ZmqPublisher(threatbus.StoppableWorker): """ Publshes messages to all registered subscribers via ZeroMQ. - @param zmq_config ZeroMQ configuration properties """ - global subscriptions, subscriptions_lock, logger - context = zmq.Context() - socket = context.socket(zmq.PUB) - socket.bind(f"tcp://{zmq_config['host']}:{zmq_config['pub']}") - - while True: - subscriptions_lock.acquire() - subs_copy = subscriptions.copy() - subscriptions_lock.release() - # the queues are filled by the backbone, the plugin distributes all - # messages in round-robin fashion to all subscribers - for topic, (_, q) in subs_copy.items(): - if q.empty(): - continue - msg = q.get() - if not msg: - q.task_done() - continue - if type(msg) is Intel: - encoded = json.dumps(msg, cls=IntelEncoder) - topic += "intel" - elif type(msg) is Sighting: - encoded = json.dumps(msg, cls=SightingEncoder) - topic += "sighting" - elif type(msg) is SnapshotRequest: - encoded = json.dumps(msg, cls=SnapshotRequestEncoder) - topic += "snapshotrequest" - else: - logger.warn( - f"Skipping unknown message type '{type(msg)}' for topic subscription {topic}." - ) - continue - socket.send((f"{topic} {encoded}").encode()) - logger.debug(f"Published {encoded} on topic {topic}") - q.task_done() - time.sleep(0.05) + def __init__(self, zmq_config: Subview): + """ + @param zmq_config ZeroMQ configuration properties + """ + super(ZmqPublisher, self).__init__() + self.zmq_config = zmq_config + + def run(self): + global subscriptions, subscriptions_lock, logger + context = zmq.Context() + socket = context.socket(zmq.PUB) + socket.bind(f"tcp://{self.zmq_config['host']}:{self.zmq_config['pub']}") -def sub_zmq(zmq_config: Subview, inq: Queue): + while self._running(): + subscriptions_lock.acquire() + # subscriptions is a dict with p2p_topic => (topic, queue) + # qt_lookup is a dict with queue-reader => (topic, queue) + qt_lookup = { + sub[1][1]._reader: (sub[0], sub[1][1]) for sub in subscriptions.items() + } + readers = [tq[1]._reader for tq in subscriptions.values()] + subscriptions_lock.release() + (ready_readers, [], []) = select.select(readers, [], [], 1) + for fd in ready_readers: + topic, q = qt_lookup[fd] + if q.empty(): + continue + msg = q.get() + if not msg: + q.task_done() + continue + if type(msg) is Intel: + encoded = json.dumps(msg, cls=IntelEncoder) + topic += "intel" + elif type(msg) is Sighting: + encoded = json.dumps(msg, cls=SightingEncoder) + topic += "sighting" + elif type(msg) is SnapshotRequest: + encoded = json.dumps(msg, cls=SnapshotRequestEncoder) + topic += "snapshotrequest" + else: + logger.warn( + f"Skipping unknown message type '{type(msg)}' for topic subscription {topic}." + ) + q.task_done() + continue + try: + socket.send((f"{topic} {encoded}").encode()) + logger.debug(f"Published {encoded} on topic {topic}") + except Exception as e: + logger.error(f"Error sending {encoded} on topic {topic}: {e}") + finally: + q.task_done() + + +class ZmqReceiver(threatbus.StoppableWorker): """ Forwards messages that are received via ZeroMQ from connected applications to the plugin's in-queue. - @param zmq_config ZeroMQ configuration properties """ - global logger - context = zmq.Context() - socket = context.socket(zmq.SUB) - socket.bind(f"tcp://{zmq_config['host']}:{zmq_config['sub']}") - intel_topic = "threatbus/intel" - sighting_topic = "threatbus/sighting" - snapshotenvelope_topic = "threatbus/snapshotenvelope" - socket.setsockopt(zmq.SUBSCRIBE, intel_topic.encode()) - socket.setsockopt(zmq.SUBSCRIBE, sighting_topic.encode()) - socket.setsockopt(zmq.SUBSCRIBE, snapshotenvelope_topic.encode()) - - poller = zmq.Poller() - poller.register(socket, zmq.POLLIN) - - while True: - socks = dict(poller.poll(timeout=None)) - if socket in socks and socks[socket] == zmq.POLLIN: + + def __init__(self, zmq_config: Subview, inq: JoinableQueue): + """ + @param zmq_config ZeroMQ configuration properties + """ + super(ZmqReceiver, self).__init__() + self.zmq_config = zmq_config + self.inq = inq + + def run(self): + global logger + context = zmq.Context() + socket = context.socket(zmq.SUB) + socket.bind(f"tcp://{self.zmq_config['host']}:{self.zmq_config['sub']}") + intel_topic = "threatbus/intel" + sighting_topic = "threatbus/sighting" + snapshotenvelope_topic = "threatbus/snapshotenvelope" + socket.setsockopt(zmq.SUBSCRIBE, intel_topic.encode()) + socket.setsockopt(zmq.SUBSCRIBE, sighting_topic.encode()) + socket.setsockopt(zmq.SUBSCRIBE, snapshotenvelope_topic.encode()) + + poller = zmq.Poller() + poller.register(socket, zmq.POLLIN) + + while self._running(): + socks = dict(poller.poll(timeout=1000)) + if socket not in socks or socks[socket] != zmq.POLLIN: + continue try: topic, msg = socket.recv().decode().split(" ", 1) if topic == intel_topic: @@ -221,16 +256,29 @@ def sub_zmq(zmq_config: Subview, inq: Queue): ) continue - inq.put(decoded) + self.inq.put(decoded) except Exception as e: logger.error(f"Error decoding message: {e}") continue +def validate_config(config: Subview): + assert config, "config must not be None" + config["host"].get(str) + config["manage"].get(int) + config["pub"].get(int) + config["sub"].get(int) + + +def rand_string(length: int): + """Generates a pseudo-random string with the requested length""" + letters = string.ascii_lowercase + return "".join(random.choice(letters) for i in range(length)) + + @threatbus.app -def snapshot(snapshot_request: SnapshotRequest, result_q: Queue): +def snapshot(snapshot_request: SnapshotRequest, result_q: JoinableQueue): global logger, snapshots, snapshots_lock, subscriptions, subscriptions_lock - logger.info(f"Executing snapshot for time delta {snapshot_request.snapshot}") snapshots_lock.acquire() requester = snapshots.get(snapshot_request.snapshot_id, None) @@ -250,22 +298,30 @@ def snapshot(snapshot_request: SnapshotRequest, result_q: Queue): def run( config: Subview, logging: Subview, - inq: Queue, + inq: JoinableQueue, subscribe_callback: Callable, unsubscribe_callback: Callable, ): - global logger + global logger, workers logger = threatbus.logger.setup(logging, __name__) config = config[plugin_name] try: validate_config(config) except Exception as e: logger.fatal("Invalid config for plugin {}: {}".format(plugin_name, str(e))) - threading.Thread(target=pub_zmq, args=(config,), daemon=True).start() - threading.Thread(target=sub_zmq, args=(config, inq), daemon=True).start() - threading.Thread( - target=receive_management, - args=(config, subscribe_callback, unsubscribe_callback), - daemon=True, - ).start() + workers.append(ZmqPublisher(config)) + workers.append(ZmqReceiver(config, inq)) + workers.append( + SubscriptionManager(config, subscribe_callback, unsubscribe_callback) + ) + for w in workers: + w.start() logger.info("ZeroMQ app plugin started") + + +@threatbus.app +def stop(): + global logger, workers + for w in workers: + w.join() + logger.info("ZeroMQ app plugin stopped") diff --git a/plugins/backbones/threatbus_inmem/plugin.py b/plugins/backbones/threatbus_inmem/plugin.py index ae2e5172..64ce4085 100644 --- a/plugins/backbones/threatbus_inmem/plugin.py +++ b/plugins/backbones/threatbus_inmem/plugin.py @@ -1,39 +1,55 @@ -import threading from collections import defaultdict +from confuse import Subview +from multiprocessing import JoinableQueue +from queue import Empty +import threading import threatbus +from typing import Dict, List, Set """In-Memory backbone plugin for Threat Bus""" plugin_name = "inmem" -subscriptions = defaultdict(set) +subscriptions: Dict[str, Set[JoinableQueue]] = defaultdict(set) lock = threading.Lock() +workers: List[threatbus.StoppableWorker] = list() def validate_config(config): return True -def provision(inq): +class Provisioner(threatbus.StoppableWorker): """ Provisions all messages that arrive on the inq to all subscribers of that topic. - @param inq The in-Queue to read messages from + @param inq The in-queue to read messages from """ - global subscriptions, lock, logger - while True: - msg = inq.get(block=True) - logger.debug(f"Backbone got message {msg}") - topic = f"threatbus/{type(msg).__name__.lower()}" - lock.acquire() - for t in filter(lambda t: str(topic).startswith(str(t)), subscriptions.keys()): - for outq in subscriptions[t]: - outq.put(msg) - lock.release() - inq.task_done() + + def __init__(self, inq: JoinableQueue): + self.inq = inq + super(Provisioner, self).__init__() + + def run(self): + global subscriptions, lock, logger + while self._running(): + try: + msg = self.inq.get(block=True, timeout=1) + except Empty: + continue + logger.debug(f"Backbone got message {msg}") + topic = f"threatbus/{type(msg).__name__.lower()}" + lock.acquire() + for t in filter( + lambda t: str(topic).startswith(str(t)), subscriptions.keys() + ): + for outq in subscriptions[t]: + outq.put(msg) + lock.release() + self.inq.task_done() @threatbus.backbone -def subscribe(topic, q): +def subscribe(topic: str, q: JoinableQueue): global logger, subscriptions, lock logger.info(f"Adding subscription to: {topic}") lock.acquire() @@ -42,7 +58,7 @@ def subscribe(topic, q): @threatbus.backbone -def unsubscribe(topic, q): +def unsubscribe(topic: str, q: JoinableQueue): global logger, subscriptions, lock logger.info(f"Removing subscription from: {topic}") lock.acquire() @@ -52,13 +68,23 @@ def unsubscribe(topic, q): @threatbus.backbone -def run(config, logging, inq): - global logger +def run(config: Subview, logging: Subview, inq: JoinableQueue): + global logger, workers logger = threatbus.logger.setup(logging, __name__) config = config[plugin_name] try: validate_config(config) except Exception as e: logger.fatal("Invalid config for plugin {}: {}".format(plugin_name, str(e))) - threading.Thread(target=provision, args=(inq,), daemon=True).start() + workers.append(Provisioner(inq)) + for w in workers: + w.start() logger.info("In-memory backbone started.") + + +@threatbus.backbone +def stop(): + global logger, workers + for w in workers: + w.join() + logger.info("In-memory backbone stopped") diff --git a/plugins/backbones/threatbus_rabbitmq/__init__.py b/plugins/backbones/threatbus_rabbitmq/__init__.py index e69de29b..6a482c01 100644 --- a/plugins/backbones/threatbus_rabbitmq/__init__.py +++ b/plugins/backbones/threatbus_rabbitmq/__init__.py @@ -0,0 +1,3 @@ +from .helpers import get_exchange_name, get_queue_name +from .rabbitmq_consumer import RabbitMQConsumer +from .rabbitmq_publisher import RabbitMQPublisher diff --git a/plugins/backbones/threatbus_rabbitmq/helpers.py b/plugins/backbones/threatbus_rabbitmq/helpers.py new file mode 100644 index 00000000..69465e4a --- /dev/null +++ b/plugins/backbones/threatbus_rabbitmq/helpers.py @@ -0,0 +1,21 @@ +from socket import gethostname + + +def get_queue_name(join_symbol: str, data_type: str, suffix: str = gethostname()): + """ + Returns a queue name accroding to the desired pattern. + @param join_symbol The symbol to use when concatenating the name + @param data_type The type of data that goes through the queue (e.g., "intel") + @param suffix A suffix to append to the name. Default: the hostname + """ + return join_symbol.join(["threatbus", data_type, suffix]) + + +def get_exchange_name(join_symbol: str, data_type: str): + """ + Returns an exchange name accroding to the desired pattern. + @param join_symbol The symbol to use when concatenating the name + @param data_type The type of data that goes through the queue (e.g., "intel") + @param suffix A suffix to append to the name. Default: the hostname + """ + return join_symbol.join(["threatbus", data_type]) diff --git a/plugins/backbones/threatbus_rabbitmq/plugin.py b/plugins/backbones/threatbus_rabbitmq/plugin.py index bb1a6532..869aea57 100644 --- a/plugins/backbones/threatbus_rabbitmq/plugin.py +++ b/plugins/backbones/threatbus_rabbitmq/plugin.py @@ -1,8 +1,8 @@ from collections import defaultdict -import json +from confuse import Subview import pika +from multiprocessing import JoinableQueue from retry import retry -from socket import gethostname import threading import threatbus from threatbus.data import ( @@ -10,46 +10,21 @@ Sighting, SnapshotRequest, SnapshotEnvelope, - IntelEncoder, - IntelDecoder, - SightingEncoder, - SightingDecoder, - SnapshotRequestEncoder, - SnapshotRequestDecoder, - SnapshotEnvelopeEncoder, - SnapshotEnvelopeDecoder, ) +from threatbus_rabbitmq import RabbitMQConsumer, RabbitMQPublisher +from typing import Dict, List, Union """RabbitMQ backbone plugin for Threat Bus""" plugin_name = "rabbitmq" -subscriptions = defaultdict(set) +subscriptions: Dict[str, set] = defaultdict(set) lock = threading.Lock() +workers: List[threatbus.StoppableWorker] = list() -def get_queue_name(join_symbol, data_type, suffix=gethostname()): - """ - Returns a queue name accroding to the desired pattern. - @param join_symbol The symbol to use when concatenating the name - @param data_type The type of data that goes through the queue (e.g., "intel") - @param suffix A suffix to append to the name. Default: the hostname - """ - return join_symbol.join(["threatbus", data_type, suffix]) - - -def get_exchange_name(join_symbol, data_type): - """ - Returns an exchange name accroding to the desired pattern. - @param join_symbol The symbol to use when concatenating the name - @param data_type The type of data that goes through the queue (e.g., "intel") - @param suffix A suffix to append to the name. Default: the hostname - """ - return join_symbol.join(["threatbus", data_type]) - - -def validate_config(config): +def validate_config(config: Subview): assert config, "config must not be None" config["host"].get(str) config["port"].get(int) @@ -68,7 +43,9 @@ def validate_config(config): config["queue"]["max_items"].get(int) -def __provision(topic, msg): +def provision( + topic: str, msg: Union[Intel, Sighting, SnapshotEnvelope, SnapshotRequest] +): """ Provisions the given `msg` to all subscribers of `topic`. @param topic The topic string to use for provisioning @@ -83,218 +60,9 @@ def __provision(topic, msg): logger.debug(f"Relayed message from RabbitMQ: {msg}") -def __decode(msg, decoder): - """ - Decodes a JSON message with the given decoder. Returns the decoded object or - None and logs an error. - @param msg The message to decode - @param decoder The decoder class to use for decoding - """ - global logger - try: - return json.loads(msg, cls=decoder) - except Exception as e: - logger.error(f"Error decoding message {msg}: {e}") - return None - - -def __provision_intel(channel, method_frame, header_frame, body): - """ - Callback to be invoked by the Pika library whenever a new message `body` has - been received from RabbitMQ on the intel queue. - @param channel: pika.Channel The channel that was received on - @param method: pika.spec.Basic.Deliver The pika delivery method (e.g., ACK) - @param properties: pika.spec.BasicProperties Pika properties - @param body: bytes The received message - """ - msg = __decode(body, IntelDecoder) - if msg: - __provision("threatbus/intel", msg) - channel.basic_ack(delivery_tag=method_frame.delivery_tag) - - -def __provision_sighting(channel, method_frame, header_frame, body): - """ - Callback to be invoked by the Pika library whenever a new message `body` has - been received from RabbitMQ on the sighting queue. - @param channel: pika.Channel The channel that was received on - @param method: pika.spec.Basic.Deliver The pika delivery method (e.g., ACK) - @param properties: pika.spec.BasicProperties Pika properties - @param body: bytes The received message - """ - msg = __decode(body, SightingDecoder) - if msg: - __provision("threatbus/sighting", msg) - channel.basic_ack(delivery_tag=method_frame.delivery_tag) - - -def __provision_snapshot_request(channel, method_frame, header_frame, body): - """ - Callback to be invoked by the Pika library whenever a new message `body` has - been received from RabbitMQ on the snapshot-request queue. - @param channel: pika.Channel The channel that was received on - @param method: pika.spec.Basic.Deliver The pika delivery method (e.g., ACK) - @param properties: pika.spec.BasicProperties Pika properties - @param body: bytes The received message - """ - msg = __decode(body, SnapshotRequestDecoder) - if msg: - __provision("threatbus/snapshotrequest", msg) - channel.basic_ack(delivery_tag=method_frame.delivery_tag) - - -def __provision_snapshot_envelope(channel, method_frame, header_frame, body): - """ - Callback to be invoked by the Pika library whenever a new message `body` has - been received from RabbitMQ on the snapshot-envelope queue. - @param channel: pika.Channel The channel that was received on - @param method: pika.spec.Basic.Deliver The pika delivery method (e.g., ACK) - @param properties: pika.spec.BasicProperties Pika properties - @param body: bytes The received message - """ - msg = __decode(body, SnapshotEnvelopeDecoder) - if msg: - __provision("threatbus/snapshotenvelope", msg) - channel.basic_ack(delivery_tag=method_frame.delivery_tag) - - -@retry(delay=5) -def consume_rabbitmq(conn_params, join_symbol, queue_params): - """ - Connects to RabbitMQ on the given host/port endpoint. Registers callbacks to - consumes all messages and initiates further provisioning. - @param conn_params Pika.ConnectionParameters to connect to RabbitMQ - @param join_symbol The symbol to use when determining queue and exchange names - @param queue_params Confuse view of parameters to use for declaring queues - """ - global logger - logger.debug("Connecting RabbitMQ consumer...") - # RabbitMQ connection - connection = pika.BlockingConnection(conn_params) - channel = connection.channel() - - # create names and parameters - exchange_intel = get_exchange_name(join_symbol, "intel") - exchange_sighting = get_exchange_name(join_symbol, "sighting") - exchange_snapshotrequest = get_exchange_name(join_symbol, "snapshotrequest") - exchange_snapshotenvelope = get_exchange_name(join_symbol, "snapshotenvelope") - queue_name_suffix = queue_params["name_suffix"].get() - queue_name_suffix = queue_name_suffix if queue_name_suffix else gethostname() - intel_queue = get_queue_name(join_symbol, "intel", queue_name_suffix) - sighting_queue = get_queue_name(join_symbol, "sighting", queue_name_suffix) - snapshot_request_queue = get_queue_name( - join_symbol, "snapshotrequest", queue_name_suffix - ) - snapshot_envelope_queue = get_queue_name( - join_symbol, "snapshotenvelope", queue_name_suffix - ) - queue_mode = "default" if not queue_params["lazy"].get(bool) else "lazy" - queue_kwargs = { - "durable": queue_params["durable"].get(bool), - "exclusive": queue_params["exclusive"].get(bool), - "auto_delete": queue_params["auto_delete"].get(bool), - "arguments": {"x-queue-mode": queue_mode}, - } - max_items = queue_params["max_items"].get() - if max_items: - queue_kwargs["arguments"]["x-max-length"] = max_items - - # bind callbacks to RabbitMQ - channel.exchange_declare(exchange=exchange_intel, exchange_type="fanout") - channel.queue_declare(intel_queue, **queue_kwargs) - channel.queue_bind(exchange=exchange_intel, queue=intel_queue) - channel.basic_consume(intel_queue, __provision_intel) - - channel.exchange_declare(exchange=exchange_sighting, exchange_type="fanout") - channel.queue_declare(sighting_queue, **queue_kwargs) - channel.queue_bind(exchange=exchange_sighting, queue=sighting_queue) - channel.basic_consume(sighting_queue, __provision_sighting) - - channel.exchange_declare(exchange=exchange_snapshotrequest, exchange_type="fanout") - channel.queue_declare(snapshot_request_queue, **queue_kwargs) - channel.queue_bind(exchange=exchange_snapshotrequest, queue=snapshot_request_queue) - channel.basic_consume(snapshot_request_queue, __provision_snapshot_request) - - channel.exchange_declare(exchange=exchange_snapshotenvelope, exchange_type="fanout") - channel.queue_declare(snapshot_envelope_queue, **queue_kwargs) - channel.queue_bind( - exchange=exchange_snapshotenvelope, queue=snapshot_envelope_queue - ) - channel.basic_consume(snapshot_envelope_queue, __provision_snapshot_envelope) - - try: - channel.start_consuming() - except KeyboardInterrupt: - channel.stop_consuming() - connection.close() - except Exception as e: - logger.error(f"Consumer lost connection to RabbitMQ: {e}") - raise e # let @retry handle the reconnect - - @retry(delay=5) -def publish_rabbitmq(conn_params, join_symbol, inq): - """ - Connects to RabbitMQ on the given host/port endpoint. Forwards all messages - from the `inq`, based on their type, to the appropriate RabbitMQ exchange. - @param conn_params Pika.ConnectionParameters to connect to RabbitMQ - @param join_symbol The symbol to use when determining queue and exchange names - @param inq A Queue object to read messages from and publish them to RabbitMQ - """ - global logger - logger.debug("Connecting RabbitMQ publisher...") - connection = pika.BlockingConnection(conn_params) - - # create names and parameters - exchange_intel = get_exchange_name(join_symbol, "intel") - exchange_sighting = get_exchange_name(join_symbol, "sighting") - exchange_snapshotrequest = get_exchange_name(join_symbol, "snapshotrequest") - exchange_snapshotenvelope = get_exchange_name(join_symbol, "snapshotenvelope") - channel = connection.channel() - channel.exchange_declare(exchange=exchange_intel, exchange_type="fanout") - channel.exchange_declare(exchange=exchange_sighting, exchange_type="fanout") - channel.exchange_declare(exchange=exchange_snapshotrequest, exchange_type="fanout") - channel.exchange_declare(exchange=exchange_snapshotenvelope, exchange_type="fanout") - - # forward messages to RabbitMQ - while True: - msg = inq.get(block=True) - exchange = None - encoded = None - try: - if type(msg) == Intel: - exchange = exchange_intel - encoded = json.dumps(msg, cls=IntelEncoder) - elif type(msg) == Sighting: - exchange = exchange_sighting - encoded = json.dumps(msg, cls=SightingEncoder) - elif type(msg) == SnapshotRequest: - exchange = exchange_snapshotrequest - encoded = json.dumps(msg, cls=SnapshotRequestEncoder) - elif type(msg) == SnapshotEnvelope: - exchange = exchange_snapshotenvelope - encoded = json.dumps(msg, cls=SnapshotEnvelopeEncoder) - except Exception as e: - logger.warn(f"Discarding unparsable message {msg}: {e}") - continue - try: - channel.basic_publish(exchange=exchange, routing_key="", body=encoded) - logger.debug(f"Forwarded message to RabbitMQ: {msg}") - inq.task_done() - except KeyboardInterrupt: - connection.close() - break - except Exception as e: - # push back message - logger.error(f"Failed to send, pushing back message: {msg}") - logger.error(f"Publisher lost connection to RabbitMQ: {e}") - if msg: - inq.put(msg) - raise e # let @retry handle the reconnect - - @threatbus.backbone -def subscribe(topic, q): +def subscribe(topic: str, q: JoinableQueue): """ Threat Bus' subscribe hook. Used to register new app-queues for certain topics. @@ -307,7 +75,7 @@ def subscribe(topic, q): @threatbus.backbone -def unsubscribe(topic, q): +def unsubscribe(topic: str, q: JoinableQueue): """ Threat Bus' unsubscribe hook. Used to deregister app-queues from certain topics. @@ -321,8 +89,8 @@ def unsubscribe(topic, q): @threatbus.backbone -def run(config, logging, inq): - global logger +def run(config: Subview, logging: Subview, inq: JoinableQueue): + global subscriptions, lock, logger, workers logger = threatbus.logger.setup(logging, __name__) config = config[plugin_name] try: @@ -337,14 +105,18 @@ def run(config, logging, inq): credentials = pika.PlainCredentials(username, password) conn_params = pika.ConnectionParameters(host, port, vhost, credentials) name_pattern = config["naming_join_pattern"].get(str) - threading.Thread( - target=consume_rabbitmq, - args=(conn_params, name_pattern, config["queue"]), - daemon=True, - ).start() - threading.Thread( - target=publish_rabbitmq, - args=(conn_params, name_pattern, inq), - daemon=True, - ).start() + workers.append( + RabbitMQConsumer(conn_params, name_pattern, config["queue"], provision, logger) + ) + workers.append(RabbitMQPublisher(conn_params, name_pattern, inq, logger)) + for w in workers: + w.start() logger.info("RabbitMQ backbone started.") + + +@threatbus.backbone +def stop(): + global logger, workers + for w in workers: + w.join() + logger.info("RabbitMQ backbone stopped") diff --git a/plugins/backbones/threatbus_rabbitmq/rabbitmq_consumer.py b/plugins/backbones/threatbus_rabbitmq/rabbitmq_consumer.py new file mode 100644 index 00000000..e10fc104 --- /dev/null +++ b/plugins/backbones/threatbus_rabbitmq/rabbitmq_consumer.py @@ -0,0 +1,311 @@ +from confuse import Subview +from functools import partial +import json +import pika +from logging import Logger +from threatbus_rabbitmq import get_exchange_name, get_queue_name +from socket import gethostname +import threatbus +from threatbus.data import ( + Intel, + Sighting, + SnapshotRequest, + SnapshotEnvelope, + IntelDecoder, + SightingDecoder, + SnapshotRequestDecoder, + SnapshotEnvelopeDecoder, +) +import time +from typing import Callable, List, Tuple, Union + + +class RabbitMQConsumer(threatbus.StoppableWorker): + """ + Connects to RabbitMQ on the given host/port endpoint. Registers callbacks to + consumes all messages and initiates further provisioning. + """ + + def __init__( + self, + conn_params: pika.ConnectionParameters, + join_symbol: str, + queue_params: Subview, + provision_callback: Callable[ + [str, Union[Intel, Sighting, SnapshotEnvelope, SnapshotRequest]], None + ], + logger: Logger, + ): + """ + @param conn_params Pika.ConnectionParameters to connect to RabbitMQ + @param join_symbol The symbol to use when determining queue and exchange names + @param queue_params Confuse view of parameters to use for declaring queues + @param provision_callback A callback to invoke after messages are retrieved and parsed successfully + @param logger A pre-configured Logger instance + """ + super(RabbitMQConsumer, self).__init__() + self.conn_params: pika.ConnectionParameters = conn_params + self.__provision: Callable[ + [str, Union[Intel, Sighting, SnapshotEnvelope, SnapshotRequest]], None + ] = provision_callback + self.logger: Logger = logger + self.consumers: List[str] = list() # RabbitMQ consumer tags + self._reconnect_delay: int = 5 + self._connection: Union[pika.SelectConnection, None] = None + self._channel: Union[pika.channel.Channel, None] = None + + # Create names and parameters for exchanges and queues + self.intel_exchange = get_exchange_name(join_symbol, "intel") + self.sighting_exchange = get_exchange_name(join_symbol, "sighting") + self.snapshot_request_exchange = get_exchange_name( + join_symbol, "snapshotrequest" + ) + self.snapshot_envelope_exchange = get_exchange_name( + join_symbol, "snapshotenvelope" + ) + queue_name_suffix = queue_params["name_suffix"].get() + queue_name_suffix = queue_name_suffix if queue_name_suffix else gethostname() + self.intel_queue = get_queue_name(join_symbol, "intel", queue_name_suffix) + self.sighting_queue = get_queue_name(join_symbol, "sighting", queue_name_suffix) + self.snapshot_request_queue = get_queue_name( + join_symbol, "snapshotrequest", queue_name_suffix + ) + self.snapshot_envelope_queue = get_queue_name( + join_symbol, "snapshotenvelope", queue_name_suffix + ) + queue_mode = "default" if not queue_params["lazy"].get(bool) else "lazy" + self.queue_kwargs = { + "durable": queue_params["durable"].get(bool), + "exclusive": queue_params["exclusive"].get(bool), + "auto_delete": queue_params["auto_delete"].get(bool), + "arguments": {"x-queue-mode": queue_mode}, + } + max_items = queue_params["max_items"].get() + if max_items: + self.queue_kwargs["arguments"]["x-max-length"] = max_items + + def __decode(self, msg: str, decoder: json.JSONDecoder): + """ + Decodes a JSON message with the given decoder. Returns the decoded object or + None and logs an error. + @param msg The message to decode + @param decoder The decoder class to use for decoding + """ + try: + return json.loads(msg, cls=decoder) + except Exception as e: + self.logger.error(f"RabbitMQ consumer: error decoding message {msg}: {e}") + return None + + def __provision_intel(self, channel, method_frame, header_frame, body): + """ + Callback to be invoked by the Pika library whenever a new message `body` has + been received from RabbitMQ on the intel queue. + @param channel: pika.Channel The channel that was received on + @param method: pika.spec.Basic.Deliver The pika delivery method (e.g., ACK) + @param properties: pika.spec.BasicProperties Pika properties + @param body: bytes The received message + """ + msg = self.__decode(body, IntelDecoder) + if msg: + self.__provision("threatbus/intel", msg) + self._channel.basic_ack(delivery_tag=method_frame.delivery_tag) + + def __provision_sighting(self, channel, method_frame, header_frame, body): + """ + Callback to be invoked by the Pika library whenever a new message `body` has + been received from RabbitMQ on the sighting queue. + @param channel: pika.Channel The channel that was received on + @param method: pika.spec.Basic.Deliver The pika delivery method (e.g., ACK) + @param properties: pika.spec.BasicProperties Pika properties + @param body: bytes The received message + """ + msg = self.__decode(body, SightingDecoder) + if msg: + self.__provision("threatbus/sighting", msg) + self._channel.basic_ack(delivery_tag=method_frame.delivery_tag) + + def __provision_snapshot_request(self, channel, method_frame, header_frame, body): + """ + Callback to be invoked by the Pika library whenever a new message `body` has + been received from RabbitMQ on the snapshot-request queue. + @param channel: pika.Channel The channel that was received on + @param method: pika.spec.Basic.Deliver The pika delivery method (e.g., ACK) + @param properties: pika.spec.BasicProperties Pika properties + @param body: bytes The received message + """ + msg = self.__decode(body, SnapshotRequestDecoder) + if msg: + self.__provision("threatbus/snapshotrequest", msg) + self._channel.basic_ack(delivery_tag=method_frame.delivery_tag) + + def __provision_snapshot_envelope(self, channel, method_frame, header_frame, body): + """ + Callback to be invoked by the Pika library whenever a new message `body` has + been received from RabbitMQ on the snapshot-envelope queue. + @param channel: pika.Channel The channel that was received on + @param method: pika.spec.Basic.Deliver The pika delivery method (e.g., ACK) + @param properties: pika.spec.BasicProperties Pika properties + @param body: bytes The received message + """ + msg = self.__decode(body, SnapshotEnvelopeDecoder) + if msg: + self.__provision("threatbus/snapshotenvelope", msg) + self._channel.basic_ack(delivery_tag=method_frame.delivery_tag) + + def __shutdown(self): + """ + Cancels the consumers and stops the RabbitMQ connection + """ + self._connection.ioloop.stop() + + def join(self, *args, **kwargs): + """ + Stops the RabbitMQ connection, disables automatic reconnection, and + forwards the join() call to the super class, i.e., the stop event will + be set, such that the semi-infinite run() method can exit. + """ + self._reconnect_delay = 0 + self.__shutdown() + super(RabbitMQConsumer, self).join(*args, **kwargs) + + def on_connection_open(self, _connection: pika.connection.Connection): + """ + Invoked as callback when opening a new pika.SelectConnecion. + Issues opening of a channel. + @param _connection The opened connection + """ + self._connection.channel(on_open_callback=self.on_channel_open) + + def on_connection_open_error( + self, _connection: pika.connection.Connection, err: Exception + ): + """ + Invoked as callback when opening a new pika.SelectConnecion. + Issues opening of a channel. + @param _connection The opened connection + """ + self.logger.error(f"RabbitMQ consumer: connection failed to open {err}") + self.__shutdown() # will restart automatically + + def on_connection_closed( + self, _connection: pika.connection.Connection, reason: Exception + ): + """ + Invoked when the connection to RabbitMQ is closed unexpectedly. Tries to + reconnect. + @param connection The closed connection + @param reason Exception representing reason for loss of connection + """ + self.logger.warning( + f"RabbitMQ consumer: connection closed unexpectedly. Reason: {reason}" + ) + self.__shutdown() # will restart automatically + + def on_channel_open(self, channel): + """ + Invoked as callback from connection.channel. See self.on_connecion_open + @param channel The successfully opened channel + """ + self._channel = channel + self._channel.add_on_close_callback(self.on_channel_closed) + + intel_cb = partial( + self.on_exchange_declare_ok, + userdata=(self.intel_exchange, self.intel_queue), + ) + self._channel.exchange_declare( + exchange=self.intel_exchange, exchange_type="fanout", callback=intel_cb + ) + sighting_cb = partial( + self.on_exchange_declare_ok, + userdata=(self.sighting_exchange, self.sighting_queue), + ) + self._channel.exchange_declare( + exchange=self.sighting_exchange, + exchange_type="fanout", + callback=sighting_cb, + ) + snapshotrequest_cb = partial( + self.on_exchange_declare_ok, + userdata=(self.snapshot_request_exchange, self.snapshot_request_queue), + ) + self._channel.exchange_declare( + exchange=self.snapshot_request_exchange, + exchange_type="fanout", + callback=snapshotrequest_cb, + ) + snapshotenvelope_cb = partial( + self.on_exchange_declare_ok, + userdata=(self.snapshot_envelope_exchange, self.snapshot_envelope_queue), + ) + self._channel.exchange_declare( + exchange=self.snapshot_envelope_exchange, + exchange_type="fanout", + callback=snapshotenvelope_cb, + ) + + def on_channel_closed(self, channel: pika.channel.Channel, reason: Exception): + """ + Invoked when RabbitMQ closes the channel unexpectedly. + Channels are usually closed if you attempt to do something that + violates the protocol, such as re-declare an exchange or queue with + different parameters. In this case, we'll close the connection + to shutdown the object. + @param channel The closed channel + @param reason The Exception the channel was closed + """ + self.logger.warning(f"RabbitMQ consumer: channel closed unexpectedly: {reason}") + self.__shutdown() # will restart automatically + + def on_exchange_declare_ok(self, _frame, userdata: Tuple[str, str]): + """ + Invoked as callback from exchange_declare. See self.on_channel_open. + Issues declaration of a queue. + @param _frame Unused pika response + @param userdata A tuple of exchange_name and queue_name. The exchange with the given name was created, hence this method is invoked. The queue name should be created. + """ + cb = partial(self.on_queue_declare_ok, userdata=userdata) + self._channel.queue_declare(queue=userdata[1], callback=cb) + + def on_queue_declare_ok(self, _frame, userdata: Tuple[str, str]): + """ + Inspects the given userdata (exchange_name, queue_name) and binds the queue to the exchange. + @param _frame Unused pika response + @param userdata A tuple of exchange_name and queue_name. Both have been created, hence this method is invoked. + """ + cb = partial(self.on_queue_bind_ok, userdata=userdata[1]) + self._channel.queue_bind(exchange=userdata[0], queue=userdata[1], callback=cb) + + def on_queue_bind_ok(self, _frame, userdata: str): + """ + Inspects the given userdata (exchange_name, queue_name) and binds the queue to the exchange. + @param _frame Unused pika response + @param userdata The name of the bound queue queue_name + """ + callbacks_by_type = { + self.intel_queue: self.__provision_intel, + self.sighting_queue: self.__provision_sighting, + self.snapshot_request_queue: self.__provision_snapshot_request, + self.snapshot_envelope_queue: self.__provision_snapshot_envelope, + } + self.consumers.append( + self._channel.basic_consume(userdata, callbacks_by_type[userdata]) + ) + + def run(self): + """ + Starts a RabbitMQ connection in a semi-infinite loop that reconnects + automatically on failure. + """ + while self._running(): + self.logger.debug("RabbitMQ consumer: connecting...") + self._connection = pika.SelectConnection( + parameters=self.conn_params, + on_open_callback=self.on_connection_open, + on_open_error_callback=self.on_connection_open_error, + on_close_callback=self.on_connection_closed, + ) + self._connection.ioloop.start() + + time.sleep(self._reconnect_delay) diff --git a/plugins/backbones/threatbus_rabbitmq/rabbitmq_publisher.py b/plugins/backbones/threatbus_rabbitmq/rabbitmq_publisher.py new file mode 100644 index 00000000..58164248 --- /dev/null +++ b/plugins/backbones/threatbus_rabbitmq/rabbitmq_publisher.py @@ -0,0 +1,149 @@ +import json +import pika +from logging import Logger +from multiprocessing import JoinableQueue +from queue import Empty +from retry import retry +import threatbus +from threatbus.data import ( + Intel, + Sighting, + SnapshotRequest, + SnapshotEnvelope, + IntelEncoder, + SightingEncoder, + SnapshotRequestEncoder, + SnapshotEnvelopeEncoder, +) +from threatbus_rabbitmq import get_exchange_name +from typing import Union + + +class RabbitMQPublisher(threatbus.StoppableWorker): + """ + Connects to RabbitMQ on the given host/port endpoint. Forwards all messages + from the `inq`, based on their type, to the appropriate RabbitMQ exchange. + """ + + def __init__( + self, + conn_params: pika.ConnectionParameters, + join_symbol: str, + inq: JoinableQueue, + logger: Logger, + ): + """ + @param conn_params Pika.ConnectionParameters to connect to RabbitMQ + @param join_symbol The symbol to use when determining queue and exchange names + @param inq A queue object to read messages from and publish them to RabbitMQ + @param logger A pre-configured Logger instance + """ + super(RabbitMQPublisher, self).__init__() + self.conn_params = conn_params + self.inq: JoinableQueue = inq + self.logger: Logger = logger + self.should_connect: bool = True + self._reconnect_delay: int = 5 + self._connection: Union[pika.BlockingConnection, None] = None + self._channel: Union[pika.channel.Channel, None] = None + + # create names and parameters + self.exchange_intel = get_exchange_name(join_symbol, "intel") + self.exchange_sighting = get_exchange_name(join_symbol, "sighting") + self.exchange_snapshot_request = get_exchange_name( + join_symbol, "snapshotrequest" + ) + self.exchange_snapshot_envelope = get_exchange_name( + join_symbol, "snapshotenvelope" + ) + + def join(self, *args, **kwargs): + """ + Stops the RabbitMQ connection, disables automatic reconnection, and + forwards the join() call to the super class, i.e., the stop event will + be set, such that the semi-infinite run() method can exit. + """ + self.should_connect = False + if self._channel: + self._channel.close() + if self._connection: + self._connection.close() + super(RabbitMQPublisher, self).join(*args, **kwargs) + + @retry(delay=5) + def __connect(self): + if not self.should_connect: + return + try: + self.logger.debug("RabbitMQ publisher: connecting...") + self.connection = pika.BlockingConnection(self.conn_params) + except KeyboardInterrupt: + return + except Exception as e: + self.logger.error("RabbitMQ publisher: connection failed to open.") + raise Exception("Connection failed") from e + self.channel = self.connection.channel() + self.channel.exchange_declare( + exchange=self.exchange_intel, exchange_type="fanout" + ) + self.channel.exchange_declare( + exchange=self.exchange_sighting, exchange_type="fanout" + ) + self.channel.exchange_declare( + exchange=self.exchange_snapshot_request, exchange_type="fanout" + ) + self.channel.exchange_declare( + exchange=self.exchange_snapshot_envelope, exchange_type="fanout" + ) + + def run(self): + self.__connect() + while self._running(): + try: + msg = self.inq.get(block=True, timeout=1) + except Empty: + continue + exchange = None + encoded_msg = None + try: + if type(msg) == Intel: + exchange = self.exchange_intel + encoded_msg = json.dumps(msg, cls=IntelEncoder) + elif type(msg) == Sighting: + exchange = self.exchange_sighting + encoded_msg = json.dumps(msg, cls=SightingEncoder) + elif type(msg) == SnapshotRequest: + exchange = self.exchange_snapshot_request + encoded_msg = json.dumps(msg, cls=SnapshotRequestEncoder) + elif type(msg) == SnapshotEnvelope: + exchange = self.exchange_snapshot_envelope + encoded_msg = json.dumps(msg, cls=SnapshotEnvelopeEncoder) + else: + self.logger.warn( + f"RabbitMQ publisher: discarding message with unknown type: {msg}" + ) + self.inq.task_done() + continue + except Exception as e: + self.logger.warn( + f"RabbitMQ publisher: discarding unparsable message {msg}: {e}" + ) + self.inq.task_done() + continue + try: + self.channel.basic_publish( + exchange=exchange, routing_key="", body=encoded_msg + ) + self.inq.task_done() + self.logger.debug( + f"RabbitMQ publisher: forwarded message to RabbitMQ: {msg}" + ) + except Exception as e: + self.logger.error(f"RabbitMQ publisher: failed to publish: {e}") + if msg: + self.inq.put(msg) + self.inq.task_done() + self.logger.error( + f"RabbitMQ publisher: pushing back message {msg}: {e}" + ) + self.__connect() diff --git a/plugins/backbones/threatbus_rabbitmq/test_plugin.py b/plugins/backbones/threatbus_rabbitmq/test_helpers.py similarity index 89% rename from plugins/backbones/threatbus_rabbitmq/test_plugin.py rename to plugins/backbones/threatbus_rabbitmq/test_helpers.py index 9eec0ffd..0950f269 100644 --- a/plugins/backbones/threatbus_rabbitmq/test_plugin.py +++ b/plugins/backbones/threatbus_rabbitmq/test_helpers.py @@ -1,9 +1,6 @@ import unittest -from threatbus_rabbitmq.plugin import ( - get_exchange_name, - get_queue_name, -) +from threatbus_rabbitmq import get_exchange_name, get_queue_name class TestNameCreation(unittest.TestCase): diff --git a/tests/integration/test_message_roundtrips.py b/tests/integration/test_message_roundtrips.py index 905cc6b9..49dad92c 100644 --- a/tests/integration/test_message_roundtrips.py +++ b/tests/integration/test_message_roundtrips.py @@ -3,7 +3,7 @@ import json import queue import threading -from threatbus import start +from threatbus import start as start_threatbus from threatbus.data import ( Intel, IntelData, @@ -27,12 +27,12 @@ def setUpClass(cls): super(TestMessageRoundtrip, cls).setUpClass() config = confuse.Configuration("threatbus") config.set_file("config_integration_test.yaml") - cls.threatbus = threading.Thread( - target=start, - args=(config,), - daemon=True, - ) - cls.threatbus.start() + cls.threatbus = start_threatbus(config) + time.sleep(1) + + @classmethod + def tearDownClass(cls): + cls.threatbus.stop() time.sleep(1) def test_zeek_plugin_message_roundtrip(self): @@ -47,6 +47,7 @@ def test_zeek_plugin_message_roundtrip(self): target=zeek_receiver.forward, args=(items, result_q), daemon=False ) rec.start() + time.sleep(1) zeek_sender.send_generic("threatbus/intel", items) time.sleep(1) self.assertEqual(result_q.qsize(), items) diff --git a/tests/integration/test_rabbitmq.py b/tests/integration/test_rabbitmq.py index b2174c74..9d161ae0 100644 --- a/tests/integration/test_rabbitmq.py +++ b/tests/integration/test_rabbitmq.py @@ -1,7 +1,7 @@ import confuse from datetime import datetime, timedelta +from multiprocessing import JoinableQueue from plugins.backbones.threatbus_rabbitmq import plugin -from queue import Queue from threatbus.data import ( Intel, IntelData, @@ -17,6 +17,42 @@ class TestRoundtrips(unittest.TestCase): + @classmethod + def setUpClass(cls): + # setup the backbone with a fan-in queue + config = confuse.Configuration("threatbus") + config["rabbitmq"].add({}) + config["rabbitmq"]["host"] = "localhost" + config["rabbitmq"]["port"] = 35672 + config["rabbitmq"]["username"] = "guest" + config["rabbitmq"]["password"] = "guest" + config["rabbitmq"]["vhost"] = "/" + config["rabbitmq"]["naming_join_pattern"] = "." + config["rabbitmq"]["queue"].add({}) + config["rabbitmq"]["queue"]["durable"] = False + config["rabbitmq"]["queue"]["auto_delete"] = True + config["rabbitmq"]["queue"]["exclusive"] = False + config["rabbitmq"]["queue"]["lazy"] = False + config["rabbitmq"]["queue"]["max_items"] = 10 + config["console"] = False + config["file"] = False + + cls.inq = JoinableQueue() + plugin.run(config, config, cls.inq) + + # subscribe this test case as concumer + cls.outq = JoinableQueue() + plugin.subscribe("threatbus/intel", cls.outq) + plugin.subscribe("threatbus/sighting", cls.outq) + plugin.subscribe("threatbus/snapshotrequest", cls.outq) + plugin.subscribe("threatbus/snapshotenvelope", cls.outq) + time.sleep(1) + + @classmethod + def tearDownClass(cls): + plugin.stop() + time.sleep(1) + def setUp(self): self.ts = datetime.now().astimezone() self.intel_id = "intel-42" @@ -76,35 +112,6 @@ def setUp(self): MessageType.SIGHTING, self.snapshot_id, self.sighting ) - # setup the backbone with a fan-in queue - config = confuse.Configuration("threatbus") - config["rabbitmq"].add({}) - config["rabbitmq"]["host"] = "localhost" - config["rabbitmq"]["port"] = 35672 - config["rabbitmq"]["username"] = "guest" - config["rabbitmq"]["password"] = "guest" - config["rabbitmq"]["vhost"] = "/" - config["rabbitmq"]["naming_join_pattern"] = "." - config["rabbitmq"]["queue"].add({}) - config["rabbitmq"]["queue"]["durable"] = False - config["rabbitmq"]["queue"]["auto_delete"] = True - config["rabbitmq"]["queue"]["exclusive"] = False - config["rabbitmq"]["queue"]["lazy"] = False - config["rabbitmq"]["queue"]["max_items"] = 10 - config["console"] = False - config["file"] = False - - self.inq = Queue() - plugin.run(config, config, self.inq) - - # subscribe this test case as concumer - self.outq = Queue() - plugin.subscribe("threatbus/intel", self.outq) - plugin.subscribe("threatbus/sighting", self.outq) - plugin.subscribe("threatbus/snapshotrequest", self.outq) - plugin.subscribe("threatbus/snapshotenvelope", self.outq) - time.sleep(1) - def test_intel_message_roundtrip(self): """ Passes an Intel item to RabbitMQ and reads back the exact same item. diff --git a/tests/integration/test_zeek_app.py b/tests/integration/test_zeek_app.py index 05ca987a..b87983a5 100644 --- a/tests/integration/test_zeek_app.py +++ b/tests/integration/test_zeek_app.py @@ -5,7 +5,7 @@ import queue import subprocess import threading -from threatbus import start +from threatbus import start as start_threatbus import time import unittest @@ -63,12 +63,11 @@ class TestZeekSightingReports(unittest.TestCase): def setUp(self): config = confuse.Configuration("threatbus") config.set_file("config_integration_test.yaml") - self.threatbus = threading.Thread( - target=start, - args=(config,), - daemon=True, - ) - self.threatbus.start() + self.threatbus = start_threatbus(config) + time.sleep(1) + + def tearDown(self): + self.threatbus.stop() time.sleep(1) def test_intel_sighting_roundtrip(self): diff --git a/tests/integration/test_zmq_app_management.py b/tests/integration/test_zmq_app_management.py index 39d78055..e0f540e4 100644 --- a/tests/integration/test_zmq_app_management.py +++ b/tests/integration/test_zmq_app_management.py @@ -1,7 +1,6 @@ import confuse import json -import threading -from threatbus import start +from threatbus import start as start_threatbus import time import unittest import zmq @@ -37,12 +36,12 @@ def setUpClass(cls): super(TestMessageRoundtrip, cls).setUpClass() config = confuse.Configuration("threatbus") config.set_file("config_integration_test.yaml") - cls.threatbus = threading.Thread( - target=start, - args=(config,), - daemon=True, - ) - cls.threatbus.start() + cls.threatbus = start_threatbus(config) + time.sleep(1) + + @classmethod + def tearDownClass(cls): + cls.threatbus.stop() time.sleep(1) def setUp(self): diff --git a/tests/utils/rabbitmq_sender.py b/tests/utils/rabbitmq_sender.py new file mode 100644 index 00000000..eedc704e --- /dev/null +++ b/tests/utils/rabbitmq_sender.py @@ -0,0 +1,32 @@ +from datetime import datetime +import json +import pika +from threatbus.data import Intel, IntelData, IntelType, Operation, IntelEncoder + +## Dummy intel data +intel_id = "intel-42" +indicator = "6.6.6.6" +intel_type = IntelType.IPSRC +operation = Operation.ADD +intel_data = IntelData(indicator, intel_type, foo=23, more_args="MORE ARGS") +intel = Intel( + datetime.strptime("2020-11-02 17:00:00", "%Y-%m-%d %H:%M:%S"), + intel_id, + intel_data, + operation, +) + +intel_json = json.dumps(intel, cls=IntelEncoder) + +## rabbitmq +host = "localhost" +port = "5672" +vhost = "/" +credentials = pika.PlainCredentials("guest", "guest") +conn_params = pika.ConnectionParameters(host, port, vhost, credentials) + +connection = pika.BlockingConnection(conn_params) +channel = connection.channel() + +for i in range(100): + channel.basic_publish(exchange="threatbus.intel", routing_key="", body=intel_json) diff --git a/tests/utils/zmq_receiver.py b/tests/utils/zmq_receiver.py index 7cf9ee06..c16ef639 100644 --- a/tests/utils/zmq_receiver.py +++ b/tests/utils/zmq_receiver.py @@ -14,7 +14,7 @@ def send_manage_message(action, topic): context = zmq.Context() socket = context.socket(zmq.REQ) socket.setsockopt(zmq.LINGER, 0) - socket.connect(f"tcp://127.0.0.1:13370") + socket.connect("tcp://127.0.0.1:13370") socket.send_json(action) poller = zmq.Poller() poller.register(socket, zmq.POLLIN) @@ -55,7 +55,7 @@ def receive(n: int, topics: list): def forward(n: int, topics: list, q: Queue): """ - Receives exactly n messages via ZeroMQ and forwards them to the result queueu q + Receives exactly n messages via ZeroMQ and forwards them to the result queue @param n Items to receive @param topics List of topics to subscribe to @param q The queue to push received items to diff --git a/threatbus/__init__.py b/threatbus/__init__.py index c5ae64fa..4dee72fb 100644 --- a/threatbus/__init__.py +++ b/threatbus/__init__.py @@ -1,5 +1,6 @@ import pluggy from .threatbus import ThreatBus, start +from .stoppable_worker import StoppableWorker app = pluggy.HookimplMarker("threatbus.app") """Marker to be imported and used in app-plugins""" diff --git a/threatbus/appspecs.py b/threatbus/appspecs.py index 73c102c3..f0fbc727 100644 --- a/threatbus/appspecs.py +++ b/threatbus/appspecs.py @@ -1,6 +1,6 @@ import pluggy from confuse import Subview -from queue import Queue +from multiprocessing import JoinableQueue from typing import Callable from threatbus.data import SnapshotRequest @@ -11,7 +11,7 @@ def run( config: Subview, logging: Subview, - inq: Queue, + inq: JoinableQueue, subscribe_callback: Callable, unsubscribe_callback: Callable, ): @@ -28,7 +28,12 @@ def run( @hookspec -def snapshot(snapshot_request: SnapshotRequest, result_q: Queue): +def stop(): + """Stops all Threads that the plugin has started.""" + + +@hookspec +def snapshot(snapshot_request: SnapshotRequest, result_q: JoinableQueue): """ Perform a snapshot, based on the given `snapshot_request`. Snapshots are collected up to the requested earliest date. Results of the type diff --git a/threatbus/backbonespecs.py b/threatbus/backbonespecs.py index 56a5b3fa..0cbd8a69 100644 --- a/threatbus/backbonespecs.py +++ b/threatbus/backbonespecs.py @@ -1,12 +1,12 @@ from confuse import Subview import pluggy -from queue import Queue +from multiprocessing import JoinableQueue hookspec = pluggy.HookspecMarker("threatbus.backbone") @hookspec -def run(config: Subview, logging: Subview, inq: Queue): +def run(config: Subview, logging: Subview, inq: JoinableQueue): """Runs / starts a plugin spec with a configuration object @param config A configuration object for the app @param logging A configuration object for the logger @@ -15,7 +15,12 @@ def run(config: Subview, logging: Subview, inq: Queue): @hookspec -def subscribe(topic: str, q: Queue): +def stop(): + """Stops all Threads that the plugin has started.""" + + +@hookspec +def subscribe(topic: str, q: JoinableQueue): """Subscribes the given queue to the requested topic. @param topic Subscribe to this topic (string) @param q A queue object to forward all messages for the given topic @@ -23,7 +28,7 @@ def subscribe(topic: str, q: Queue): @hookspec -def unsubscribe(topic: str, q: Queue): +def unsubscribe(topic: str, q: JoinableQueue): """Unubscribes the given queue from the requested topic @param topic Unsubscribe from this topic (string) @param q The queue object that was subscribed to the given topic diff --git a/threatbus/data.py b/threatbus/data.py index 5203f21e..8c30b107 100644 --- a/threatbus/data.py +++ b/threatbus/data.py @@ -4,6 +4,7 @@ from dateutil import parser from enum import auto, Enum, unique import json +from typing import Union @dataclass @@ -86,7 +87,9 @@ class IntelData(dict): The 'intel_type' is a threatbus.data.IntelType """ - def __init__(self, indicator: str or tuple, intel_type: IntelType, *args, **kw): + def __init__( + self, indicator: Union[str, tuple], intel_type: IntelType, *args, **kw + ): super(IntelData, self).__init__(*args, **kw) assert indicator, "Intel indicator must be set" assert ( @@ -115,7 +118,7 @@ class Sighting: ts: datetime intel: str context: dict - ioc: tuple or None + ioc: Union[tuple, None] @dataclass() @@ -128,7 +131,7 @@ class SnapshotEnvelope: snapshot_type: MessageType snapshot_id: str - body: Intel or Sighting + body: Union[Intel, Sighting] @dataclass diff --git a/threatbus/stoppable_worker.py b/threatbus/stoppable_worker.py new file mode 100644 index 00000000..c88b9846 --- /dev/null +++ b/threatbus/stoppable_worker.py @@ -0,0 +1,24 @@ +from threading import Event, Thread + + +class StoppableWorker(Thread): + """ + A threading.Thread with a dedicated method, called _running(), for checking + if the thread should continue running. Invoking it's join() method changes + the return value of _running(). Use this method as exit condition to model + semi-infinite loops. + """ + + def __init__(self): + super(StoppableWorker, self).__init__() + self._stop_event = Event() + + def __stop(self): + self._stop_event.set() + + def _running(self): + return not self._stop_event.is_set() + + def join(self, *args, **kwargs): + self.__stop() + super(StoppableWorker, self).join(*args, **kwargs) diff --git a/threatbus/subscriptions.py b/threatbus/subscriptions.py new file mode 100644 index 00000000..e69de29b diff --git a/threatbus/threatbus.py b/threatbus/threatbus.py index cc463f83..98be3eec 100644 --- a/threatbus/threatbus.py +++ b/threatbus/threatbus.py @@ -3,14 +3,16 @@ from datetime import timedelta from logging import Logger import pluggy -from queue import Queue -from threatbus import appspecs, backbonespecs, logger +from multiprocessing import JoinableQueue +from queue import Empty +import signal +from threatbus import appspecs, backbonespecs, logger, stoppable_worker from threatbus.data import MessageType, SnapshotRequest, SnapshotEnvelope from threading import Lock from uuid import uuid4 -class ThreatBus: +class ThreatBus(stoppable_worker.StoppableWorker): def __init__( self, backbones: pluggy.hooks._HookRelay, @@ -18,12 +20,13 @@ def __init__( logger: Logger, config: confuse.Subview, ): + super(ThreatBus, self).__init__() self.backbones = backbones self.apps = apps self.config = config self.logger = logger - self.inq = Queue() # fan-in everything, provisioned by backbone - self.snapshot_q = Queue() + self.inq = JoinableQueue() # fan-in everything, provisioned by backbone + self.snapshot_q = JoinableQueue() self.lock = Lock() self.snapshots = dict() @@ -33,15 +36,16 @@ def handle_snapshots(self): all implementing app plugins. Forwards envelopes to the requesting app or discards them accordingly. """ - while True: - msg = self.snapshot_q.get(block=True) + while self._running(): + try: + msg = self.snapshot_q.get(block=True, timeout=1) + except Empty: + continue if type(msg) is SnapshotRequest: self.logger.debug(f"Received SnapshotRequest: {msg}") self.apps.snapshot(snapshot_request=msg, result_q=self.inq) - elif type(msg) is SnapshotEnvelope: + elif type(msg) is SnapshotEnvelope and msg.snapshot_id in self.snapshots: self.logger.debug(f"Received SnapshotEnvelope: {msg}") - if msg.snapshot_id not in self.snapshots: - continue self.snapshots[msg.snapshot_id].put(msg.body) else: self.logger.warn( @@ -50,7 +54,7 @@ def handle_snapshots(self): self.snapshot_q.task_done() def request_snapshot( - self, topic: str, dst_q: Queue, snapshot_id: str, time_delta: timedelta + self, topic: str, dst_q: JoinableQueue, snapshot_id: str, time_delta: timedelta ): """ Create a new SnapshotRequest and push it to the inq, so that the @@ -81,7 +85,7 @@ def request_snapshot( self.inq.put(req) - def subscribe(self, topic: str, q: Queue, time_delta: timedelta = None): + def subscribe(self, topic: str, q: JoinableQueue, time_delta: timedelta = None): """ Accepts a new subscription for a given topic and queue pointer. Forwards that subscription to all managed backbones. @@ -99,7 +103,7 @@ def subscribe(self, topic: str, q: Queue, time_delta: timedelta = None): self.request_snapshot(topic, q, snapshot_id, time_delta) return snapshot_id - def unsubscribe(self, topic: str, q: Queue): + def unsubscribe(self, topic: str, q: JoinableQueue): """ Removes subscription for a given topic and queue pointer from all managed backbones. @@ -107,6 +111,24 @@ def unsubscribe(self, topic: str, q: Queue): assert isinstance(topic, str), "topic must be string" self.backbones.unsubscribe(topic=topic, q=q) + def stop(self): + """ + Stops all running threads and Threat Bus + """ + self.logger.info("Stopping plugins...") + self.backbones.stop() + self.apps.stop() + self.logger.info("Stopping Threat Bus...") + self.join() + + def stop_signal(self, signal, frame): + """ + Implements Python's signal.signal handler. + See https://docs.python.org/3/library/signal.html#signal.signal + Stops all running threads and Threat Bus + """ + self.stop() + def run(self): self.logger.info("Starting plugins...") logging = self.config["logging"] @@ -161,8 +183,10 @@ def start(config: confuse.Subview): ) backbones.unregister(name=unwanted_backbones) - bus = ThreatBus(backbones.hook, apps.hook, tb_logger, config) - bus.run() + bus_thread = ThreatBus(backbones.hook, apps.hook, tb_logger, config) + signal.signal(signal.SIGINT, bus_thread.stop_signal) + bus_thread.start() + return bus_thread def main():