From 9ebba9738869d888a34092999208a289280bf147 Mon Sep 17 00:00:00 2001 From: Mia Altieri <32723809+MiaAltieri@users.noreply.github.com> Date: Thu, 12 Sep 2024 11:11:52 +0200 Subject: [PATCH] [DPE-5236] update TLS lib for K8s mongos external clients (#478) ## Issue K8s mongos charm needs to be able to compare the current sans with the expected sans in order to update TLS certs when the k8s ip changes ## Solution Update lib to support the needed functionality --- lib/charms/mongodb/v1/mongodb_tls.py | 107 +++++++++++++++++----- lib/charms/mongodb/v1/shards_interface.py | 4 +- tests/unit/test_tls_lib.py | 43 ++++++++- 3 files changed, 129 insertions(+), 25 deletions(-) diff --git a/lib/charms/mongodb/v1/mongodb_tls.py b/lib/charms/mongodb/v1/mongodb_tls.py index d78df5d13..8f00bde7b 100644 --- a/lib/charms/mongodb/v1/mongodb_tls.py +++ b/lib/charms/mongodb/v1/mongodb_tls.py @@ -8,10 +8,11 @@ external relation. """ import base64 +import json import logging import re import socket -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple from charms.tls_certificates_interface.v3.tls_certificates import ( CertificateAvailableEvent, @@ -20,6 +21,8 @@ generate_csr, generate_private_key, ) +from cryptography import x509 +from cryptography.hazmat.backends import default_backend from ops.charm import ActionEvent, RelationBrokenEvent, RelationJoinedEvent from ops.framework import Object from ops.model import ActiveStatus, MaintenanceStatus, WaitingStatus @@ -28,7 +31,8 @@ UNIT_SCOPE = Config.Relations.UNIT_SCOPE Scopes = Config.Relations.Scopes - +SANS_DNS_KEY = "sans_dns" +SANS_IPS_KEY = "sans_ips" # The unique Charmhub library identifier, never change it LIBID = "e02a50f0795e4dd292f58e93b4f493dd" @@ -38,7 +42,9 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 3 +LIBPATCH = 4 + +WAIT_CERT_UPDATE = "wait-cert-updated" logger = logging.getLogger(__name__) @@ -99,17 +105,21 @@ def request_certificate( internal: bool, ): """Request TLS certificate.""" + if not self.charm.model.get_relation(Config.TLS.TLS_PEER_RELATION): + return + if param is None: key = generate_private_key() else: key = self._parse_tls_file(param) + sans = self.get_new_sans() csr = generate_csr( private_key=key, subject=self._get_subject_name(), organization=self._get_subject_name(), - sans=self._get_sans(), - sans_ip=[str(self.charm.model.get_binding(self.peer_relation).network.bind_address)], + sans=sans[SANS_DNS_KEY], + sans_ip=sans[SANS_IPS_KEY], ) self.set_tls_secret(internal, Config.TLS.SECRET_KEY_LABEL, key.decode("utf-8")) self.set_tls_secret(internal, Config.TLS.SECRET_CSR_LABEL, csr.decode("utf-8")) @@ -118,9 +128,8 @@ def request_certificate( label = "int" if internal else "ext" self.charm.unit_peer_data[f"{label}_certs_subject"] = self._get_subject_name() self.charm.unit_peer_data[f"{label}_certs_subject"] = self._get_subject_name() - - if self.charm.model.get_relation(Config.TLS.TLS_PEER_RELATION): - self.certs.request_certificate_creation(certificate_signing_request=csr) + self.certs.request_certificate_creation(certificate_signing_request=csr) + self.set_waiting_for_cert_to_update(internal=internal, waiting=True) @staticmethod def _parse_tls_file(raw_content: str) -> bytes: @@ -158,16 +167,18 @@ def _on_tls_relation_joined(self, event: RelationJoinedEvent) -> None: def _on_tls_relation_broken(self, event: RelationBrokenEvent) -> None: """Disable TLS when TLS relation broken.""" - logger.debug("Disabling external and internal TLS for unit: %s", self.charm.unit.name) if not self.charm.db_initialised: logger.info("Deferring %s. db is not initialised.", str(type(event))) event.defer() return + if self.charm.upgrade_in_progress: logger.warning( "Disabling TLS is not supported during an upgrade. The charm may be in a broken, unrecoverable state." ) + logger.debug("Disabling external and internal TLS for unit: %s", self.charm.unit.name) + for internal in [True, False]: self.set_tls_secret(internal, Config.TLS.SECRET_CA_LABEL, None) self.set_tls_secret(internal, Config.TLS.SECRET_CERT_LABEL, None) @@ -217,12 +228,13 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: ) self.set_tls_secret(internal, Config.TLS.SECRET_CERT_LABEL, event.certificate) self.set_tls_secret(internal, Config.TLS.SECRET_CA_LABEL, event.ca) + self.set_waiting_for_cert_to_update(internal=internal, waiting=False) if self.charm.is_role(Config.Role.CONFIG_SERVER) and internal: self.charm.cluster.update_ca_secret(new_ca=event.ca) self.charm.config_server.update_ca_secret(new_ca=event.ca) - if self.waiting_for_certs(): + if self.waiting_for_both_certs(): logger.debug( "Defer till both internal and external TLS certificates available to avoid second restart." ) @@ -244,7 +256,7 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: # clear waiting status if db service is ready self.charm.status.set_and_share_status(ActiveStatus()) - def waiting_for_certs(self): + def waiting_for_both_certs(self): """Returns a boolean indicating whether additional certs are needed.""" if not self.get_tls_secret(internal=True, label_name=Config.TLS.SECRET_CERT_LABEL): logger.debug("Waiting for internal certificate.") @@ -277,7 +289,6 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None: == self.get_tls_secret(internal=True, label_name=Config.TLS.SECRET_CERT_LABEL).rstrip() ): logger.debug("The internal TLS certificate expiring.") - internal = True else: logger.error("An unknown certificate expiring.") @@ -286,12 +297,13 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None: logger.debug("Generating a new Certificate Signing Request.") key = self.get_tls_secret(internal, Config.TLS.SECRET_KEY_LABEL).encode("utf-8") old_csr = self.get_tls_secret(internal, Config.TLS.SECRET_CSR_LABEL).encode("utf-8") + sans = self.get_new_sans() new_csr = generate_csr( private_key=key, subject=self._get_subject_name(), organization=self._get_subject_name(), - sans=self._get_sans(), - sans_ip=[str(self.charm.model.get_binding(self.peer_relation).network.bind_address)], + sans=sans[SANS_DNS_KEY], + sans_ip=sans[SANS_IPS_KEY], ) logger.debug("Requesting a certificate renewal.") @@ -302,20 +314,51 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None: self.set_tls_secret(internal, Config.TLS.SECRET_CSR_LABEL, new_csr.decode("utf-8")) - def _get_sans(self) -> List[str]: + def get_new_sans(self) -> Dict: """Create a list of DNS names for a MongoDB unit. Returns: A list representing the hostnames of the MongoDB unit. """ unit_id = self.charm.unit.name.split("/")[1] - return [ - f"{self.charm.app.name}-{unit_id}", - socket.getfqdn(), - f"{self.charm.app.name}-{unit_id}.{self.charm.app.name}-endpoints", - str(self.charm.model.get_binding(self.peer_relation).network.bind_address), - "localhost", - ] + + sans = { + SANS_DNS_KEY: [ + f"{self.charm.app.name}-{unit_id}", + socket.getfqdn(), + "localhost", + f"{self.charm.app.name}-{unit_id}.{self.charm.app.name}-endpoints", + ], + SANS_IPS_KEY: [ + str(self.charm.model.get_binding(self.peer_relation).network.bind_address) + ], + } + + if self.charm.is_role(Config.Role.MONGOS) and self.charm.is_external_client: + sans[SANS_IPS_KEY].append( + self.charm.get_ext_mongos_host(self.charm.unit, incl_port=False) + ) + + return sans + + def get_current_sans(self, internal: bool) -> List[str] | None: + """Gets the current SANs for the unit cert.""" + # if unit has no certificates do not proceed. + if not self.is_tls_enabled(internal=internal): + return + + pem_file = self.get_tls_secret(internal, Config.TLS.SECRET_CERT_LABEL) + + try: + cert = x509.load_pem_x509_certificate(pem_file.encode(), default_backend()) + sans = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + sans_ip = [str(san) for san in sans.get_values_for_type(x509.IPAddress)] + sans_dns = [str(san) for san in sans.get_values_for_type(x509.DNSName)] + except x509.ExtensionNotFound: + sans_ip = [] + sans_dns = [] + + return {SANS_IPS_KEY: sorted(sans_ip), SANS_DNS_KEY: sorted(sans_dns)} def get_tls_files(self, internal: bool) -> Tuple[Optional[str], Optional[str]]: """Prepare TLS files in special MongoDB way. @@ -365,3 +408,23 @@ def _get_subject_name(self) -> str: return self.charm.get_config_server_name() or self.charm.app.name return self.charm.app.name + + def is_set_waiting_for_cert_to_update( + self, + internal=False, + ) -> bool: + """Returns True we are waiting for a cert to update.""" + scope = "int" if internal else "ext" + label_name = f"{scope}-{WAIT_CERT_UPDATE}" + + return json.loads(self.charm.unit_peer_data.get(label_name, "false")) + + def set_waiting_for_cert_to_update( + self, + waiting: bool, + internal: bool, + ) -> None: + """Sets a boolean indicator, for whether or not we are waiting for a cert to update.""" + scope = "int" if internal else "ext" + label_name = f"{scope}-{WAIT_CERT_UPDATE}" + self.charm.unit_peer_data[label_name] = json.dumps(waiting) diff --git a/lib/charms/mongodb/v1/shards_interface.py b/lib/charms/mongodb/v1/shards_interface.py index a0e60b55e..9be6c8a41 100644 --- a/lib/charms/mongodb/v1/shards_interface.py +++ b/lib/charms/mongodb/v1/shards_interface.py @@ -58,7 +58,7 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 9 +LIBPATCH = 10 KEYFILE_KEY = "key-file" HOSTS_KEY = "host" @@ -711,7 +711,7 @@ def _on_relation_changed(self, event): self.update_member_auth(event, (key_file_enabled, tls_enabled)) - if tls_enabled and self.charm.tls.waiting_for_certs(): + if tls_enabled and self.charm.tls.waiting_for_both_certs(): logger.info("Waiting for requested certs, before restarting and adding to cluster.") event.defer() return diff --git a/tests/unit/test_tls_lib.py b/tests/unit/test_tls_lib.py index 3a38fe1e6..bf908b749 100644 --- a/tests/unit/test_tls_lib.py +++ b/tests/unit/test_tls_lib.py @@ -4,6 +4,7 @@ from unittest import mock from unittest.mock import patch +from cryptography import x509 from ops.testing import Harness from parameterized import parameterized @@ -31,15 +32,18 @@ def setUp(self, *unused): @parameterized.expand([True, False]) @patch("charm.CrossAppVersionChecker.is_local_charm") + @patch("charm.MongoDBTLS.get_new_sans") @patch("charm.CrossAppVersionChecker.is_integrated_to_locally_built_charm") @patch("charms.mongodb.v0.set_status.get_charm_revision") @patch_network_get(private_address="1.1.1.1") - def test_set_tls_private_keys(self, leader, *unused): + def test_set_tls_private_keys(self, leader, get_new_sans, *unused): """Tests setting of TLS private key via the leader, ie both internal and external. Note: this implicitly tests: _request_certificate & _parse_tls_file """ + self.harness.add_relation("certificates", "certificates") # Tests for leader unit (ie internal certificates and external certificates) + get_new_sans.return_value = {"sans_dns": [""], "sans_ips": ["1.1.1.1"]} self.harness.set_leader(leader) action_event = mock.Mock() action_event.params = {} @@ -295,6 +299,42 @@ def test_external_certificate_broken_deferred(self, defer, *unused): defer.assert_called() + def test_get_new_sans_gives_node_port_for_mongos_k8s(self): + """Tests that get_new_sans only gets node port for external mongos K8s.""" + mock_get_ext_mongos_host = mock.Mock() + mock_get_ext_mongos_host.return_value = "node_port" + self.harness.charm.get_ext_mongos_host = mock_get_ext_mongos_host + for substrate in ["k8s", "vm"]: + for role in ["mongos", "config-server", "shard"]: + if role == "mongos" and substrate == "k8s": + continue + + assert "node-port" not in self.harness.charm.tls.get_new_sans()["sans_ips"] + + @patch("charm.MongoDBTLS.is_tls_enabled") + @patch("charms.mongodb.v1.mongodb_tls.x509.load_pem_x509_certificate") + def test_get_current_sans_returns_none(self, cert, is_tls_enabled): + """Tests the different scenarios that get_current_sans returns None. + + 1. get_current_sans returns None when TLS is not enabled. + 2. get_current_sans returns None if cert file is wrongly formatted. + """ + # case 1: get_current_sans returns None when TLS is not enabled. + is_tls_enabled.return_value = None + for internal in [True, False]: + self.assertEqual(self.harness.charm.tls.get_current_sans(internal), None) + + # case 2: error getting extension + is_tls_enabled.return_value = True + cert.side_effect = x509.ExtensionNotFound(msg="error-message", oid=1) + self.harness.charm.set_secret("unit", "ext-cert-secret", "unit-cert") + self.harness.charm.set_secret("unit", "int-cert-secret", "app-cert") + + for internal in [True, False]: + self.assertEqual( + self.harness.charm.tls.get_current_sans(internal), {"sans_ips": [], "sans_dns": []} + ) + # Helper functions def relate_to_tls_certificates_operator(self) -> int: """Relates the charm to the TLS certificates operator.""" @@ -331,6 +371,7 @@ def verify_internal_rsa_csr( """ int_rsa_key = self.harness.charm.get_secret("unit", "int-key-secret") int_csr = self.harness.charm.get_secret("unit", "int-csr-secret") + if specific_rsa: self.assertEqual(int_rsa_key, expected_rsa) else: