Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fixes for the loadbalancing changes #121

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion wgkex.yaml.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# [broker] The domains that should be accepted by clients and for which matching WireGuard interfaces exist
# [broker, worker] The domains that should be accepted by clients and for which matching WireGuard interfaces exist
domains:
- ffmuc_muc_cty
- ffmuc_muc_nord
Expand Down
13 changes: 10 additions & 3 deletions wgkex/broker/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,15 @@ def wg_api_v2_key_exchange() -> Tuple[Response | Dict, int]:
}
}, 400

# Update number of peers locally to interpolate data between MQTT updates from the worker
# TODO fix data race
current_peers_domain = (
worker_metrics.get(best_worker)
.get_domain_metrics(domain)
.get(CONNECTED_PEERS_METRIC, 0)
)
worker_metrics.update(
best_worker, domain, CONNECTED_PEERS_METRIC, current_peers + 1
best_worker, domain, CONNECTED_PEERS_METRIC, current_peers_domain + 1
)
logger.debug(
f"Chose worker {best_worker} with {current_peers} connected clients ({diff})"
Expand Down Expand Up @@ -200,10 +207,10 @@ def handle_mqtt_message_status(
_, worker, _ = message.topic.split("/", 2)

status = int(message.payload)
if status < 1:
if status < 1 and worker_metrics.get(worker).is_online():
logger.warning(f"Marking worker as offline: {worker}")
worker_metrics.set_offline(worker)
else:
elif status >= 1 and not worker_metrics.get(worker).is_online():
logger.warning(f"Marking worker as online: {worker}")
worker_metrics.set_online(worker)

Expand Down
25 changes: 20 additions & 5 deletions wgkex/broker/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,23 @@ def set_metric(self, domain: str, metric: str, value: Any) -> None:
else:
self.domain_data[domain] = {metric: value}

def get_peer_count(self) -> int:
"""Returns the sum of connected peers on this worker over all domains"""
total = 0
for data in self.domain_data.values():
total += max(
data.get(CONNECTED_PEERS_METRIC, 0),
0,
)

return total


@dataclasses.dataclass
class WorkerMetricsCollection:
"""A container for all worker metrics"""
"""A container for all worker metrics
# TODO make threadsafe / fix data races
"""

# worker -> WorkerMetrics
data: Dict[str, WorkerMetrics] = dataclasses.field(default_factory=dict)
Expand Down Expand Up @@ -68,7 +81,8 @@ def set_offline(self, worker: str) -> None:
if worker in self.data:
self.data[worker].online = False

def get_total_peers(self) -> int:
def get_total_peer_count(self) -> int:
"""Returns the sum of connected peers over all workers and domains"""
total = 0
for worker in self.data:
worker_data = self.data.get(worker)
Expand Down Expand Up @@ -96,22 +110,23 @@ def get_best_worker(self, domain: str) -> Tuple[Optional[str], int, int]:
# Map metrics to a list of (target diff, peer count, worker) tuples for online workers

peers_worker_tuples = []
total_peers = self.get_total_peers()
total_peers = self.get_total_peer_count()
worker_cfg = config.get_config().workers

for wm in self.data.values():
if not wm.is_online(domain):
continue

peers = wm.get_domain_metrics(domain).get(CONNECTED_PEERS_METRIC)
peers = wm.get_peer_count()
rel_weight = worker_cfg.relative_worker_weight(wm.worker)
target = rel_weight * total_peers
diff = peers - target
logger.debug(
f"Worker {wm.worker}: rel weight {rel_weight}, target {target} (total {total_peers}), diff {diff}"
f"Worker candidate {wm.worker}: current {peers}, target {target} (total {total_peers}, rel weight {rel_weight}), diff {diff}"
)
peers_worker_tuples.append((diff, peers, wm.worker))

# Sort by diff (ascending), workers with most peers missing to target are sorted first
peers_worker_tuples = sorted(peers_worker_tuples, key=itemgetter(0))

if len(peers_worker_tuples) > 0:
Expand Down
20 changes: 20 additions & 0 deletions wgkex/broker/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,26 @@ def test_get_best_worker_returns_best(self, config_mock):
self.assertEqual(diff, -20) # 19-(1*(20+19))
self.assertEqual(connected, 19)

@mock.patch("wgkex.broker.metrics.config.get_config", autospec=True)
def test_get_best_worker_returns_best_imbalanced_domains(self, config_mock):
"""Verify get_best_worker returns the worker with overall least connected clients even if it has more clients on this domain."""
test_config = mock.MagicMock(spec=config.Config)
test_config.workers = config.Workers.from_dict({})
config_mock.return_value = test_config

worker_metrics = WorkerMetricsCollection()
worker_metrics.update("1", "domain1", "connected_peers", 25)
worker_metrics.update("1", "domain2", "connected_peers", 5)
worker_metrics.update("2", "domain1", "connected_peers", 20)
worker_metrics.update("2", "domain2", "connected_peers", 20)
worker_metrics.set_online("1")
worker_metrics.set_online("2")

(worker, diff, connected) = worker_metrics.get_best_worker("domain1")
self.assertEqual(worker, "1")
self.assertEqual(diff, -40) # 30-(1*(25+5+20+20))
self.assertEqual(connected, 30)

@mock.patch("wgkex.broker.metrics.config.get_config", autospec=True)
def test_get_best_worker_weighted_returns_best(self, config_mock):
"""Verify get_best_worker returns the worker with least client differential for weighted workers."""
Expand Down
2 changes: 1 addition & 1 deletion wgkex/worker/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_main_fails_bad_domain(self, connect_mock, config_mock):
app.main()
connect_mock.assert_not_called()

@mock.patch.object(app, "_CLEANUP_TIME", 0)
@mock.patch.object(app, "_CLEANUP_TIME", 1)
@mock.patch.object(app, "wg_flush_stale_peers")
def test_flush_workers_doesnt_throw(self, wg_flush_mock):
"""Ensure the flush_workers thread doesn't throw and exit if it encounters an exception."""
Expand Down
1 change: 1 addition & 0 deletions wgkex/worker/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def on_connect(client: mqtt.Client, userdata: Any, flags, rc) -> None:
logger.info(f"Subscribing to topic {topic}")
client.subscribe(topic)

for domain in domains:
# Publish worker data (WG pubkeys, ports, local addresses)
iface = wg_interface_name(domain)
if iface:
Expand Down
36 changes: 36 additions & 0 deletions wgkex/worker/mqtt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,42 @@ def test_connect_fails_mqtt_error(self, config_mock, mqtt_mock):
with self.assertRaises(ValueError):
mqtt.connect(threading.Event())

@mock.patch.object(mqtt.mqtt, "Client")
@mock.patch.object(mqtt, "get_config")
@mock.patch.object(mqtt, "get_device_data")
def test_on_connect_subscribes(
self, get_device_data_mock, config_mock, mqtt_client_mock
):
"""Test that the on_connect callback correctly subscribes to all domains and pushes device data"""
config_mqtt_mock = mock.MagicMock()
config_mqtt_mock.broker_url = "some_url"
config_mqtt_mock.broker_port = 1833
config_mqtt_mock.keepalive = False
config = _get_config_mock(mqtt=config_mqtt_mock)
config.external_name = None
config_mock.return_value = config
get_device_data_mock.return_value = (51820, "456asdf=", "fe80::1")

hostname = socket.gethostname()

mqtt.on_connect(mqtt.mqtt.Client(), None, None, 0)

mqtt_client_mock.assert_has_calls(
[
mock.call().subscribe("wireguard/_ffmuc_domain.one/+"),
mock.call().publish(
f"wireguard-worker/{hostname}/_ffmuc_domain.one/data",
'{"ExternalAddress": "%s", "Port": 51820, "PublicKey": "456asdf=", "LinkAddress": "fe80::1"}'
% hostname,
qos=1,
retain=True,
),
mock.call().publish(
f"wireguard-worker/{hostname}/status", 1, qos=1, retain=True
),
]
)

@mock.patch.object(mqtt, "get_config")
@mock.patch.object(mqtt, "get_connected_peers_count")
def test_publish_metrics_loop_success(self, conn_peers_mock, config_mock):
Expand Down
17 changes: 10 additions & 7 deletions wgkex/worker/netlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,10 @@ def get_connected_peers_count(wg_interface: str) -> int:
if peers:
for peer in peers:
if (
peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get(
"tv_sec", int()
)
> three_mins_ago_in_secs
):
hshk_time := peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME")
) is not None and hshk_time.get(
"tv_sec", int()
) > three_mins_ago_in_secs:
count += 1

return count
Expand All @@ -251,7 +250,7 @@ def get_device_data(wg_interface: str) -> Tuple[int, str, str]:
# The listening port, public key, and local IP address of the WireGuard interface.
"""
logger.info("Reading data from interface %s.", wg_interface)
with pyroute2.WireGuard() as wg, pyroute2.NDB() as ndb:
with pyroute2.WireGuard() as wg, pyroute2.IPRoute() as ipr:
msgs = wg.info(wg_interface)
logger.debug("Got infos for interface data: %s.", msgs)
if len(msgs) > 1:
Expand All @@ -262,7 +261,11 @@ def get_device_data(wg_interface: str) -> Tuple[int, str, str]:

port = int(info.get_attr("WGDEVICE_A_LISTEN_PORT"))
public_key = info.get_attr("WGDEVICE_A_PUBLIC_KEY").decode("ascii")
link_address = ndb.interfaces[wg_interface].ipaddr[0].get("address")

# Get link address using IPRoute
ipr_link = ipr.link_lookup(ifname=wg_interface)[0]
msgs = ipr.get_addr(index=ipr_link)
link_address = msgs[0].get_attr("IFA_ADDRESS")

logger.debug(
"Interface data: port '%s', public key '%s', link address '%s",
Expand Down
Loading