From 2ea994a72c04740a287a1307799e8ed0a50aace0 Mon Sep 17 00:00:00 2001 From: guillaume Date: Thu, 2 Nov 2023 15:17:34 +0200 Subject: [PATCH] feat: Removes reliance on defer and handles certificate management in central method --- .../v2/tls_certificates.py | 49 +++++++++++++------ src/charm.py | 37 +++++++++++--- tests/unit/test_charm.py | 9 ++-- 3 files changed, 71 insertions(+), 24 deletions(-) diff --git a/lib/charms/tls_certificates_interface/v2/tls_certificates.py b/lib/charms/tls_certificates_interface/v2/tls_certificates.py index 09a7443..469386a 100644 --- a/lib/charms/tls_certificates_interface/v2/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v2/tls_certificates.py @@ -298,7 +298,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven ) from ops.framework import EventBase, EventSource, Handle, Object from ops.jujuversion import JujuVersion -from ops.model import Relation, SecretNotFoundError +from ops.model import ModelError, Relation, RelationDataContent, SecretNotFoundError # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" @@ -308,7 +308,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 19 +LIBPATCH = 20 PYDEPS = ["cryptography", "jsonschema"] @@ -600,23 +600,26 @@ def restore(self, snapshot: dict): self.chain = snapshot["chain"] -def _load_relation_data(raw_relation_data: dict) -> dict: +def _load_relation_data(relation_data_content: RelationDataContent) -> dict: """Loads relation data from the relation data bag. Json loads all data. Args: - raw_relation_data: Relation data from the databag + relation_data_content: Relation data from the databag Returns: dict: Relation data in dict format. """ certificate_data = dict() - for key in raw_relation_data: - try: - certificate_data[key] = json.loads(raw_relation_data[key]) - except (json.decoder.JSONDecodeError, TypeError): - certificate_data[key] = raw_relation_data[key] + try: + for key in relation_data_content: + try: + certificate_data[key] = json.loads(relation_data_content[key]) + except (json.decoder.JSONDecodeError, TypeError): + certificate_data[key] = relation_data_content[key] + except ModelError: + pass return certificate_data @@ -1257,12 +1260,24 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None ) self.remove_certificate(certificate=certificate["certificate"]) - def get_requirer_csrs_with_no_certs( + def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Filters the requirer's units csrs. + """Returns CSR's for which no certificate has been issued. - Keeps the ones for which no certificate was provided. + Example return: [ + { + "relation_id": 0, + "application_name": "tls-certificates-requirer", + "unit_name": "tls-certificates-requirer/0", + "unit_csrs": [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "is_ca": false + } + ] + } + ] Args: relation_id (int): Relation id @@ -1279,6 +1294,7 @@ def get_requirer_csrs_with_no_certs( if not self.certificate_issued_for_csr( app_name=unit_csr_mapping["application_name"], # type: ignore[arg-type] csr=csr["certificate_signing_request"], # type: ignore[index] + relation_id=relation_id, ): csrs_without_certs.append(csr) if csrs_without_certs: @@ -1325,17 +1341,22 @@ def get_requirer_csrs( ) return unit_csr_mappings - def certificate_issued_for_csr(self, app_name: str, csr: str) -> bool: + def certificate_issued_for_csr( + self, app_name: str, csr: str, relation_id: Optional[int] + ) -> bool: """Checks whether a certificate has been issued for a given CSR. Args: app_name (str): Application name that the CSR belongs to. csr (str): Certificate Signing Request. + relation_id (int): Relation ID Returns: bool: True/False depending on whether a certificate has been issued for the given CSR. """ - issued_certificates_per_csr = self.get_issued_certificates()[app_name] + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)[ + app_name + ] for issued_pair in issued_certificates_per_csr: if "csr" in issued_pair and issued_pair["csr"] == csr: return csr_matches_certificate(csr, issued_pair["certificate"]) diff --git a/src/charm.py b/src/charm.py index d7d5665..af29c4c 100755 --- a/src/charm.py +++ b/src/charm.py @@ -20,6 +20,7 @@ generate_certificate, generate_private_key, ) +from cryptography import x509 from ops.charm import ActionEvent, CharmBase, EventBase, RelationJoinedEvent from ops.main import main from ops.model import ActiveStatus, BlockedStatus, SecretNotFoundError @@ -31,6 +32,17 @@ SEND_CA_CERT_REL_NAME = "send-ca-cert" # Must match metadata +def certificate_has_common_name(certificate: bytes, common_name: str) -> bool: + """Returns whether the certificate has the given common name.""" + print(certificate) + loaded_certificate = x509.load_pem_x509_certificate(certificate) + certificate_common_name = loaded_certificate.subject.get_attributes_for_oid( + x509.oid.NameOID.COMMON_NAME + )[0].value + + return certificate_common_name == common_name + + class SelfSignedCertificatesCharm(CharmBase): """Main class to handle Juju events.""" @@ -38,12 +50,13 @@ def __init__(self, *args): """Observes config change and certificate request events.""" super().__init__(*args) self.tls_certificates = TLSCertificatesProvidesV2(self, "certificates") - self.framework.observe(self.on.config_changed, self._configure_ca) + self.framework.observe(self.on.update_status, self._configure) + self.framework.observe(self.on.config_changed, self._configure) + self.framework.observe(self.on.secret_expired, self._configure) self.framework.observe( self.tls_certificates.on.certificate_creation_request, self._on_certificate_creation_request, ) - self.framework.observe(self.on.secret_expired, self._configure_ca) self.framework.observe(self.on.get_ca_certificate_action, self._on_get_ca_certificate) self.framework.observe( self.on.get_issued_certificates_action, self._on_get_issued_certificates @@ -145,7 +158,7 @@ def _generate_root_certificate(self) -> None: ) logger.info("Root certificates generated and stored.") - def _configure_ca(self, event: EventBase) -> None: + def _configure(self, event: EventBase) -> None: """Validates configuration and generates root certificate. It will revoke the certificates signed by the previous root certificate. @@ -160,16 +173,26 @@ def _configure_ca(self, event: EventBase) -> None: f"The following configuration values are not valid: {invalid_configs}" ) return - self._generate_root_certificate() - self.tls_certificates.revoke_all_certificates() - logger.info("Revoked all previously issued certificates.") + if not self._root_certificate_is_stored or not self._root_certificate_matches_config(): + self._generate_root_certificate() + self.tls_certificates.revoke_all_certificates() + logger.info("Revoked all previously issued certificates.") self._send_ca_cert() self._process_outstanding_certificate_requests() self.unit.status = ActiveStatus() + def _root_certificate_matches_config(self) -> bool: + """Returns whether the stored root certificate matches with the config.""" + if not self._config_ca_common_name: + raise ValueError("CA common name should not be empty") + ca_certificate_secret = self.model.get_secret(label=CA_CERTIFICATES_SECRET_LABEL) + ca_certificate_secret_content = ca_certificate_secret.get_content() + ca = ca_certificate_secret_content["ca-certificate"].encode() + return certificate_has_common_name(certificate=ca, common_name=self._config_ca_common_name) + def _process_outstanding_certificate_requests(self) -> None: """Process outstanding certificate requests.""" - for relation in self.tls_certificates.get_requirer_csrs_with_no_certs(): + for relation in self.tls_certificates.get_outstanding_certificate_requests(): for request in relation["unit_csrs"]: self._generate_self_signed_certificate( csr=request["certificate_signing_request"], diff --git a/tests/unit/test_charm.py b/tests/unit/test_charm.py index b952966..b4fcbc4 100644 --- a/tests/unit/test_charm.py +++ b/tests/unit/test_charm.py @@ -108,7 +108,7 @@ def test_given_valid_config_when_config_changed_then_status_is_active( self.assertEqual(self.harness.model.unit.status, ActiveStatus()) @patch(f"{TLS_LIB_PATH}.TLSCertificatesProvidesV2.set_relation_certificate") - @patch(f"{TLS_LIB_PATH}.TLSCertificatesProvidesV2.get_requirer_csrs_with_no_certs") + @patch(f"{TLS_LIB_PATH}.TLSCertificatesProvidesV2.get_outstanding_certificate_requests") @patch("charm.generate_private_key") @patch("charm.generate_password") @patch("charm.generate_ca") @@ -119,7 +119,7 @@ def test_given_outstanding_certificate_requests_when_config_changed_then_request patch_generate_ca, patch_generate_password, patch_generate_private_key, - patch_get_requirer_csrs_with_no_certs, + patch_get_outstanding_certificate_requests, patch_set_relation_certificate, ): validity = 100 @@ -133,7 +133,7 @@ def test_given_outstanding_certificate_requests_when_config_changed_then_request patch_generate_ca.return_value = ca.encode() patch_generate_password.return_value = private_key_password patch_generate_private_key.return_value = private_key.encode() - patch_get_requirer_csrs_with_no_certs.return_value = [ + patch_get_outstanding_certificate_requests.return_value = [ { "relation_id": relation_id, "unit_csrs": [ @@ -270,6 +270,7 @@ def test_given_root_certificates_when_certificate_request_then_certificates_are_ certificate_signing_request=certificate_signing_request, ) + @patch("charm.certificate_has_common_name") @patch("charm.generate_private_key") @patch("charm.generate_password") @patch("charm.generate_ca") @@ -278,7 +279,9 @@ def test_given_initial_config_when_config_changed_then_stored_ca_common_name_use patch_generate_ca, patch_generate_password, patch_generate_private_key, + patch_certificate_has_common_name, ): + patch_certificate_has_common_name.return_value = False initial_common_name = "common-name-initial.com" new_common_name = "common-name-new.com" ca_certificate_1_string = "whatever CA certificate 1"