Skip to content

Commit

Permalink
[DPE-5236] update TLS lib for K8s mongos external clients (#478)
Browse files Browse the repository at this point in the history
## Issue
K8s mongos charm needs to be able to compare the current sans with the
expected sans in order to update TLS certs when the k8s ip changes

## Solution
Update lib to support the needed functionality
  • Loading branch information
MiaAltieri committed Sep 12, 2024
1 parent cf68980 commit 9ebba97
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 25 deletions.
107 changes: 85 additions & 22 deletions lib/charms/mongodb/v1/mongodb_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
external relation.
"""
import base64
import json
import logging
import re
import socket
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

from charms.tls_certificates_interface.v3.tls_certificates import (
CertificateAvailableEvent,
Expand All @@ -20,6 +21,8 @@
generate_csr,
generate_private_key,
)
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from ops.charm import ActionEvent, RelationBrokenEvent, RelationJoinedEvent
from ops.framework import Object
from ops.model import ActiveStatus, MaintenanceStatus, WaitingStatus
Expand All @@ -28,7 +31,8 @@

UNIT_SCOPE = Config.Relations.UNIT_SCOPE
Scopes = Config.Relations.Scopes

SANS_DNS_KEY = "sans_dns"
SANS_IPS_KEY = "sans_ips"

# The unique Charmhub library identifier, never change it
LIBID = "e02a50f0795e4dd292f58e93b4f493dd"
Expand All @@ -38,7 +42,9 @@

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 3
LIBPATCH = 4

WAIT_CERT_UPDATE = "wait-cert-updated"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,17 +105,21 @@ def request_certificate(
internal: bool,
):
"""Request TLS certificate."""
if not self.charm.model.get_relation(Config.TLS.TLS_PEER_RELATION):
return

if param is None:
key = generate_private_key()
else:
key = self._parse_tls_file(param)

sans = self.get_new_sans()
csr = generate_csr(
private_key=key,
subject=self._get_subject_name(),
organization=self._get_subject_name(),
sans=self._get_sans(),
sans_ip=[str(self.charm.model.get_binding(self.peer_relation).network.bind_address)],
sans=sans[SANS_DNS_KEY],
sans_ip=sans[SANS_IPS_KEY],
)
self.set_tls_secret(internal, Config.TLS.SECRET_KEY_LABEL, key.decode("utf-8"))
self.set_tls_secret(internal, Config.TLS.SECRET_CSR_LABEL, csr.decode("utf-8"))
Expand All @@ -118,9 +128,8 @@ def request_certificate(
label = "int" if internal else "ext"
self.charm.unit_peer_data[f"{label}_certs_subject"] = self._get_subject_name()
self.charm.unit_peer_data[f"{label}_certs_subject"] = self._get_subject_name()

if self.charm.model.get_relation(Config.TLS.TLS_PEER_RELATION):
self.certs.request_certificate_creation(certificate_signing_request=csr)
self.certs.request_certificate_creation(certificate_signing_request=csr)
self.set_waiting_for_cert_to_update(internal=internal, waiting=True)

@staticmethod
def _parse_tls_file(raw_content: str) -> bytes:
Expand Down Expand Up @@ -158,16 +167,18 @@ def _on_tls_relation_joined(self, event: RelationJoinedEvent) -> None:

def _on_tls_relation_broken(self, event: RelationBrokenEvent) -> None:
"""Disable TLS when TLS relation broken."""
logger.debug("Disabling external and internal TLS for unit: %s", self.charm.unit.name)
if not self.charm.db_initialised:
logger.info("Deferring %s. db is not initialised.", str(type(event)))
event.defer()
return

if self.charm.upgrade_in_progress:
logger.warning(
"Disabling TLS is not supported during an upgrade. The charm may be in a broken, unrecoverable state."
)

logger.debug("Disabling external and internal TLS for unit: %s", self.charm.unit.name)

for internal in [True, False]:
self.set_tls_secret(internal, Config.TLS.SECRET_CA_LABEL, None)
self.set_tls_secret(internal, Config.TLS.SECRET_CERT_LABEL, None)
Expand Down Expand Up @@ -217,12 +228,13 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None:
)
self.set_tls_secret(internal, Config.TLS.SECRET_CERT_LABEL, event.certificate)
self.set_tls_secret(internal, Config.TLS.SECRET_CA_LABEL, event.ca)
self.set_waiting_for_cert_to_update(internal=internal, waiting=False)

if self.charm.is_role(Config.Role.CONFIG_SERVER) and internal:
self.charm.cluster.update_ca_secret(new_ca=event.ca)
self.charm.config_server.update_ca_secret(new_ca=event.ca)

if self.waiting_for_certs():
if self.waiting_for_both_certs():
logger.debug(
"Defer till both internal and external TLS certificates available to avoid second restart."
)
Expand All @@ -244,7 +256,7 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None:
# clear waiting status if db service is ready
self.charm.status.set_and_share_status(ActiveStatus())

def waiting_for_certs(self):
def waiting_for_both_certs(self):
"""Returns a boolean indicating whether additional certs are needed."""
if not self.get_tls_secret(internal=True, label_name=Config.TLS.SECRET_CERT_LABEL):
logger.debug("Waiting for internal certificate.")
Expand Down Expand Up @@ -277,7 +289,6 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None:
== self.get_tls_secret(internal=True, label_name=Config.TLS.SECRET_CERT_LABEL).rstrip()
):
logger.debug("The internal TLS certificate expiring.")

internal = True
else:
logger.error("An unknown certificate expiring.")
Expand All @@ -286,12 +297,13 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None:
logger.debug("Generating a new Certificate Signing Request.")
key = self.get_tls_secret(internal, Config.TLS.SECRET_KEY_LABEL).encode("utf-8")
old_csr = self.get_tls_secret(internal, Config.TLS.SECRET_CSR_LABEL).encode("utf-8")
sans = self.get_new_sans()
new_csr = generate_csr(
private_key=key,
subject=self._get_subject_name(),
organization=self._get_subject_name(),
sans=self._get_sans(),
sans_ip=[str(self.charm.model.get_binding(self.peer_relation).network.bind_address)],
sans=sans[SANS_DNS_KEY],
sans_ip=sans[SANS_IPS_KEY],
)
logger.debug("Requesting a certificate renewal.")

Expand All @@ -302,20 +314,51 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None:

self.set_tls_secret(internal, Config.TLS.SECRET_CSR_LABEL, new_csr.decode("utf-8"))

def _get_sans(self) -> List[str]:
def get_new_sans(self) -> Dict:
"""Create a list of DNS names for a MongoDB unit.
Returns:
A list representing the hostnames of the MongoDB unit.
"""
unit_id = self.charm.unit.name.split("/")[1]
return [
f"{self.charm.app.name}-{unit_id}",
socket.getfqdn(),
f"{self.charm.app.name}-{unit_id}.{self.charm.app.name}-endpoints",
str(self.charm.model.get_binding(self.peer_relation).network.bind_address),
"localhost",
]

sans = {
SANS_DNS_KEY: [
f"{self.charm.app.name}-{unit_id}",
socket.getfqdn(),
"localhost",
f"{self.charm.app.name}-{unit_id}.{self.charm.app.name}-endpoints",
],
SANS_IPS_KEY: [
str(self.charm.model.get_binding(self.peer_relation).network.bind_address)
],
}

if self.charm.is_role(Config.Role.MONGOS) and self.charm.is_external_client:
sans[SANS_IPS_KEY].append(
self.charm.get_ext_mongos_host(self.charm.unit, incl_port=False)
)

return sans

def get_current_sans(self, internal: bool) -> List[str] | None:
"""Gets the current SANs for the unit cert."""
# if unit has no certificates do not proceed.
if not self.is_tls_enabled(internal=internal):
return

pem_file = self.get_tls_secret(internal, Config.TLS.SECRET_CERT_LABEL)

try:
cert = x509.load_pem_x509_certificate(pem_file.encode(), default_backend())
sans = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
sans_ip = [str(san) for san in sans.get_values_for_type(x509.IPAddress)]
sans_dns = [str(san) for san in sans.get_values_for_type(x509.DNSName)]
except x509.ExtensionNotFound:
sans_ip = []
sans_dns = []

return {SANS_IPS_KEY: sorted(sans_ip), SANS_DNS_KEY: sorted(sans_dns)}

def get_tls_files(self, internal: bool) -> Tuple[Optional[str], Optional[str]]:
"""Prepare TLS files in special MongoDB way.
Expand Down Expand Up @@ -365,3 +408,23 @@ def _get_subject_name(self) -> str:
return self.charm.get_config_server_name() or self.charm.app.name

return self.charm.app.name

def is_set_waiting_for_cert_to_update(
self,
internal=False,
) -> bool:
"""Returns True we are waiting for a cert to update."""
scope = "int" if internal else "ext"
label_name = f"{scope}-{WAIT_CERT_UPDATE}"

return json.loads(self.charm.unit_peer_data.get(label_name, "false"))

def set_waiting_for_cert_to_update(
self,
waiting: bool,
internal: bool,
) -> None:
"""Sets a boolean indicator, for whether or not we are waiting for a cert to update."""
scope = "int" if internal else "ext"
label_name = f"{scope}-{WAIT_CERT_UPDATE}"
self.charm.unit_peer_data[label_name] = json.dumps(waiting)
4 changes: 2 additions & 2 deletions lib/charms/mongodb/v1/shards_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 9
LIBPATCH = 10

KEYFILE_KEY = "key-file"
HOSTS_KEY = "host"
Expand Down Expand Up @@ -711,7 +711,7 @@ def _on_relation_changed(self, event):

self.update_member_auth(event, (key_file_enabled, tls_enabled))

if tls_enabled and self.charm.tls.waiting_for_certs():
if tls_enabled and self.charm.tls.waiting_for_both_certs():
logger.info("Waiting for requested certs, before restarting and adding to cluster.")
event.defer()
return
Expand Down
43 changes: 42 additions & 1 deletion tests/unit/test_tls_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest import mock
from unittest.mock import patch

from cryptography import x509
from ops.testing import Harness
from parameterized import parameterized

Expand Down Expand Up @@ -31,15 +32,18 @@ def setUp(self, *unused):

@parameterized.expand([True, False])
@patch("charm.CrossAppVersionChecker.is_local_charm")
@patch("charm.MongoDBTLS.get_new_sans")
@patch("charm.CrossAppVersionChecker.is_integrated_to_locally_built_charm")
@patch("charms.mongodb.v0.set_status.get_charm_revision")
@patch_network_get(private_address="1.1.1.1")
def test_set_tls_private_keys(self, leader, *unused):
def test_set_tls_private_keys(self, leader, get_new_sans, *unused):
"""Tests setting of TLS private key via the leader, ie both internal and external.
Note: this implicitly tests: _request_certificate & _parse_tls_file
"""
self.harness.add_relation("certificates", "certificates")
# Tests for leader unit (ie internal certificates and external certificates)
get_new_sans.return_value = {"sans_dns": [""], "sans_ips": ["1.1.1.1"]}
self.harness.set_leader(leader)
action_event = mock.Mock()
action_event.params = {}
Expand Down Expand Up @@ -295,6 +299,42 @@ def test_external_certificate_broken_deferred(self, defer, *unused):

defer.assert_called()

def test_get_new_sans_gives_node_port_for_mongos_k8s(self):
"""Tests that get_new_sans only gets node port for external mongos K8s."""
mock_get_ext_mongos_host = mock.Mock()
mock_get_ext_mongos_host.return_value = "node_port"
self.harness.charm.get_ext_mongos_host = mock_get_ext_mongos_host
for substrate in ["k8s", "vm"]:
for role in ["mongos", "config-server", "shard"]:
if role == "mongos" and substrate == "k8s":
continue

assert "node-port" not in self.harness.charm.tls.get_new_sans()["sans_ips"]

@patch("charm.MongoDBTLS.is_tls_enabled")
@patch("charms.mongodb.v1.mongodb_tls.x509.load_pem_x509_certificate")
def test_get_current_sans_returns_none(self, cert, is_tls_enabled):
"""Tests the different scenarios that get_current_sans returns None.
1. get_current_sans returns None when TLS is not enabled.
2. get_current_sans returns None if cert file is wrongly formatted.
"""
# case 1: get_current_sans returns None when TLS is not enabled.
is_tls_enabled.return_value = None
for internal in [True, False]:
self.assertEqual(self.harness.charm.tls.get_current_sans(internal), None)

# case 2: error getting extension
is_tls_enabled.return_value = True
cert.side_effect = x509.ExtensionNotFound(msg="error-message", oid=1)
self.harness.charm.set_secret("unit", "ext-cert-secret", "unit-cert")
self.harness.charm.set_secret("unit", "int-cert-secret", "app-cert")

for internal in [True, False]:
self.assertEqual(
self.harness.charm.tls.get_current_sans(internal), {"sans_ips": [], "sans_dns": []}
)

# Helper functions
def relate_to_tls_certificates_operator(self) -> int:
"""Relates the charm to the TLS certificates operator."""
Expand Down Expand Up @@ -331,6 +371,7 @@ def verify_internal_rsa_csr(
"""
int_rsa_key = self.harness.charm.get_secret("unit", "int-key-secret")
int_csr = self.harness.charm.get_secret("unit", "int-csr-secret")

if specific_rsa:
self.assertEqual(int_rsa_key, expected_rsa)
else:
Expand Down

0 comments on commit 9ebba97

Please sign in to comment.