Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Insert server certificate to chain if its not already there #109

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions lib/charms/observability_libs/v1/cert_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a"
LIBAPI = 1
LIBPATCH = 12
LIBPATCH = 13

VAULT_SECRET_LABEL = "cert-handler-private-vault"

Expand Down Expand Up @@ -584,9 +584,19 @@ def server_cert(self) -> Optional[str]:

@property
def chain(self) -> Optional[str]:
"""Return the ca chain bundled as a single PEM string."""
"""Return the entire chain bundled as a single PEM string. This includes, if available, the certificate, intermediate CAs, and the root CA.

If the server certificate is not set in the chain by the provider, we'll add it
to the top of the chain so that it could be used by a server.
"""
cert = self.get_cert()
return cert.chain_as_pem() if cert else None
if not cert:
return None
chain = cert.chain_as_pem()
if cert.certificate not in chain:
# add server cert to chain
chain = cert.certificate + "\n\n" + chain
return chain

def _on_certificate_expiring(
self, event: Union[CertificateExpiringEvent, CertificateInvalidatedEvent]
Expand Down
96 changes: 96 additions & 0 deletions tests/scenario/test_cert_handler/test_cert_handler_v1.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import datetime
import json
import socket
import sys
from contextlib import contextmanager
Expand All @@ -7,6 +9,8 @@
import pytest
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import ExtensionOID
from ops import CharmBase
from scenario import Context, PeerRelation, Relation, State
Expand Down Expand Up @@ -43,6 +47,71 @@ def _mock_san(self):
return None


def generate_certificate_and_key():
"""Generate certificate and CA to use for tests."""
# Generate private key
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)

# Generate CA certificate
ca_subject = issuer = x509.Name(
[
x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "US"),
x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, "California"),
x509.NameAttribute(x509.NameOID.LOCALITY_NAME, "San Francisco"),
x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, "Example CA"),
x509.NameAttribute(x509.NameOID.COMMON_NAME, "example.com"),
]
)

ca_cert = (
x509.CertificateBuilder()
.subject_name(ca_subject)
.issuer_name(issuer)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
.sign(private_key, hashes.SHA256())
)

# Generate server certificate
server_subject = x509.Name(
[
x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "US"),
x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, "California"),
x509.NameAttribute(x509.NameOID.LOCALITY_NAME, "San Francisco"),
x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, "Example Server"),
x509.NameAttribute(x509.NameOID.COMMON_NAME, "server.example.com"),
]
)

server_cert = (
x509.CertificateBuilder()
.subject_name(server_subject)
.issuer_name(ca_subject)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=30))
.add_extension(
x509.SubjectAlternativeName([x509.DNSName("server.example.com")]), critical=False
)
.sign(private_key, hashes.SHA256())
)

# Convert to PEM format
ca_cert_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
server_cert_pem = server_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
private_key_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
).decode("utf-8")

return ca_cert_pem, server_cert_pem, private_key_pem


def get_csr_obj(csr: str):
return x509.load_pem_x509_csr(csr.encode(), default_backend())

Expand Down Expand Up @@ -174,3 +243,30 @@ def test_csr_no_change(ctx, certificates):
csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(csr) == {socket.getfqdn()}
assert renew_patch.call_count == 0


def test_chain_contains_server_cert(ctx, certificates):
ca_cert_pem, server_cert_pem, _ = generate_certificate_and_key()

certificates = certificates.replace(
remote_app_data={
"certificates": json.dumps(
[
{
"certificate": server_cert_pem,
"ca": ca_cert_pem,
"chain": [ca_cert_pem],
"certificate_signing_request": "csr",
}
],
)
},
local_unit_data={
"certificate_signing_requests": json.dumps([{"certificate_signing_request": "csr"}])
},
)

with ctx.manager("update_status", State(leader=True, relations=[certificates])) as mgr:
mgr.run()
assert server_cert_pem in mgr.charm.ch.chain
assert x509.load_pem_x509_certificate(mgr.charm.ch.chain.encode(), default_backend())
Loading