diff --git a/README.md b/README.md index 18bf361..911aeb8 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ - [Overview](#overview) - [Frontend broker](#frontend-broker) - [POST /api/v1/wg/key/exchange](#post-apiv1wgkeyexchange) + - [POST /api/v2/wg/key/exchange](#post-apiv2wgkeyexchange) - [Backend worker](#backend-worker) - [Installation](#installation) - [Configuration](#configuration) @@ -41,6 +42,7 @@ The frontend broker exposes the following API endpoints for use: ``` /api/v1/wg/key/exchange +/api/v2/wg/key/exchange ``` The listen address and port for the Flask server can be configured in `wgkex.yaml` under the `broker_listen` key: @@ -66,6 +68,35 @@ JSON POST'd to this endpoint should be in this format: The broker will validate the domain and public key, and if valid, will push the key onto the MQTT bus. + +#### POST /api/v2/wg/key/exchange + +JSON POST'd to this endpoint should be in this format: + +```json +{ + "domain": "CONFIGURED_DOMAIN", + "public_key": "PUBLIC_KEY" +} +``` + +The broker will validate the domain and public key, and if valid, will push the key onto the MQTT bus. +Additionally it chooses a worker (aka gateway, endpoint) that the client should connect to. +The response is JSON data containing the connection details for the chosen gateway: + +```json +{ + "Endpoint": { + "Address": "GATEWAY_ADDRESS", + "Port": "GATEWAY_WIREGUARD_PORT", + "AllowedIPs": [ + "GATEWAY_WIREGUARD_INTERFACE_ADDRESS" + ], + "PublicKey": "GATEWAY_PUBLIC_KEY" + } +} +``` + ### Backend worker The backend (worker) waits for new keys to appear on the MQTT message bus. Once a new key appears, the worker performs @@ -129,13 +160,25 @@ Worker: python3 -c 'from wgkex.worker.app import main; main()' ``` -## Client usage + +## Development + +### Unit tests + +The test can be run using `bazel test ... --test_output=all` or `python3 -m unittest discover -p '*_test.py'`. + +### Client The client can be used via CLI: ``` -$ wget -q -O- --post-data='{"domain": "ffmuc_welt","public_key": "o52Ge+Rpj4CUSitVag9mS7pSXUesNM0ESnvj/wwehkg="}' --header='Content-Type:application/json' 'http://127.0.0.1:5000/api/v1/wg/key/exchange' +$ wget -q -O- --post-data='{"domain": "ffmuc_welt","public_key": "o52Ge+Rpj4CUSitVag9mS7pSXUesNM0ESnvj/wwehkg="}' --header='Content-Type:application/json' 'http://127.0.0.1:5000/api/v2/wg/key/exchange' { + "Endpoint": { + "Address": "gw04.ext.ffmuc.net:40011", + "LinkAddress": "fe80::27c:16ff:fec0:6c74", + "PublicKey": "TszFS3oFRdhsJP3K0VOlklGMGYZy+oFCtlaghXJqW2g=" + }, "Message": "OK" } ``` @@ -146,7 +189,7 @@ Or via python: import requests key_data = {"domain": "ffmuc_welt","public_key": "o52Ge+Rpj4CUSitVag9mS7pSXUesNM0ESnvj/wwehkg="} broker_url = "http://127.0.0.1:5000" -push_key = requests.get(f'{broker_url}/api/v1/wg/key/exchange', json=key_data) +push_key = requests.get(f'{broker_url}/api/v2/wg/key/exchange', json=key_data) print(f'Key push was: {push_key.json().get("Message")}') ``` @@ -173,6 +216,13 @@ sudo ip link set wg-welt up sudo ip link set vx-welt up ``` +### MQTT topics + +- Publishing keys broker->worker: `wireguard/{domain}/{worker}` +- Publishing metrics worker->broker: `wireguard-metrics/{domain}/{worker}/connected_peers` +- Publishing worker status: `wireguard-worker/{worker}/status` +- Publishing worker data: `wireguard-worker/{worker}/{domain}/data` + ## Contact [Freifunk Munich Mattermost](https://chat.ffmuc.net) diff --git a/requirements.txt b/requirements.txt index 97a41ba..1821412 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,4 @@ waitress~=2.1.2 ipaddress~=1.0.23 mock~=5.1.0 coverage -paho-mqtt~=1.6.1 \ No newline at end of file +paho-mqtt~=1.6.1 diff --git a/wgkex.yaml.example b/wgkex.yaml.example index 7b82c71..30340fe 100644 --- a/wgkex.yaml.example +++ b/wgkex.yaml.example @@ -1,3 +1,4 @@ +# [broker] The domains that should be accepted by clients and for which matching WireGuard interfaces exist domains: - ffmuc_muc_cty - ffmuc_muc_nord @@ -6,6 +7,25 @@ domains: - ffmuc_muc_west - ffmuc_welt - ffwert_city +# [broker, worker] The prefix is trimmed from the domain name and replaced with 'wg-' and 'vx-' +# to calculate the WireGuard and VXLAN interface names +domain_prefixes: + - ffmuc_ + - ffdon_ + - ffwert_ +# [broker] The dict of workers mapping their hostname to their respective weight for weighted peer distribution +workers: + gw04.in.ffmuc.net: + weight: 30 + gw05.in.ffmuc.net: + weight: 30 + gw06.in.ffmuc.net: + weight: 20 + gw07.in.ffmuc.net: + weight: 20 +# [worker] The external hostname of this worker +externalName: gw04.ext.ffmuc.net +# [broker, worker] MQTT connection informations mqtt: broker_url: broker.hivemq.com broker_port: 1883 @@ -13,13 +33,11 @@ mqtt: password: SECRET keepalive: 5 tls: False +# [broker] broker_listen: host: 0.0.0.0 port: 5000 -domain_prefixes: - - ffmuc_ - - ffdon_ - - ffwert_ +# [broker, worker] logging_config: formatters: standard: diff --git a/wgkex/broker/BUILD b/wgkex/broker/BUILD index 260fe45..780f340 100644 --- a/wgkex/broker/BUILD +++ b/wgkex/broker/BUILD @@ -1,6 +1,26 @@ load("@rules_python//python:defs.bzl", "py_binary", "py_test") load("@pip//:requirements.bzl", "requirement") +py_library( + name = "metrics", + srcs = ["metrics.py"], + visibility = ["//visibility:public"], + deps = [ + "//wgkex/common:mqtt", + "//wgkex/common:logger", + "//wgkex/config:config", + ], +) + +py_test( + name="metrics_test", + srcs=["metrics_test.py"], + deps = [ + "//wgkex/broker:metrics", + requirement("mock"), + ], +) + py_binary( name="app", srcs=["app.py"], @@ -11,5 +31,7 @@ py_binary( requirement("flask-mqtt"), requirement("waitress"), "//wgkex/config:config", + "//wgkex/common:mqtt", + ":metrics" ], ) diff --git a/wgkex/broker/app.py b/wgkex/broker/app.py index f01ec3f..1d753ff 100644 --- a/wgkex/broker/app.py +++ b/wgkex/broker/app.py @@ -1,22 +1,25 @@ #!/usr/bin/env python3 """wgkex broker""" -import re import dataclasses -import logging -from typing import Tuple, Any - -from flask import Flask -from flask import abort -from flask import jsonify -from flask import render_template -from flask import request +import json +import re +from typing import Dict, Tuple, Any + +import paho.mqtt.client as mqtt_client +from flask import Flask, render_template, request, Response from flask.app import Flask as Flask_app from flask_mqtt import Mqtt -import paho.mqtt.client as mqtt_client from waitress import serve from wgkex.config import config from wgkex.common import logger +from wgkex.common.utils import is_valid_domain +from wgkex.broker.metrics import WorkerMetricsCollection +from wgkex.common.mqtt import ( + CONNECTED_PEERS_METRIC, + TOPIC_WORKER_STATUS, + TOPIC_WORKER_WG_DATA, +) WG_PUBKEY_PATTERN = re.compile(r"^[A-Za-z0-9+/]{42}[AEIMQUYcgkosw480]=$") @@ -43,7 +46,9 @@ def from_dict(cls, msg: dict) -> "KeyExchange": A KeyExchange object. """ public_key = is_valid_wg_pubkey(msg.get("public_key")) - domain = is_valid_domain(msg.get("domain")) + domain = str(msg.get("domain")) + if not is_valid_domain(domain): + raise ValueError(f"Domain {domain} not in configured domains.") return cls(public_key=public_key, domain=domain) @@ -54,8 +59,7 @@ def _fetch_app_config() -> Flask_app: A created Flask app. """ app = Flask(__name__) - # TODO(ruairi): Refactor load_config to return Dataclass. - mqtt_cfg = config.Config.from_dict(config.load_config()).mqtt + mqtt_cfg = config.get_config().mqtt app.config["MQTT_BROKER_URL"] = mqtt_cfg.broker_url app.config["MQTT_BROKER_PORT"] = mqtt_cfg.broker_port app.config["MQTT_USERNAME"] = mqtt_cfg.username @@ -67,34 +71,86 @@ def _fetch_app_config() -> Flask_app: app = _fetch_app_config() mqtt = Mqtt(app) +worker_metrics = WorkerMetricsCollection() +worker_data: Dict[Tuple[str, str], Dict] = {} @app.route("/", methods=["GET"]) -def index() -> None: +def index() -> str: """Returns main page""" return render_template("index.html") @app.route("/api/v1/wg/key/exchange", methods=["POST"]) -def wg_key_exchange() -> Tuple[str, int]: +def wg_api_v1_key_exchange() -> Tuple[Response | Dict, int]: """Retrieves a new key and validates. - Returns: Status message. """ try: data = KeyExchange.from_dict(request.get_json(force=True)) - except TypeError as ex: - return abort(400, jsonify({"error": {"message": str(ex)}})) + except Exception as ex: + return {"error": {"message": str(ex)}}, 400 + + key = data.public_key + domain = data.domain + # in case we want to decide here later we want to publish it only to dedicated gateways + gateway = "all" + logger.info(f"wg_api_v1_key_exchange: Domain: {domain}, Key:{key}") + + mqtt.publish(f"wireguard/{domain}/{gateway}", key) + return {"Message": "OK"}, 200 + + +@app.route("/api/v2/wg/key/exchange", methods=["POST"]) +def wg_api_v2_key_exchange() -> Tuple[Response | Dict, int]: + """Retrieves a new key, validates it and responds with a worker/gateway the client should connect to. + + Returns: + Status message, Endpoint with address/domain, port pubic key and link address. + """ + try: + data = KeyExchange.from_dict(request.get_json(force=True)) + except Exception as ex: + return {"error": {"message": str(ex)}}, 400 key = data.public_key domain = data.domain # in case we want to decide here later we want to publish it only to dedicated gateways gateway = "all" - logger.info(f"wg_key_exchange: Domain: {domain}, Key:{key}") + logger.info(f"wg_api_v2_key_exchange: Domain: {domain}, Key:{key}") mqtt.publish(f"wireguard/{domain}/{gateway}", key) - return jsonify({"Message": "OK"}), 200 + + best_worker, diff, current_peers = worker_metrics.get_best_worker(domain) + if best_worker is None: + logger.warning(f"No worker online for domain {domain}") + return { + "error": { + "message": "no gateway online for this domain, please check the domain value and try again later" + } + }, 400 + + worker_metrics.update( + best_worker, domain, CONNECTED_PEERS_METRIC, current_peers + 1 + ) + logger.debug( + f"Chose worker {best_worker} with {current_peers} connected clients ({diff})" + ) + + w_data = worker_data.get((best_worker, domain), None) + if w_data is None: + logger.error(f"Couldn't get worker endpoint data for {best_worker}/{domain}") + return {"error": {"message": "could not get gateway data"}}, 500 + + endpoint = { + "Address": w_data.get("ExternalAddress"), + "Port": str(w_data.get("Port")), + "AllowedIPs": [w_data.get("LinkAddress")], + "PublicKey": w_data.get("PublicKey"), + } + + return {"Endpoint": endpoint}, 200 @mqtt.on_connect() @@ -108,7 +164,69 @@ def handle_mqtt_connect( app.config["MQTT_BROKER_URL"], app.config["MQTT_BROKER_PORT"] ) ) - # mqtt.subscribe("wireguard/#") + mqtt.subscribe("wireguard-metrics/#") + mqtt.subscribe(TOPIC_WORKER_STATUS.format(worker="+")) + mqtt.subscribe(TOPIC_WORKER_WG_DATA.format(worker="+", domain="+")) + + +@mqtt.on_topic("wireguard-metrics/#") +def handle_mqtt_message_metrics( + client: mqtt_client.Client, userdata: bytes, message: mqtt_client.MQTTMessage +) -> None: + """Processes published metrics from workers.""" + logger.debug( + f"MQTT message received on {message.topic}: {message.payload.decode()}" + ) + _, domain, worker, metric = message.topic.split("/", 3) + if not is_valid_domain(domain): + logger.error(f"Domain {domain} not in configured domains") + return + + if not worker or not metric: + logger.error("Ignored MQTT message with empty worker or metrics label") + return + + data = int(message.payload) + + logger.info(f"Update worker metrics: {metric} on {worker}/{domain} = {data}") + worker_metrics.update(worker, domain, metric, data) + + +@mqtt.on_topic(TOPIC_WORKER_STATUS.format(worker="+")) +def handle_mqtt_message_status( + client: mqtt_client.Client, userdata: bytes, message: mqtt_client.MQTTMessage +) -> None: + """Processes status messages from workers.""" + _, worker, _ = message.topic.split("/", 2) + + status = int(message.payload) + if status < 1: + logger.warning(f"Marking worker as offline: {worker}") + worker_metrics.set_offline(worker) + else: + logger.warning(f"Marking worker as online: {worker}") + worker_metrics.set_online(worker) + + +@mqtt.on_topic(TOPIC_WORKER_WG_DATA.format(worker="+", domain="+")) +def handle_mqtt_message_data( + client: mqtt_client.Client, userdata: bytes, message: mqtt_client.MQTTMessage +) -> None: + """Processes data messages from workers. + + Stores them in a local dict""" + _, worker, domain, _ = message.topic.split("/", 3) + if not is_valid_domain(domain): + logger.error(f"Domain {domain} not in configured domains.") + return + + data = json.loads(message.payload) + if not isinstance(data, dict): + logger.error("Invalid worker data received for %s/%s: %s", worker, domain, data) + return + + logger.info("Worker data received for %s/%s: %s", worker, domain, data) + worker_data[(worker, domain)] = data @mqtt.on_message() @@ -116,7 +234,6 @@ def handle_mqtt_message( client: mqtt_client.Client, userdata: bytes, message: mqtt_client.MQTTMessage ) -> None: """Prints message contents.""" - # TODO(ruairi): Clarify current usage of this function. logger.debug( f"MQTT message received on {message.topic}: {message.payload.decode()}" ) @@ -140,33 +257,26 @@ def is_valid_wg_pubkey(pubkey: str) -> str: return pubkey -def is_valid_domain(domain: str) -> str: - """Verifies if the domain is configured. - - Arguments: - domain: The domain to verify. - - Raises: - ValueError: If the domain is not configured. +def join_host_port(host: str, port: str) -> str: + """Concatenate a port string with a host string using a colon. + The host may be either a hostname, IPv4 or IPv6 address. + An IPv6 address as host will be automatically encapsulated in square brackets. Returns: - The domain. + The joined host:port string """ - # TODO(ruairi): Refactor to return bool. - if domain not in config.fetch_from_config("domains"): - raise ValueError( - f'Domains {domain} not in configured domains({config.fetch_from_config("domains")}) a valid domain' - ) - return domain + if host.find(":") >= 0: + return "[" + host + "]:" + port + return host + ":" + port if __name__ == "__main__": listen_host = None listen_port = None - listen_config = config.fetch_from_config("broker_listen") + listen_config = config.get_config().broker_listen if listen_config is not None: - listen_host = listen_config.get("host") - listen_port = listen_config.get("port") + listen_host = listen_config.host + listen_port = listen_config.port serve(app, host=listen_host, port=listen_port) diff --git a/wgkex/broker/metrics.py b/wgkex/broker/metrics.py new file mode 100644 index 0000000..a2e2893 --- /dev/null +++ b/wgkex/broker/metrics.py @@ -0,0 +1,120 @@ +import dataclasses +from operator import itemgetter +from typing import Any, Dict, Optional, Tuple + +from wgkex.config import config +from wgkex.common import logger +from wgkex.common.mqtt import CONNECTED_PEERS_METRIC + + +@dataclasses.dataclass +class WorkerMetrics: + """Metrics of a single worker""" + + worker: str + # domain -> [metric name -> metric data] + domain_data: Dict[str, Dict[str, Any]] = dataclasses.field(default_factory=dict) + online: bool = False + + def is_online(self, domain: str = "") -> bool: + if domain: + return ( + self.online + and self.get_domain_metrics(domain).get(CONNECTED_PEERS_METRIC, -1) >= 0 + ) + else: + return self.online + + def get_domain_metrics(self, domain: str) -> Dict[str, Any]: + return self.domain_data.get(domain, {}) + + def set_metric(self, domain: str, metric: str, value: Any) -> None: + if domain in self.domain_data: + self.domain_data[domain][metric] = value + else: + self.domain_data[domain] = {metric: value} + + +@dataclasses.dataclass +class WorkerMetricsCollection: + """A container for all worker metrics""" + + # worker -> WorkerMetrics + data: Dict[str, WorkerMetrics] = dataclasses.field(default_factory=dict) + + def get(self, worker: str) -> WorkerMetrics: + return self.data.get(worker, WorkerMetrics(worker=worker)) + + def set(self, worker: str, metrics: WorkerMetrics) -> None: + self.data[worker] = metrics + + def update(self, worker: str, domain: str, metric: str, value: Any) -> None: + if worker in self.data: + self.data[worker].set_metric(domain, metric, value) + else: + metrics = WorkerMetrics(worker) + metrics.set_metric(domain, metric, value) + self.data[worker] = metrics + + def set_online(self, worker: str) -> None: + if worker in self.data: + self.data[worker].online = True + else: + metrics = WorkerMetrics(worker) + metrics.online = True + self.data[worker] = metrics + + def set_offline(self, worker: str) -> None: + if worker in self.data: + self.data[worker].online = False + + def get_total_peers(self) -> int: + total = 0 + for worker in self.data: + worker_data = self.data.get(worker) + if not worker_data: + continue + for domain in worker_data.domain_data: + total += max( + worker_data.get_domain_metrics(domain).get( + CONNECTED_PEERS_METRIC, 0 + ), + 0, + ) + + return total + + def get_best_worker(self, domain: str) -> Tuple[Optional[str], int, int]: + """Analyzes the metrics and determines the best worker that a new client should connect to. + The best worker is defined as the one with the most number of clients missing + to its should-be target value according to its weight. + + Returns: + A 3-tuple containing the worker name, difference to target peers, number of connected peers. + The worker name can be None if none is online. + """ + # Map metrics to a list of (target diff, peer count, worker) tuples for online workers + + peers_worker_tuples = [] + total_peers = self.get_total_peers() + worker_cfg = config.get_config().workers + + for wm in self.data.values(): + if not wm.is_online(domain): + continue + + peers = wm.get_domain_metrics(domain).get(CONNECTED_PEERS_METRIC) + rel_weight = worker_cfg.relative_worker_weight(wm.worker) + target = rel_weight * total_peers + diff = peers - target + logger.debug( + f"Worker {wm.worker}: rel weight {rel_weight}, target {target} (total {total_peers}), diff {diff}" + ) + peers_worker_tuples.append((diff, peers, wm.worker)) + + peers_worker_tuples = sorted(peers_worker_tuples, key=itemgetter(0)) + + if len(peers_worker_tuples) > 0: + best = peers_worker_tuples[0] + return best[2], best[0], best[1] + return None, 0, 0 diff --git a/wgkex/broker/metrics_test.py b/wgkex/broker/metrics_test.py new file mode 100644 index 0000000..520e6a9 --- /dev/null +++ b/wgkex/broker/metrics_test.py @@ -0,0 +1,125 @@ +import unittest + +import mock +from wgkex.config import config +from wgkex.broker.metrics import WorkerMetricsCollection + + +class TestMetrics(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + # Give each test a placeholder config + test_config = config.Config.from_dict( + { + "domains": [], + "domain_prefixes": "", + "workers": {}, + "mqtt": {"broker_url": "", "username": "", "password": ""}, + } + ) + mocked_config = mock.create_autospec(spec=test_config, spec_set=True) + config._parsed_config = mocked_config + + @classmethod + def tearDownClass(cls) -> None: + config._parsed_config = None + + def test_set_online_matches_is_online(self): + """Verify set_online sets worker online and matches result of is_online.""" + worker_metrics = WorkerMetricsCollection() + worker_metrics.set_online("worker1") + + ret = worker_metrics.get("worker1").is_online() + self.assertTrue(ret) + + def test_set_offline_matches_is_online(self): + """Verify set_offline sets worker offline and matches negated result of is_online.""" + worker_metrics = WorkerMetricsCollection() + worker_metrics.set_offline("worker1") + + ret = worker_metrics.get("worker1").is_online() + self.assertFalse(ret) + + def test_unkown_is_offline(self): + """Verify an unkown worker is considered offline.""" + worker_metrics = WorkerMetricsCollection() + + ret = worker_metrics.get("worker1").is_online() + self.assertFalse(ret) + + def test_set_online_matches_is_online_domain(self): + """Verify set_online sets worker online and matches result of is_online with domain.""" + worker_metrics = WorkerMetricsCollection() + worker_metrics.set_online("worker1") + worker_metrics.update("worker1", "d", "connected_peers", 5) + + ret = worker_metrics.get("worker1").is_online("d") + self.assertTrue(ret) + + def test_set_online_matches_is_online_offline_domain(self): + """Verify worker is considered offline if connected_peers for domain is <0.""" + worker_metrics = WorkerMetricsCollection() + worker_metrics.set_online("worker1") + worker_metrics.update("worker1", "d", "connected_peers", -1) + + ret = worker_metrics.get("worker1").is_online("d") + self.assertFalse(ret) + + @mock.patch("wgkex.broker.metrics.config.get_config", autospec=True) + def test_get_best_worker_returns_best(self, config_mock): + """Verify get_best_worker returns the worker with least connected clients for equally weighted workers.""" + test_config = mock.MagicMock(spec=config.Config) + test_config.workers = config.Workers.from_dict({}) + config_mock.return_value = test_config + + worker_metrics = WorkerMetricsCollection() + worker_metrics.update("1", "d", "connected_peers", 20) + worker_metrics.update("2", "d", "connected_peers", 19) + worker_metrics.set_online("1") + worker_metrics.set_online("2") + + (worker, diff, connected) = worker_metrics.get_best_worker("d") + self.assertEqual(worker, "2") + self.assertEqual(diff, -20) # 19-(1*(20+19)) + self.assertEqual(connected, 19) + + @mock.patch("wgkex.broker.metrics.config.get_config", autospec=True) + def test_get_best_worker_weighted_returns_best(self, config_mock): + """Verify get_best_worker returns the worker with least client differential for weighted workers.""" + test_config = mock.MagicMock(spec=config.Config) + test_config.workers = config.Workers.from_dict( + {"1": {"weight": 84}, "2": {"weight": 42}} + ) + config_mock.return_value = test_config + + worker_metrics = WorkerMetricsCollection() + worker_metrics.update("1", "d", "connected_peers", 21) + worker_metrics.update("2", "d", "connected_peers", 19) + worker_metrics.set_online("1") + worker_metrics.set_online("2") + + (worker, _, _) = worker_metrics.get_best_worker("d") + config_mock.assert_called() + self.assertEqual(worker, "1") + + def test_get_best_worker_no_worker_online_returns_none(self): + """Verify get_best_worker returns None if there is no online worker.""" + worker_metrics = WorkerMetricsCollection() + worker_metrics.update("1", "d", "connected_peers", 20) + worker_metrics.update("2", "d", "connected_peers", 19) + worker_metrics.set_offline("1") + worker_metrics.set_offline("2") + + (worker, _, _) = worker_metrics.get_best_worker("d") + self.assertIsNone(worker) + + def test_get_best_worker_no_worker_registered_returns_none(self): + """Verify get_best_worker returns None if there is no online worker.""" + worker_metrics = WorkerMetricsCollection() + + (worker, _, _) = worker_metrics.get_best_worker("d") + self.assertIsNone(worker) + + +if __name__ == "__main__": + unittest.main() diff --git a/wgkex/common/BUILD b/wgkex/common/BUILD index 4a12559..b203348 100644 --- a/wgkex/common/BUILD +++ b/wgkex/common/BUILD @@ -15,7 +15,8 @@ py_test( name = "utils_test", srcs = ["utils_test.py"], deps = [ - ":utils", + "//wgkex/common:utils", + "//wgkex/config:config", requirement("mock"), ], ) @@ -24,4 +25,10 @@ py_library( name = "logger", srcs = ["logger.py"], visibility = ["//visibility:public"] -) \ No newline at end of file +) + +py_library( + name = "mqtt", + srcs = ["mqtt.py"], + visibility = ["//visibility:public"] +) diff --git a/wgkex/common/mqtt.py b/wgkex/common/mqtt.py new file mode 100644 index 0000000..69bf15b --- /dev/null +++ b/wgkex/common/mqtt.py @@ -0,0 +1,6 @@ +"""Common MQTT constants like topic string templates""" + +TOPIC_WORKER_WG_DATA = "wireguard-worker/{worker}/{domain}/data" +TOPIC_WORKER_STATUS = "wireguard-worker/{worker}/status" +CONNECTED_PEERS_METRIC = "connected_peers" +TOPIC_CONNECTED_PEERS = "wireguard-metrics/{domain}/{worker}/" + CONNECTED_PEERS_METRIC diff --git a/wgkex/common/utils.py b/wgkex/common/utils.py index fecebef..8fa201c 100644 --- a/wgkex/common/utils.py +++ b/wgkex/common/utils.py @@ -2,6 +2,8 @@ import ipaddress import re +from wgkex.config import config + def mac2eui64(mac: str, prefix=None) -> str: """Converts a MAC address to an EUI64 identifier. @@ -35,3 +37,20 @@ def mac2eui64(mac: str, prefix=None) -> str: net = ipaddress.ip_network(prefix, strict=False) euil = int(f"0x{eui64:16}", 16) return f"{net[euil]}/{net.prefixlen}" + + +def is_valid_domain(domain: str) -> bool: + """Verifies if the domain is configured. + + Arguments: + domain: The domain to verify. + + Returns: + True if the domain is valid, False otherwise. + """ + if not domain in config.get_config().domains: + return False + for prefix in config.get_config().domain_prefixes: + if domain.startswith(prefix): + return True + return False diff --git a/wgkex/common/utils_test.py b/wgkex/common/utils_test.py index e14b174..a0aa187 100644 --- a/wgkex/common/utils_test.py +++ b/wgkex/common/utils_test.py @@ -1,5 +1,5 @@ import unittest -import utils +from wgkex.common import utils class UtilsTest(unittest.TestCase): diff --git a/wgkex/config/BUILD b/wgkex/config/BUILD index 1ca5fb3..8167c58 100644 --- a/wgkex/config/BUILD +++ b/wgkex/config/BUILD @@ -16,7 +16,7 @@ py_test( name="config_test", srcs=["config_test.py"], deps=[ - ":config", + "//wgkex/config:config", requirement("mock"), ], ) diff --git a/wgkex/config/__init__.py b/wgkex/config/__init__.py index 1b48be8..9c9cace 100644 --- a/wgkex/config/__init__.py +++ b/wgkex/config/__init__.py @@ -1,3 +1,3 @@ -from wgkex.config.config import load_config +from wgkex.config.config import get_config -__all__ = ["load_config"] +__all__ = ["get_config"] diff --git a/wgkex/config/config.py b/wgkex/config/config.py index 0659d69..b0239e2 100644 --- a/wgkex/config/config.py +++ b/wgkex/config/config.py @@ -1,11 +1,11 @@ """Configuration handling class.""" +import dataclasses import logging import os import sys +from typing import Dict, Any, List, Optional + import yaml -from functools import lru_cache -from typing import Dict, Union, Any, List, Optional -import dataclasses class Error(Exception): @@ -20,9 +20,78 @@ class ConfigFileNotFoundError(Error): WG_CONFIG_DEFAULT_LOCATION = "/etc/wgkex.yaml" +@dataclasses.dataclass +class Worker: + """A representation of the values of the 'workers' dict in the configuration file. + + Attributes: + weight: The relative weight of a worker, defaults to 1. + """ + + weight: int + + @classmethod + def from_dict(cls, worker_cfg: Dict[str, Any]) -> "Worker": + return cls( + weight=int(worker_cfg["weight"]) if worker_cfg["weight"] else 1, + ) + + +@dataclasses.dataclass +class Workers: + """A representation of the 'workers' key in the configuration file. + + Attributes: + total_weight: Calculated on init, the total weight of all configured workers. + """ + + total_weight: int + _workers: Dict[str, Worker] + + @classmethod + def from_dict(cls, workers_cfg: Dict[str, Dict[str, Any]]) -> "Workers": + d = {key: Worker.from_dict(value) for (key, value) in workers_cfg.items()} + + total = 0 + for worker in d.values(): + total += worker.weight + total = max(total, 1) + + return cls(total_weight=total, _workers=d) + + def get(self, worker: str) -> Optional[Worker]: + return self._workers.get(worker) + + def relative_worker_weight(self, worker_name: str) -> float: + worker = self.get(worker_name) + if worker is None: + return 1 / self.total_weight + return worker.weight / self.total_weight + + +@dataclasses.dataclass +class BrokerListen: + """A representation of the 'broker_listen' key in Configuration file. + + Attributes: + host: The listen address the broker should listen to for the HTTP API. + port: The port the broker should listen to for the HTTP API. + """ + + host: Optional[str] + port: Optional[int] + + @classmethod + def from_dict(cls, broker_listen: Dict[str, Any]) -> "BrokerListen": + return cls( + host=broker_listen.get("host"), + port=broker_listen.get("port"), + ) + + @dataclasses.dataclass class MQTT: - """A representation of MQTT key in Configuration file. + """A representation of the 'mqtt' key in Configuration file. Attributes: broker_url: The broker URL for MQTT to connect to. @@ -54,11 +123,9 @@ def from_dict(cls, mqtt_cfg: Dict[str, str]) -> "MQTT": broker_url=mqtt_cfg["broker_url"], username=mqtt_cfg["username"], password=mqtt_cfg["password"], - tls=mqtt_cfg["tls"] if mqtt_cfg["tls"] else False, - broker_port=int(mqtt_cfg["broker_port"]) - if mqtt_cfg["broker_port"] - else None, - keepalive=int(mqtt_cfg["keepalive"]) if mqtt_cfg["keepalive"] else None, + tls=bool(mqtt_cfg.get("tls", cls.tls)), + broker_port=int(mqtt_cfg.get("broker_port", cls.broker_port)), + keepalive=int(mqtt_cfg.get("keepalive", cls.keepalive)), ) @@ -68,59 +135,72 @@ class Config: Attributes: domains: The list of domains to listen for. + domain_prefixes: The list of prefixes to pre-pend to a given domain. mqtt: The MQTT configuration. - domain_prefixes: The list of prefixes to pre-pend to a given domain.""" + workers: The worker weights configuration (broker-only). + externalName: The publicly resolvable domain name or public IP address of this worker (worker-only). + """ + raw: Dict[str, Any] domains: List[str] - mqtt: MQTT domain_prefixes: List[str] + broker_listen: BrokerListen + mqtt: MQTT + workers: Workers + external_name: Optional[str] @classmethod - def from_dict(cls, cfg: Dict[str, str]) -> "Config": + def from_dict(cls, cfg: Dict[str, Any]) -> "Config": """Creates a Config object from a configuration file. Arguments: cfg: The configuration file as a dict. Returns: A Config object. """ + broker_listen = BrokerListen.from_dict(cfg.get("broker_listen", {})) mqtt_cfg = MQTT.from_dict(cfg["mqtt"]) + workers_cfg = Workers.from_dict(cfg.get("workers", {})) return cls( + raw=cfg, domains=cfg["domains"], - mqtt=mqtt_cfg, domain_prefixes=cfg["domain_prefixes"], + broker_listen=broker_listen, + mqtt=mqtt_cfg, + workers=workers_cfg, + external_name=cfg.get("externalName"), ) + def get(self, key: str) -> Any: + """Get the value of key from the raw dict representation of the config file""" + return self.raw.get(key) -@lru_cache(maxsize=10) -def fetch_from_config(key: str) -> Optional[Union[Dict[str, Any], List[str]]]: - """Fetches a specific key from configuration. - Arguments: - key: The named key to fetch. - Returns: - The config value associated with the key - """ - return load_config().get(key) +_parsed_config: Optional[Config] = None -def load_config() -> Dict[str, str]: - """Fetches and validates configuration file from disk. +def get_config() -> Config: + """Returns a parsed Config object. + Raises: + ConfigFileNotFoundError: If we could not find the configuration file on disk. Returns: - Linted configuration file. + The Config representation of the config file """ - cfg_contents = fetch_config_from_disk() - try: - config = yaml.safe_load(cfg_contents) - except yaml.YAMLError as e: - print("Failed to load YAML file: %s", e) - sys.exit(1) - try: - _ = Config.from_dict(config) - return config - except (KeyError, TypeError) as e: - print("Failed to lint file: %s", e) - sys.exit(2) + global _parsed_config + if _parsed_config is None: + cfg_contents = fetch_config_from_disk() + try: + config = yaml.safe_load(cfg_contents) + except yaml.YAMLError as e: + print("Failed to load YAML file: %s" % e) + sys.exit(1) + try: + config = Config.from_dict(config) + except (KeyError, TypeError, AttributeError) as e: + print("Failed to lint file: %s" % e) + sys.exit(2) + _parsed_config = config + return _parsed_config def fetch_config_from_disk() -> str: diff --git a/wgkex/config/config_test.py b/wgkex/config/config_test.py index 3c33148..6e30eb3 100644 --- a/wgkex/config/config_test.py +++ b/wgkex/config/config_test.py @@ -1,9 +1,10 @@ """Tests for configuration handling class.""" import unittest import mock -import config import yaml +from wgkex.config import config + _VALID_CFG = ( "domain_prefixes:\n- ffmuc_\n- ffdon_\n- ffwert_\nlog_level: DEBUG\ndomains:\n- a\n- b\nmqtt:\n broker_port: 1883" "\n broker_url: mqtt://broker\n keepalive: 5\n password: pass\n tls: true\n username: user\n" @@ -16,18 +17,22 @@ class TestConfig(unittest.TestCase): + def tearDown(self) -> None: + config._parsed_config = None + return super().tearDown() + def test_load_config_success(self): """Test loads and lint config successfully.""" mock_open = mock.mock_open(read_data=_VALID_CFG) with mock.patch("builtins.open", mock_open): - self.assertDictEqual(yaml.safe_load(_VALID_CFG), config.load_config()) + self.assertDictEqual(yaml.safe_load(_VALID_CFG), config.get_config().raw) @mock.patch.object(config.sys, "exit", autospec=True) def test_load_config_fails_good_yaml_bad_format(self, exit_mock): """Test loads yaml successfully and fails lint.""" mock_open = mock.mock_open(read_data=_INVALID_LINT) with mock.patch("builtins.open", mock_open): - config.load_config() + config.get_config() exit_mock.assert_called_with(2) @mock.patch.object(config.sys, "exit", autospec=True) @@ -35,7 +40,7 @@ def test_load_config_fails_bad_yaml(self, exit_mock): """Test loads bad YAML.""" mock_open = mock.mock_open(read_data=_INVALID_CFG) with mock.patch("builtins.open", mock_open): - config.load_config() + config.get_config() exit_mock.assert_called_with(2) def test_fetch_config_from_disk_success(self): @@ -52,17 +57,17 @@ def test_fetch_config_from_disk_fails_file_not_found(self): with self.assertRaises(config.ConfigFileNotFoundError): config.fetch_config_from_disk() - def test_fetch_from_config_success(self): + def test_raw_get_success(self): """Test fetch key from configuration.""" mock_open = mock.mock_open(read_data=_VALID_CFG) with mock.patch("builtins.open", mock_open): - self.assertListEqual(["a", "b"], config.fetch_from_config("domains")) + self.assertListEqual(["a", "b"], config.get_config().raw.get("domains")) - def test_fetch_from_config_no_key_in_config(self): + def test_raw_get_no_key_in_config(self): """Test fetch non-existent key from configuration.""" mock_open = mock.mock_open(read_data=_VALID_CFG) with mock.patch("builtins.open", mock_open): - self.assertIsNone(config.fetch_from_config("key_does_not_exist")) + self.assertIsNone(config.get_config().raw.get("key_does_not_exist")) if __name__ == "__main__": diff --git a/wgkex/worker/BUILD b/wgkex/worker/BUILD index 80a82eb..b1d9b6d 100644 --- a/wgkex/worker/BUILD +++ b/wgkex/worker/BUILD @@ -21,8 +21,9 @@ py_test( name = "netlink_test", srcs = ["netlink_test.py"], deps = [ - ":netlink", + "//wgkex/worker:netlink", requirement("mock"), + requirement("pyroute2"), ], ) @@ -34,8 +35,9 @@ py_library( requirement("NetLink"), requirement("paho-mqtt"), requirement("pyroute2"), - "//wgkex/common:utils", "//wgkex/common:logger", + "//wgkex/common:mqtt", + "//wgkex/common:utils", "//wgkex/config:config", ":msg_queue", ":netlink", @@ -46,8 +48,8 @@ py_test( name = "mqtt_test", srcs = ["mqtt_test.py"], deps = [ - ":mqtt", - ":msg_queue", + "//wgkex/worker:mqtt", + "//wgkex/worker:msg_queue", requirement("mock"), ], ) @@ -67,8 +69,8 @@ py_test( name = "app_test", srcs = ["app_test.py"], deps = [ - ":app", - ":msg_queue", + "//wgkex/worker:app", + "//wgkex/worker:msg_queue", requirement("mock"), ], ) @@ -80,4 +82,4 @@ py_library( deps = [ "//wgkex/common:logger", ], -) \ No newline at end of file +) diff --git a/wgkex/worker/app.py b/wgkex/worker/app.py index 70aa8fc..432955c 100644 --- a/wgkex/worker/app.py +++ b/wgkex/worker/app.py @@ -1,13 +1,17 @@ """Initialises the MQTT worker.""" -import wgkex.config.config as config +import signal +import sys +import threading +import time +from typing import Text + +from wgkex.common import logger +from wgkex.common.utils import is_valid_domain +from wgkex.config import config from wgkex.worker import mqtt from wgkex.worker.msg_queue import watch_queue from wgkex.worker.netlink import wg_flush_stale_peers -import time -import threading -from wgkex.common import logger -from typing import List, Text _CLEANUP_TIME = 3600 @@ -28,22 +32,32 @@ class DomainsAreNotUnique(Error): """If non-unique domains exist in configuration file.""" +class InvalidDomain(Error): + """If the domains is invalid and is not listed in the configuration file.""" + + def flush_workers(domain: Text) -> None: """Calls peer flush every _CLEANUP_TIME interval.""" while True: - time.sleep(_CLEANUP_TIME) - logger.info(f"Running cleanup task for {domain}") - logger.info("Cleaned up domains: %s", wg_flush_stale_peers(domain)) + try: + time.sleep(_CLEANUP_TIME) + logger.info(f"Running cleanup task for {domain}") + logger.info("Cleaned up domains: %s", wg_flush_stale_peers(domain)) + except Exception as e: + # Don't crash the thread when an exception is encountered + logger.error(f"Exception during cleanup task for {domain}:") + logger.error(e) -def clean_up_worker(domains: List[Text]) -> None: +def clean_up_worker() -> None: """Wraps flush_workers in a thread for all given domains. Arguments: domains: list of domains. """ + domains = config.get_config().domains + prefixes = config.get_config().domain_prefixes logger.debug("Cleaning up the following domains: %s", domains) - prefixes = config.load_config().get("domain_prefixes") cleanup_counter = 0 # ToDo: do we need a check if every domain got gleaned? for prefix in prefixes: @@ -60,7 +74,9 @@ def clean_up_worker(domains: List[Text]) -> None: domain, ) continue - thread = threading.Thread(target=flush_workers, args=(cleaned_domain,)) + thread = threading.Thread( + target=flush_workers, args=(cleaned_domain,), daemon=True + ) thread.start() if cleanup_counter < len(domains): logger.error( @@ -89,8 +105,7 @@ def check_all_domains_unique(domains, prefixes): stripped_domain = domain.split(prefix)[1] if stripped_domain in unique_domains: logger.error( - "We have a non-unique domain here", - domain, + f"Domain {domain} is not unique after stripping the prefix" ) return False unique_domains.append(stripped_domain) @@ -104,15 +119,28 @@ def main(): DomainsNotInConfig: If no domains were found in configuration file. DomainsAreNotUnique: If there were non-unique domains after stripping prefix """ - domains = config.load_config().get("domains") - prefixes = config.load_config().get("domain_prefixes") + exit_event = threading.Event() + + def on_exit(sig_number, stack_frame) -> None: + logger.info("Shutting down...") + exit_event.set() + time.sleep(2) + sys.exit() + + signal.signal(signal.SIGINT, on_exit) + + domains = config.get_config().domains + prefixes = config.get_config().domain_prefixes if not domains: raise DomainsNotInConfig("Could not locate domains in configuration.") if not check_all_domains_unique(domains, prefixes): raise DomainsAreNotUnique("There are non-unique domains! Check config.") - clean_up_worker(domains) + for domain in domains: + if not is_valid_domain(domain): + raise InvalidDomain(f"Domain {domain} has invalid prefix.") + clean_up_worker() watch_queue() - mqtt.connect() + mqtt.connect(exit_event) if __name__ == "__main__": diff --git a/wgkex/worker/app_test.py b/wgkex/worker/app_test.py index 111590b..04cc6fb 100644 --- a/wgkex/worker/app_test.py +++ b/wgkex/worker/app_test.py @@ -1,7 +1,20 @@ """Unit tests for app.py""" +import threading +from time import sleep import unittest import mock -import app + +from wgkex.worker import app + + +def _get_config_mock(domains=None): + test_prefixes = ["_TEST_PREFIX_", "_TEST_PREFIX2_"] + config_mock = mock.MagicMock() + config_mock.domains = ( + domains if domains is not None else [f"{test_prefixes[1]}domain.one"] + ) + config_mock.domain_prefixes = test_prefixes + return config_mock class AppTest(unittest.TestCase): @@ -48,49 +61,57 @@ def test_unique_domains_not_list(self): with self.assertRaises(TypeError): app.check_all_domains_unique(test_domains, test_prefixes) - @mock.patch.object(app.config, "load_config") + @mock.patch.object(app.config, "get_config") @mock.patch.object(app.mqtt, "connect", autospec=True) def test_main_success(self, connect_mock, config_mock): """Ensure we can execute main.""" connect_mock.return_value = None - test_prefixes = ["TEST_PREFIX_", "TEST_PREFIX2_"] - config_mock.return_value = dict( - domains=[f"{test_prefixes[1]}domain.one"], domain_prefixes=test_prefixes - ) - with mock.patch("app.flush_workers", return_value=None): + config_mock.return_value = _get_config_mock() + with mock.patch.object(app, "flush_workers", return_value=None): app.main() - connect_mock.assert_called_with() + connect_mock.assert_called() - @mock.patch.object(app.config, "load_config") + @mock.patch.object(app.config, "get_config") @mock.patch.object(app.mqtt, "connect", autospec=True) def test_main_fails_no_domain(self, connect_mock, config_mock): """Ensure we fail when domains are not configured.""" - config_mock.return_value = dict(domains=None) + config_mock.return_value = _get_config_mock(domains=[]) connect_mock.return_value = None with self.assertRaises(app.DomainsNotInConfig): app.main() - @mock.patch.object(app.config, "load_config") + @mock.patch.object(app.config, "get_config") @mock.patch.object(app.mqtt, "connect", autospec=True) def test_main_fails_bad_domain(self, connect_mock, config_mock): """Ensure we fail when domains are badly formatted.""" - test_prefixes = ["TEST_PREFIX_", "TEST_PREFIX2_"] - config_mock.return_value = dict( - domains=[f"cant_split_domain"], domain_prefixes=test_prefixes - ) + config_mock.return_value = _get_config_mock(domains=["cant_split_domain"]) connect_mock.return_value = None - with mock.patch("app.flush_workers", return_value=None): + with self.assertRaises(app.InvalidDomain): app.main() - connect_mock.assert_called_with() + connect_mock.assert_not_called() - @mock.patch("time.sleep", side_effect=InterruptedError) - @mock.patch("app.wg_flush_stale_peers") - def test_flush_workers(self, flush_mock, sleep_mock): - """Ensure we fail when domains are badly formatted.""" - flush_mock.return_value = "" - # Infinite loop in flush_workers has no exit value, so test will generate one, and test for that. - with self.assertRaises(InterruptedError): - app.flush_workers("test_domain") + @mock.patch.object(app, "_CLEANUP_TIME", 0) + @mock.patch.object(app, "wg_flush_stale_peers") + def test_flush_workers_doesnt_throw(self, wg_flush_mock): + """Ensure the flush_workers thread doesn't throw and exit if it encounters an exception.""" + wg_flush_mock.side_effect = AttributeError( + "'NoneType' object has no attribute 'get'" + ) + + thread = threading.Thread( + target=app.flush_workers, args=("dummy_domain",), daemon=True + ) + thread.start() + + i = 0 + while i < 20 and not wg_flush_mock.called: + i += 1 + sleep(0.1) + + wg_flush_mock.assert_called() + # Assert that the thread hasn't crashed and is still running + self.assertTrue(thread.is_alive()) + # If Python would allow it without writing custom signalling, this would be the place to stop the thread again if __name__ == "__main__": diff --git a/wgkex/worker/mqtt.py b/wgkex/worker/mqtt.py index 1c1cf31..caf7011 100644 --- a/wgkex/worker/mqtt.py +++ b/wgkex/worker/mqtt.py @@ -1,58 +1,106 @@ #!/usr/bin/env python3 """Process messages from MQTT.""" -import paho.mqtt.client as mqtt - # TODO(ruairi): Deprecate __init__.py from config, as it masks namespace. -from wgkex.config.config import load_config -import socket +import json import re -from typing import Optional, Dict, Any, Union -from wgkex.common import logger -from wgkex.worker.msg_queue import q - - -def fetch_from_config(var: str) -> Optional[Union[Dict[str, str], str]]: - """Fetches values from configuration file. +import socket +import threading +from typing import Any, Optional - Arguments: - var: The variable to fetch from config. +import paho.mqtt.client as mqtt - Raises: - ValueError: If given key cannot be found in configuration. +from wgkex.common import logger +from wgkex.common.mqtt import ( + TOPIC_CONNECTED_PEERS, + TOPIC_WORKER_STATUS, + TOPIC_WORKER_WG_DATA, +) +from wgkex.config.config import get_config +from wgkex.worker.msg_queue import q +from wgkex.worker.netlink import ( + get_device_data, + link_handler, + get_connected_peers_count, + WireGuardClient, +) - Returns: - The given variable from configuration. - """ - config = load_config() - ret = config.get(var) - if not ret: - raise ValueError("Failed to get %s from configuration, failing", var) - return config.get(var) +_HOSTNAME = socket.gethostname() +_METRICS_SEND_INTERVAL = 60 -def connect() -> None: - """Connect to MQTT for the given domains. +def connect(exit_event: threading.Event) -> None: + """Connect to MQTT. Argument: - domains: The domains to connect to. + exit_event: A threading.Event that signals application shutdown. """ - base_config = fetch_from_config("mqtt") - broker_address = base_config.get("broker_url") - broker_port = base_config.get("broker_port") - broker_keepalive = base_config.get("keepalive") - # TODO(ruairi): Move the hostname to a global variable. - client = mqtt.Client(socket.gethostname()) + base_config = get_config().mqtt + broker_address = base_config.broker_url + broker_port = base_config.broker_port + broker_keepalive = base_config.keepalive + client = mqtt.Client(_HOSTNAME) + domains = get_config().domains + + # Register LWT to set worker status down when lossing connection + client.will_set(TOPIC_WORKER_STATUS.format(worker=_HOSTNAME), 0, qos=1, retain=True) # Register handlers client.on_connect = on_connect + client.on_disconnect = on_disconnect client.on_message = on_message + client.message_callback_add("wireguard/#", on_message_wireguard) logger.info("connecting to broker %s", broker_address) client.connect(broker_address, port=broker_port, keepalive=broker_keepalive) + + # Start background threads that should not be restarted on reconnect + + # Mark worker as offline on graceful shutdown, after exit_event is set + def mark_offline_on_exit(exit_event: threading.Event): + exit_event.wait() + if client.is_connected(): + logger.info("Marking worker as down") + client.publish( + TOPIC_WORKER_STATUS.format(worker=_HOSTNAME), 0, qos=1, retain=True + ) + + mark_offline_on_exit_thread = threading.Thread( + target=mark_offline_on_exit, args=(exit_event,) + ) + mark_offline_on_exit_thread.start() + + for domain in domains: + # Schedule repeated metrics publishing + metrics_thread = threading.Thread( + target=publish_metrics_loop, args=(exit_event, client, domain) + ) + metrics_thread.start() + client.loop_forever() +def on_disconnect(client: mqtt.Client, userdata: Any, rc): + """Handles MQTT disconnect and logs the event + + Expected signature for MQTT v3.1.1 and v3.1 is: + disconnect_callback(client, userdata, rc) + + and for MQTT v5.0: + disconnect_callback(client, userdata, reasonCode, properties) + + Arguments: + client: the client instance for this callback + userdata: the private user data as set in Client() or userdata_set() + rc: the disconnection result + The rc parameter indicates the disconnection state. If + MQTT_ERR_SUCCESS (0), the callback was called in response to + a disconnect() call. If any other value the disconnection + was unexpected, such as might be caused by a network error. + """ + logger.debug("Disconnected with result code " + str(rc)) + + # The callback for when the client receives a CONNACK response from the server. def on_connect(client: mqtt.Client, userdata: Any, flags, rc) -> None: """Handles MQTT connect and subscribes to topics on connect @@ -64,17 +112,61 @@ def on_connect(client: mqtt.Client, userdata: Any, flags, rc) -> None: rc: The MQTT rc. """ logger.debug("Connected with result code " + str(rc)) - domains = load_config().get("domains") + domains = get_config().domains + + own_external_name = ( + get_config().external_name + if get_config().external_name is not None + else _HOSTNAME + ) - # Subscribing in on_connect() means that if we lose the connection and - # reconnect then subscriptions will be renewed. for domain in domains: + # Subscribing in on_connect() means that if we lose the connection and + # reconnect then subscriptions will be renewed. topic = f"wireguard/{domain}/+" logger.info(f"Subscribing to topic {topic}") client.subscribe(topic) + # Publish worker data (WG pubkeys, ports, local addresses) + iface = wg_interface_name(domain) + if iface: + (port, public_key, link_address) = get_device_data(iface) + data = { + "ExternalAddress": own_external_name, + "Port": port, + "PublicKey": public_key, + "LinkAddress": link_address, + } + client.publish( + TOPIC_WORKER_WG_DATA.format(worker=_HOSTNAME, domain=domain), + json.dumps(data), + qos=1, + retain=True, + ) + else: + logger.error( + f"Could not get interface name for domain {domain}. Skipping worker data publication" + ) + + # Mark worker as online + client.publish(TOPIC_WORKER_STATUS.format(worker=_HOSTNAME), 1, qos=1, retain=True) + def on_message(client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage) -> None: + """Fallback handler for MQTT messages that do not match any other registered handler. + + Arguments: + client: the client instance for this callback. + userdata: the private user data. + message: The MQTT message. + """ + logger.info("Got unhandled message on %s from MQTT", message.topic) + return + + +def on_message_wireguard( + client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage +) -> None: """Processes messages from MQTT and forwards them to netlink. Arguments: @@ -83,16 +175,17 @@ def on_message(client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage) -> message: The MQTT message. """ # TODO(ruairi): Check bounds and raise exception here. - logger.debug("Got message %s from MTQQ", message) - domain_prefixes = load_config().get("domain_prefixes") + logger.debug("Got message on %s from MQTT", message.topic) + + domain_prefixes = get_config().domain_prefixes domain = None for domain_prefix in domain_prefixes: - domain = re.search(r"/.*" + domain_prefix + "(\w+)/", message.topic) + domain = re.search(r".*/" + domain_prefix + r"(\w+)/", message.topic) if domain: break if not domain: raise ValueError( - "Could not find a match for %s on %s", repr(domain_prefixes), message.topic + f"Could not find a match for {domain_prefixes} on {message.topic}" ) # this will not work, if we have non-unique prefix stripped domains domain = domain.group(1) @@ -101,3 +194,55 @@ def on_message(client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage) -> f"Received create message for key {str(message.payload.decode('utf-8'))} on domain {domain} adding to queue" ) q.put((domain, str(message.payload.decode("utf-8")))) + + +def publish_metrics_loop( + exit_event: threading.Event, client: mqtt.Client, domain: str +) -> None: + """Continuously send metrics every METRICS_SEND_INTERVAL seconds for this gateway and the given domain.""" + logger.info("Scheduling metrics task for %s, ", domain) + + topic = TOPIC_CONNECTED_PEERS.format(domain=domain, worker=_HOSTNAME) + + while not exit_event.is_set(): + publish_metrics(client, topic, domain) + # This drifts slightly over time, doesn't matter for us + exit_event.wait(_METRICS_SEND_INTERVAL) + + # Set peers metric to -1 to mark worker as offline + # Use QoS 1 (at least once) to make sure the broker notices + client.publish(topic, -1, qos=1, retain=True) + + +def publish_metrics(client: mqtt.Client, topic: str, domain: str) -> None: + """Publish metrics for this gateway and the given domain. + + The metrics currently only consist of the number of connected peers. + """ + logger.debug(f"Publishing metrics for domain {domain}") + iface = wg_interface_name(domain) + if not iface: + logger.error( + f"Could not get interface name for domain {domain}. Skipping metrics publication" + ) + return + peer_count = get_connected_peers_count(iface) + + # Publish metrics, retain it at MQTT broker so restarted wgkex broker has metrics right away + client.publish(topic, peer_count, retain=True) + + +def wg_interface_name(domain: str) -> Optional[str]: + """Calculates the WireGuard interface name for a domain""" + domain_prefixes = get_config().domain_prefixes + cleaned_domain = None + for prefix in domain_prefixes: + try: + cleaned_domain = domain.split(prefix)[1] + except IndexError: + continue + break + if not cleaned_domain: + raise ValueError(f"Could not find a match for {domain_prefixes} on {domain}") + # this will not work, if we have non-unique prefix stripped domains + return f"wg-{cleaned_domain}" diff --git a/wgkex/worker/mqtt_test.py b/wgkex/worker/mqtt_test.py index 8e2fcbf..8bd6672 100644 --- a/wgkex/worker/mqtt_test.py +++ b/wgkex/worker/mqtt_test.py @@ -1,89 +1,130 @@ """Unit tests for mqtt.py""" +import socket +import threading import unittest +from time import sleep + import mock -import mqtt -import msg_queue +import paho.mqtt.client +from wgkex.common.mqtt import TOPIC_CONNECTED_PEERS +from wgkex.worker import mqtt + + +def _get_config_mock(domains=None, mqtt=None): + test_prefixes = ["_ffmuc_", "_TEST_PREFIX2_"] + config_mock = mock.MagicMock() + config_mock.domains = ( + domains if domains is not None else [f"{test_prefixes[0]}domain.one"] + ) + config_mock.domain_prefixes = test_prefixes + if mqtt: + config_mock.mqtt = mqtt + return config_mock -class MQTTTest(unittest.TestCase): - @mock.patch.object(mqtt, "load_config") - def test_fetch_from_config_success(self, config_mock): - """Ensure we can fetch a value from config.""" - config_mock.return_value = dict(key="value") - self.assertEqual("value", mqtt.fetch_from_config("key")) - - @mock.patch.object(mqtt, "load_config") - def test_fetch_from_config_fails_no_key(self, config_mock): - """Tests we fail with ValueError for missing key in config.""" - config_mock.return_value = dict(key="value") - with self.assertRaises(ValueError): - mqtt.fetch_from_config("does_not_exist") +class MQTTTest(unittest.TestCase): @mock.patch.object(mqtt.mqtt, "Client") @mock.patch.object(mqtt.socket, "gethostname") - @mock.patch.object(mqtt, "load_config") + @mock.patch.object(mqtt, "get_config") def test_connect_success(self, config_mock, hostname_mock, mqtt_mock): """Tests successful connection to MQTT server.""" hostname_mock.return_value = "hostname" - config_mock.return_value = dict(mqtt={"broker_url": "some_url"}) - mqtt.connect() + config_mqtt_mock = mock.MagicMock() + config_mqtt_mock.broker_url = "some_url" + config_mqtt_mock.broker_port = 1833 + config_mqtt_mock.keepalive = False + config_mock.return_value = _get_config_mock(mqtt=config_mqtt_mock) + ee = threading.Event() + mqtt.connect(ee) + ee.set() mqtt_mock.assert_has_calls( - [mock.call().connect("some_url", port=None, keepalive=None)], + [mock.call().connect("some_url", port=1833, keepalive=False)], any_order=True, ) @mock.patch.object(mqtt.mqtt, "Client") - @mock.patch.object(mqtt, "load_config") + @mock.patch.object(mqtt, "get_config") def test_connect_fails_mqtt_error(self, config_mock, mqtt_mock): """Tests failure for connect - ValueError.""" mqtt_mock.side_effect = ValueError("barf") - config_mock.return_value = dict(mqtt={"broker_url": "some_url"}) + config_mqtt_mock = mock.MagicMock() + config_mqtt_mock.broker_url = "some_url" + config_mock.return_value = _get_config_mock(mqtt=config_mqtt_mock) with self.assertRaises(ValueError): - mqtt.connect() + mqtt.connect(threading.Event()) + @mock.patch.object(mqtt, "get_config") + @mock.patch.object(mqtt, "get_connected_peers_count") + def test_publish_metrics_loop_success(self, conn_peers_mock, config_mock): + config_mock.return_value = _get_config_mock() + conn_peers_mock.return_value = 20 + mqtt_client = mock.MagicMock(spec=paho.mqtt.client.Client) -""" @mock.patch.object(msg_queue, "link_handler") - @mock.patch.object(mqtt, "load_config") - def test_on_message_success(self, config_mock, link_mock): - config_mock.return_value = {"domain_prefix": "_ffmuc_"} - link_mock.return_value = dict(WireGuard="result") + ee = threading.Event() + thread = threading.Thread( + target=mqtt.publish_metrics_loop, + args=(ee, mqtt_client, "_ffmuc_domain.one"), + ) + thread.start() + + i = 0 + while i < 20 and not mqtt_client.publish.called: + i += 1 + sleep(0.1) + + conn_peers_mock.assert_called_with("wg-domain.one") + mqtt_client.publish.assert_called_with( + TOPIC_CONNECTED_PEERS.format( + domain="_ffmuc_domain.one", worker=socket.gethostname() + ), + 20, + retain=True, + ) + + ee.set() + + i = 0 + while i < 20 and thread.is_alive(): + i += 1 + sleep(0.1) + + self.assertFalse(thread.is_alive()) + + @mock.patch.object(mqtt, "get_config") + def test_on_message_wireguard_success(self, config_mock): + # Tests on_message for success. + config_mock.return_value = _get_config_mock() mqtt_msg = mock.patch.object(mqtt.mqtt, "MQTTMessage") - mqtt_msg.topic = "/_ffmuc_domain1/" + mqtt_msg.topic = "wireguard/_ffmuc_domain1/gateway" mqtt_msg.payload = b"PUB_KEY" - mqtt.on_message(None, None, mqtt_msg) - link_mock.assert_has_calls( - [ - mock.call( - msg_queue.WireGuardClient( - public_key="PUB_KEY", domain="domain1", remove=False - ) - ) - ], - any_order=True, - ) + mqtt.on_message_wireguard(None, None, mqtt_msg) + self.assertTrue(mqtt.q.qsize() > 0) + item = mqtt.q.get_nowait() + self.assertEqual(item, ("domain1", "PUB_KEY")) + - @mock.patch.object(msg_queue, "link_handler") - @mock.patch.object(mqtt, "load_config") - def test_on_message_fails_no_domain(self, config_mock, link_mock): - config_mock.return_value = { - "domain_prefix": "ffmuc_", - "log_level": "DEBUG", - "domains": ["a", "b"], - "mqtt": { - "broker_port": 1883, - "broker_url": "mqtt://broker", - "keepalive": 5, - "password": "pass", - "tls": True, - "username": "user", - }, - } +""" @mock.patch.object(msg_queue, "link_handler") + @mock.patch.object(mqtt, "get_config") + def test_on_message_wireguard_fails_no_domain(self, config_mock, link_mock): + # Tests on_message for failure to parse domain. + config_mqtt_mock = mock.MagicMock() + config_mqtt_mock.broker_url = "mqtt://broker" + config_mqtt_mock.broker_port = 1883 + config_mqtt_mock.keepalive = 5 + config_mqtt_mock.password = "pass" + config_mqtt_mock.tls = True + config_mqtt_mock.username = "user" + config_mock.return_value = _get_config_mock( + domains=["a", "b"], mqtt=config_mqtt_mock + ) link_mock.return_value = dict(WireGuard="result") mqtt_msg = mock.patch.object(mqtt.mqtt, "MQTTMessage") - mqtt_msg.topic = "bad_domain_match" + mqtt_msg.topic = "wireguard/bad_domain_match" with self.assertRaises(ValueError): - mqtt.on_message(None, None, mqtt_msg) - """ + mqtt.on_message_wireguard(None, None, mqtt_msg) +""" + if __name__ == "__main__": unittest.main() diff --git a/wgkex/worker/netlink.py b/wgkex/worker/netlink.py index d4f0656..366d430 100644 --- a/wgkex/worker/netlink.py +++ b/wgkex/worker/netlink.py @@ -1,12 +1,15 @@ """Functions related to netlink manipulation for Wireguard, IPRoute and FDB on Linux.""" +# See https://docs.pyroute2.org/iproute.html for a documentation of the layout of netlink responses import hashlib import re from dataclasses import dataclass from datetime import datetime from datetime import timedelta from textwrap import wrap -from typing import Dict, List +from typing import Any, Dict, List, Tuple + import pyroute2 +import pyroute2.netlink from wgkex.common.utils import mac2eui64 from wgkex.common import logger @@ -69,13 +72,12 @@ def wg_flush_stale_peers(domain: str) -> List[Dict]: stale_clients = [ stale_client for stale_client in find_stale_wireguard_clients("wg-" + domain) ] - logger.debug("Found stale clients: %s", stale_clients) - logger.info("Searching for stale WireGuard clients.") + logger.debug("Found %s stale clients: %s", len(stale_clients), stale_clients) stale_wireguard_clients = [ WireGuardClient(public_key=stale_client, domain=domain, remove=True) for stale_client in stale_clients ] - logger.debug("Found stable WireGuard clients: %s", stale_wireguard_clients) + logger.debug("Found stale WireGuard clients: %s", stale_wireguard_clients) logger.info("Processing clients.") link_handled = [ link_handler(stale_client) for stale_client in stale_wireguard_clients @@ -191,18 +193,82 @@ def find_stale_wireguard_clients(wg_interface: str) -> List: "Starting search for stale wireguard peers for interface %s.", wg_interface ) with pyroute2.WireGuard() as wg: - all_clients = [] - peers_on_interface = wg.info(wg_interface) - logger.info("Got infos: %s.", peers_on_interface) - for peer in peers_on_interface: - clients = peer.get_attr("WGDEVICE_A_PEERS") - logger.info("Got clients: %s.", clients) - if clients: - all_clients.extend(clients) + all_peers = [] + msgs = wg.info(wg_interface) + logger.debug("Got infos for stale peers: %s.", msgs) + for msg in msgs: + peers = msg.get_attr("WGDEVICE_A_PEERS") + logger.debug("Got clients: %s.", peers) + if peers: + all_peers.extend(peers) ret = [ - client.get_attr("WGPEER_A_PUBLIC_KEY").decode("utf-8") - for client in all_clients - if client.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get("tv_sec", int()) - < three_hrs_in_secs + peer.get_attr("WGPEER_A_PUBLIC_KEY").decode("utf-8") + for peer in all_peers + if (hshk_time := peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME")) is not None + and hshk_time.get("tv_sec", int()) < three_hrs_in_secs ] return ret + + +def get_connected_peers_count(wg_interface: str) -> int: + """Fetches and returns the number of connected peers, i.e. which had recent handshakes. + + Arguments: + wg_interface: The WireGuard interface to query. + + Returns: + # The number of peers which have recently seen a handshake. + """ + three_mins_ago_in_secs = int((datetime.now() - timedelta(minutes=3)).timestamp()) + logger.info("Counting connected wireguard peers for interface %s.", wg_interface) + with pyroute2.WireGuard() as wg: + msgs = wg.info(wg_interface) + logger.debug("Got infos for connected peers: %s.", msgs) + count = 0 + for msg in msgs: + peers = msg.get_attr("WGDEVICE_A_PEERS") + logger.debug("Got clients: %s.", peers) + if peers: + for peer in peers: + if ( + peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get( + "tv_sec", int() + ) + > three_mins_ago_in_secs + ): + count += 1 + + return count + + +def get_device_data(wg_interface: str) -> Tuple[int, str, str]: + """Returns the listening port, public key and local IP address. + + Arguments: + wg_interface: The WireGuard interface to query. + + Returns: + # The listening port, public key, and local IP address of the WireGuard interface. + """ + logger.info("Reading data from interface %s.", wg_interface) + with pyroute2.WireGuard() as wg, pyroute2.NDB() as ndb: + msgs = wg.info(wg_interface) + logger.debug("Got infos for interface data: %s.", msgs) + if len(msgs) > 1: + logger.warning( + "Got multiple messages from netlink, expected one. Using only first one." + ) + info: pyroute2.netlink.nla = msgs[0] + + port = int(info.get_attr("WGDEVICE_A_LISTEN_PORT")) + public_key = info.get_attr("WGDEVICE_A_PUBLIC_KEY").decode("ascii") + link_address = ndb.interfaces[wg_interface].ipaddr[0].get("address") + + logger.debug( + "Interface data: port '%s', public key '%s', link address '%s", + port, + public_key, + link_address, + ) + + return (port, public_key, link_address) diff --git a/wgkex/worker/netlink_test.py b/wgkex/worker/netlink_test.py index aeb4ff3..0874005 100644 --- a/wgkex/worker/netlink_test.py +++ b/wgkex/worker/netlink_test.py @@ -11,9 +11,12 @@ sys.modules["pyroute2"] = mock.MagicMock() sys.modules["pyroute2.WireGuard"] = mock.MagicMock() sys.modules["pyroute2.IPRoute"] = mock.MagicMock() +sys.modules["pyroute2.NDB"] = mock.MagicMock() +sys.modules["pyroute2.netlink"] = mock.MagicMock() from pyroute2 import WireGuard from pyroute2 import IPRoute -import netlink + +from wgkex.worker import netlink _WG_CLIENT_ADD = netlink.WireGuardClient( public_key="public_key", domain="add", remove=False @@ -23,15 +26,31 @@ ) -def _get_wg_mock(key_name, stale_time): - pm = mock.Mock() - pm.get_attr.side_effect = [{"tv_sec": stale_time}, key_name.encode()] +def _get_peer_mock(public_key, last_handshake_time): + def peer_get_attr(attr: str): + if attr == "WGPEER_A_LAST_HANDSHAKE_TIME": + return {"tv_sec": last_handshake_time} + if attr == "WGPEER_A_PUBLIC_KEY": + return public_key.encode() + peer_mock = mock.Mock() - peer_mock.get_attr.side_effect = [[pm]] + peer_mock.get_attr.side_effect = peer_get_attr + return peer_mock + + +def _get_wg_mock(public_key, last_handshake_time): + peer_mock = _get_peer_mock(public_key, last_handshake_time) + + def msg_get_attr(attr: str): + if attr == "WGDEVICE_A_PEERS": + return [peer_mock] + + msg_mock = mock.Mock() + msg_mock.get_attr.side_effect = msg_get_attr wg_instance = WireGuard() wg_info_mock = wg_instance.__enter__.return_value wg_info_mock.set.return_value = {"WireGuard": "set"} - wg_info_mock.info.return_value = [peer_mock] + wg_info_mock.info.return_value = [msg_mock] return wg_info_mock @@ -185,6 +204,47 @@ def test_wg_flush_stale_peers_stale_success(self): "del", dst="fe80::281:16ff:fe49:395e/128", oif=mock.ANY ) + def test_get_connected_peers_count_success(self): + """Tests getting the correct number of connected peers for an interface.""" + peers = [] + for i in range(10): + peer_mock = _get_peer_mock( + "TEST_KEY", + int((datetime.now() - timedelta(minutes=i, seconds=5)).timestamp()), + ) + peers.append(peer_mock) + + def msg_get_attr(attr: str): + if attr == "WGDEVICE_A_PEERS": + return peers + + msg_mock = mock.Mock() + msg_mock.get_attr.side_effect = msg_get_attr + + wg_instance = WireGuard() + wg_info_mock = wg_instance.__enter__.return_value + wg_info_mock.info.return_value = [msg_mock] + + ret = netlink.get_connected_peers_count("wg-welt") + self.assertEqual(ret, 3) + + def test_get_device_data_success(self): + def msg_get_attr(attr: str): + if attr == "WGDEVICE_A_LISTEN_PORT": + return 51820 + if attr == "WGDEVICE_A_PUBLIC_KEY": + return "TEST_PUBLIC_KEY".encode("ascii") + + msg_mock = mock.Mock() + msg_mock.get_attr.side_effect = msg_get_attr + + wg_instance = WireGuard() + wg_info_mock = wg_instance.__enter__.return_value + wg_info_mock.info.return_value = [msg_mock] + + ret = netlink.get_device_data("wg-welt") + self.assertTupleEqual(ret, (51820, "TEST_PUBLIC_KEY", mock.ANY)) + if __name__ == "__main__": unittest.main()