Skip to content

Commit

Permalink
Merge pull request #121 from DasSkelett/loadbalancing-fixes
Browse files Browse the repository at this point in the history
Some fixes for the loadbalancing changes
  • Loading branch information
DasSkelett authored Jan 23, 2024
2 parents 11213a5 + 908efa4 commit b7d6e16
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 17 deletions.
2 changes: 1 addition & 1 deletion wgkex.yaml.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# [broker] The domains that should be accepted by clients and for which matching WireGuard interfaces exist
# [broker, worker] The domains that should be accepted by clients and for which matching WireGuard interfaces exist
domains:
- ffmuc_muc_cty
- ffmuc_muc_nord
Expand Down
13 changes: 10 additions & 3 deletions wgkex/broker/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,15 @@ def wg_api_v2_key_exchange() -> Tuple[Response | Dict, int]:
}
}, 400

# Update number of peers locally to interpolate data between MQTT updates from the worker
# TODO fix data race
current_peers_domain = (
worker_metrics.get(best_worker)
.get_domain_metrics(domain)
.get(CONNECTED_PEERS_METRIC, 0)
)
worker_metrics.update(
best_worker, domain, CONNECTED_PEERS_METRIC, current_peers + 1
best_worker, domain, CONNECTED_PEERS_METRIC, current_peers_domain + 1
)
logger.debug(
f"Chose worker {best_worker} with {current_peers} connected clients ({diff})"
Expand Down Expand Up @@ -200,10 +207,10 @@ def handle_mqtt_message_status(
_, worker, _ = message.topic.split("/", 2)

status = int(message.payload)
if status < 1:
if status < 1 and worker_metrics.get(worker).is_online():
logger.warning(f"Marking worker as offline: {worker}")
worker_metrics.set_offline(worker)
else:
elif status >= 1 and not worker_metrics.get(worker).is_online():
logger.warning(f"Marking worker as online: {worker}")
worker_metrics.set_online(worker)

Expand Down
25 changes: 20 additions & 5 deletions wgkex/broker/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,23 @@ def set_metric(self, domain: str, metric: str, value: Any) -> None:
else:
self.domain_data[domain] = {metric: value}

def get_peer_count(self) -> int:
"""Returns the sum of connected peers on this worker over all domains"""
total = 0
for data in self.domain_data.values():
total += max(
data.get(CONNECTED_PEERS_METRIC, 0),
0,
)

return total


@dataclasses.dataclass
class WorkerMetricsCollection:
"""A container for all worker metrics"""
"""A container for all worker metrics
# TODO make threadsafe / fix data races
"""

# worker -> WorkerMetrics
data: Dict[str, WorkerMetrics] = dataclasses.field(default_factory=dict)
Expand Down Expand Up @@ -68,7 +81,8 @@ def set_offline(self, worker: str) -> None:
if worker in self.data:
self.data[worker].online = False

def get_total_peers(self) -> int:
def get_total_peer_count(self) -> int:
"""Returns the sum of connected peers over all workers and domains"""
total = 0
for worker in self.data:
worker_data = self.data.get(worker)
Expand Down Expand Up @@ -96,22 +110,23 @@ def get_best_worker(self, domain: str) -> Tuple[Optional[str], int, int]:
# Map metrics to a list of (target diff, peer count, worker) tuples for online workers

peers_worker_tuples = []
total_peers = self.get_total_peers()
total_peers = self.get_total_peer_count()
worker_cfg = config.get_config().workers

for wm in self.data.values():
if not wm.is_online(domain):
continue

peers = wm.get_domain_metrics(domain).get(CONNECTED_PEERS_METRIC)
peers = wm.get_peer_count()
rel_weight = worker_cfg.relative_worker_weight(wm.worker)
target = rel_weight * total_peers
diff = peers - target
logger.debug(
f"Worker {wm.worker}: rel weight {rel_weight}, target {target} (total {total_peers}), diff {diff}"
f"Worker candidate {wm.worker}: current {peers}, target {target} (total {total_peers}, rel weight {rel_weight}), diff {diff}"
)
peers_worker_tuples.append((diff, peers, wm.worker))

# Sort by diff (ascending), workers with most peers missing to target are sorted first
peers_worker_tuples = sorted(peers_worker_tuples, key=itemgetter(0))

if len(peers_worker_tuples) > 0:
Expand Down
20 changes: 20 additions & 0 deletions wgkex/broker/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,26 @@ def test_get_best_worker_returns_best(self, config_mock):
self.assertEqual(diff, -20) # 19-(1*(20+19))
self.assertEqual(connected, 19)

@mock.patch("wgkex.broker.metrics.config.get_config", autospec=True)
def test_get_best_worker_returns_best_imbalanced_domains(self, config_mock):
"""Verify get_best_worker returns the worker with overall least connected clients even if it has more clients on this domain."""
test_config = mock.MagicMock(spec=config.Config)
test_config.workers = config.Workers.from_dict({})
config_mock.return_value = test_config

worker_metrics = WorkerMetricsCollection()
worker_metrics.update("1", "domain1", "connected_peers", 25)
worker_metrics.update("1", "domain2", "connected_peers", 5)
worker_metrics.update("2", "domain1", "connected_peers", 20)
worker_metrics.update("2", "domain2", "connected_peers", 20)
worker_metrics.set_online("1")
worker_metrics.set_online("2")

(worker, diff, connected) = worker_metrics.get_best_worker("domain1")
self.assertEqual(worker, "1")
self.assertEqual(diff, -40) # 30-(1*(25+5+20+20))
self.assertEqual(connected, 30)

@mock.patch("wgkex.broker.metrics.config.get_config", autospec=True)
def test_get_best_worker_weighted_returns_best(self, config_mock):
"""Verify get_best_worker returns the worker with least client differential for weighted workers."""
Expand Down
2 changes: 1 addition & 1 deletion wgkex/worker/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_main_fails_bad_domain(self, connect_mock, config_mock):
app.main()
connect_mock.assert_not_called()

@mock.patch.object(app, "_CLEANUP_TIME", 0)
@mock.patch.object(app, "_CLEANUP_TIME", 1)
@mock.patch.object(app, "wg_flush_stale_peers")
def test_flush_workers_doesnt_throw(self, wg_flush_mock):
"""Ensure the flush_workers thread doesn't throw and exit if it encounters an exception."""
Expand Down
1 change: 1 addition & 0 deletions wgkex/worker/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def on_connect(client: mqtt.Client, userdata: Any, flags, rc) -> None:
logger.info(f"Subscribing to topic {topic}")
client.subscribe(topic)

for domain in domains:
# Publish worker data (WG pubkeys, ports, local addresses)
iface = wg_interface_name(domain)
if iface:
Expand Down
36 changes: 36 additions & 0 deletions wgkex/worker/mqtt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,42 @@ def test_connect_fails_mqtt_error(self, config_mock, mqtt_mock):
with self.assertRaises(ValueError):
mqtt.connect(threading.Event())

@mock.patch.object(mqtt.mqtt, "Client")
@mock.patch.object(mqtt, "get_config")
@mock.patch.object(mqtt, "get_device_data")
def test_on_connect_subscribes(
self, get_device_data_mock, config_mock, mqtt_client_mock
):
"""Test that the on_connect callback correctly subscribes to all domains and pushes device data"""
config_mqtt_mock = mock.MagicMock()
config_mqtt_mock.broker_url = "some_url"
config_mqtt_mock.broker_port = 1833
config_mqtt_mock.keepalive = False
config = _get_config_mock(mqtt=config_mqtt_mock)
config.external_name = None
config_mock.return_value = config
get_device_data_mock.return_value = (51820, "456asdf=", "fe80::1")

hostname = socket.gethostname()

mqtt.on_connect(mqtt.mqtt.Client(), None, None, 0)

mqtt_client_mock.assert_has_calls(
[
mock.call().subscribe("wireguard/_ffmuc_domain.one/+"),
mock.call().publish(
f"wireguard-worker/{hostname}/_ffmuc_domain.one/data",
'{"ExternalAddress": "%s", "Port": 51820, "PublicKey": "456asdf=", "LinkAddress": "fe80::1"}'
% hostname,
qos=1,
retain=True,
),
mock.call().publish(
f"wireguard-worker/{hostname}/status", 1, qos=1, retain=True
),
]
)

@mock.patch.object(mqtt, "get_config")
@mock.patch.object(mqtt, "get_connected_peers_count")
def test_publish_metrics_loop_success(self, conn_peers_mock, config_mock):
Expand Down
17 changes: 10 additions & 7 deletions wgkex/worker/netlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,10 @@ def get_connected_peers_count(wg_interface: str) -> int:
if peers:
for peer in peers:
if (
peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get(
"tv_sec", int()
)
> three_mins_ago_in_secs
):
hshk_time := peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME")
) is not None and hshk_time.get(
"tv_sec", int()
) > three_mins_ago_in_secs:
count += 1

return count
Expand All @@ -251,7 +250,7 @@ def get_device_data(wg_interface: str) -> Tuple[int, str, str]:
# The listening port, public key, and local IP address of the WireGuard interface.
"""
logger.info("Reading data from interface %s.", wg_interface)
with pyroute2.WireGuard() as wg, pyroute2.NDB() as ndb:
with pyroute2.WireGuard() as wg, pyroute2.IPRoute() as ipr:
msgs = wg.info(wg_interface)
logger.debug("Got infos for interface data: %s.", msgs)
if len(msgs) > 1:
Expand All @@ -262,7 +261,11 @@ def get_device_data(wg_interface: str) -> Tuple[int, str, str]:

port = int(info.get_attr("WGDEVICE_A_LISTEN_PORT"))
public_key = info.get_attr("WGDEVICE_A_PUBLIC_KEY").decode("ascii")
link_address = ndb.interfaces[wg_interface].ipaddr[0].get("address")

# Get link address using IPRoute
ipr_link = ipr.link_lookup(ifname=wg_interface)[0]
msgs = ipr.get_addr(index=ipr_link)
link_address = msgs[0].get_attr("IFA_ADDRESS")

logger.debug(
"Interface data: port '%s', public key '%s', link address '%s",
Expand Down

0 comments on commit b7d6e16

Please sign in to comment.