From b12d1432b4f53243f0b76d0fa83d7411107a5ff5 Mon Sep 17 00:00:00 2001 From: Michael Dmitry <33381599+michaeldmitry@users.noreply.github.com> Date: Mon, 16 Sep 2024 10:17:12 +0300 Subject: [PATCH] append server cert to chain (#109) --- .../observability_libs/v1/cert_handler.py | 16 +++- .../test_cert_handler/test_cert_handler_v1.py | 96 +++++++++++++++++++ 2 files changed, 109 insertions(+), 3 deletions(-) diff --git a/lib/charms/observability_libs/v1/cert_handler.py b/lib/charms/observability_libs/v1/cert_handler.py index 6e693ff..4a1940b 100644 --- a/lib/charms/observability_libs/v1/cert_handler.py +++ b/lib/charms/observability_libs/v1/cert_handler.py @@ -67,7 +67,7 @@ LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a" LIBAPI = 1 -LIBPATCH = 12 +LIBPATCH = 13 VAULT_SECRET_LABEL = "cert-handler-private-vault" @@ -584,9 +584,19 @@ def server_cert(self) -> Optional[str]: @property def chain(self) -> Optional[str]: - """Return the ca chain bundled as a single PEM string.""" + """Return the entire chain bundled as a single PEM string. This includes, if available, the certificate, intermediate CAs, and the root CA. + + If the server certificate is not set in the chain by the provider, we'll add it + to the top of the chain so that it could be used by a server. + """ cert = self.get_cert() - return cert.chain_as_pem() if cert else None + if not cert: + return None + chain = cert.chain_as_pem() + if cert.certificate not in chain: + # add server cert to chain + chain = cert.certificate + "\n\n" + chain + return chain def _on_certificate_expiring( self, event: Union[CertificateExpiringEvent, CertificateInvalidatedEvent] diff --git a/tests/scenario/test_cert_handler/test_cert_handler_v1.py b/tests/scenario/test_cert_handler/test_cert_handler_v1.py index 589995f..78defb8 100644 --- a/tests/scenario/test_cert_handler/test_cert_handler_v1.py +++ b/tests/scenario/test_cert_handler/test_cert_handler_v1.py @@ -1,3 +1,5 @@ +import datetime +import json import socket import sys from contextlib import contextmanager @@ -7,6 +9,8 @@ import pytest from cryptography import x509 from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import ExtensionOID from ops import CharmBase from scenario import Context, PeerRelation, Relation, State @@ -43,6 +47,71 @@ def _mock_san(self): return None +def generate_certificate_and_key(): + """Generate certificate and CA to use for tests.""" + # Generate private key + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + # Generate CA certificate + ca_subject = issuer = x509.Name( + [ + x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, "California"), + x509.NameAttribute(x509.NameOID.LOCALITY_NAME, "San Francisco"), + x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, "Example CA"), + x509.NameAttribute(x509.NameOID.COMMON_NAME, "example.com"), + ] + ) + + ca_cert = ( + x509.CertificateBuilder() + .subject_name(ca_subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.utcnow()) + .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365)) + .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True) + .sign(private_key, hashes.SHA256()) + ) + + # Generate server certificate + server_subject = x509.Name( + [ + x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, "California"), + x509.NameAttribute(x509.NameOID.LOCALITY_NAME, "San Francisco"), + x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, "Example Server"), + x509.NameAttribute(x509.NameOID.COMMON_NAME, "server.example.com"), + ] + ) + + server_cert = ( + x509.CertificateBuilder() + .subject_name(server_subject) + .issuer_name(ca_subject) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.utcnow()) + .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=30)) + .add_extension( + x509.SubjectAlternativeName([x509.DNSName("server.example.com")]), critical=False + ) + .sign(private_key, hashes.SHA256()) + ) + + # Convert to PEM format + ca_cert_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") + server_cert_pem = server_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") + private_key_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") + + return ca_cert_pem, server_cert_pem, private_key_pem + + def get_csr_obj(csr: str): return x509.load_pem_x509_csr(csr.encode(), default_backend()) @@ -174,3 +243,30 @@ def test_csr_no_change(ctx, certificates): csr = get_csr_obj(charm.ch._csr) assert get_sans_from_csr(csr) == {socket.getfqdn()} assert renew_patch.call_count == 0 + + +def test_chain_contains_server_cert(ctx, certificates): + ca_cert_pem, server_cert_pem, _ = generate_certificate_and_key() + + certificates = certificates.replace( + remote_app_data={ + "certificates": json.dumps( + [ + { + "certificate": server_cert_pem, + "ca": ca_cert_pem, + "chain": [ca_cert_pem], + "certificate_signing_request": "csr", + } + ], + ) + }, + local_unit_data={ + "certificate_signing_requests": json.dumps([{"certificate_signing_request": "csr"}]) + }, + ) + + with ctx.manager("update_status", State(leader=True, relations=[certificates])) as mgr: + mgr.run() + assert server_cert_pem in mgr.charm.ch.chain + assert x509.load_pem_x509_certificate(mgr.charm.ch.chain.encode(), default_backend())