Skip to content

Commit

Permalink
feat: Implements option to request for a CA certificate
Browse files Browse the repository at this point in the history
tests: Improves unit tests

Update lib/charms/tls_certificates_interface/v2/tls_certificates.py

Co-authored-by: Ghislain Bourgeois <ghislain.bourgeois@canonical.com>

Update lib/charms/tls_certificates_interface/v2/tls_certificates.py

Co-authored-by: Ghislain Bourgeois <ghislain.bourgeois@canonical.com>

Update lib/charms/tls_certificates_interface/v2/tls_certificates.py

Co-authored-by: Ghislain Bourgeois <ghislain.bourgeois@canonical.com>

Update lib/charms/tls_certificates_interface/v2/tls_certificates.py

Co-authored-by: Ghislain Bourgeois <ghislain.bourgeois@canonical.com>

updates docstring
  • Loading branch information
gruyaume committed Oct 31, 2023
1 parent 522d963 commit a9dd1e6
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 21 deletions.
75 changes: 61 additions & 14 deletions lib/charms/tls_certificates_interface/v2/tls_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 18
LIBPATCH = 19

PYDEPS = ["cryptography", "jsonschema"]

Expand All @@ -335,7 +335,10 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven
"type": "array",
"items": {
"type": "object",
"properties": {"certificate_signing_request": {"type": "string"}},
"properties": {
"certificate_signing_request": {"type": "string"},
"ca": {"type": "boolean"},
},
"required": ["certificate_signing_request"],
},
}
Expand Down Expand Up @@ -536,22 +539,31 @@ def restore(self, snapshot: dict):
class CertificateCreationRequestEvent(EventBase):
"""Charm Event triggered when a TLS certificate is required."""

def __init__(self, handle: Handle, certificate_signing_request: str, relation_id: int):
def __init__(
self,
handle: Handle,
certificate_signing_request: str,
relation_id: int,
is_ca: bool = False,
):
super().__init__(handle)
self.certificate_signing_request = certificate_signing_request
self.relation_id = relation_id
self.is_ca = is_ca

def snapshot(self) -> dict:
"""Returns snapshot."""
return {
"certificate_signing_request": self.certificate_signing_request,
"relation_id": self.relation_id,
"is_ca": self.is_ca,
}

def restore(self, snapshot: dict):
"""Restores snapshot."""
self.certificate_signing_request = snapshot["certificate_signing_request"]
self.relation_id = snapshot["relation_id"]
self.is_ca = snapshot["is_ca"]


class CertificateRevocationRequestEvent(EventBase):
Expand Down Expand Up @@ -685,6 +697,7 @@ def generate_certificate(
ca_key_password: Optional[bytes] = None,
validity: int = 365,
alt_names: Optional[List[str]] = None,
is_ca: bool = False,
) -> bytes:
"""Generates a TLS certificate based on a CSR.
Expand All @@ -695,6 +708,7 @@ def generate_certificate(
ca_key_password: CA private key password
validity (int): Certificate validity (in days)
alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR
is_ca (bool): Whether the certificate is a CA certificate
Returns:
bytes: Certificate
Expand Down Expand Up @@ -726,7 +740,6 @@ def generate_certificate(
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(csr_object.public_key()), critical=False
)
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=False)
)

extensions_list = csr_object.extensions
Expand Down Expand Up @@ -758,6 +771,29 @@ def generate_certificate(
critical=extension.critical,
)

if is_ca:
certificate_builder = certificate_builder.add_extension(
x509.BasicConstraints(ca=True, path_length=None), critical=True
)
certificate_builder = certificate_builder.add_extension(
x509.KeyUsage(
digital_signature=False,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
else:
certificate_builder = certificate_builder.add_extension(
x509.BasicConstraints(ca=False, path_length=None), critical=False
)

certificate_builder._version = x509.Version.v3
cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type]
return cert.public_bytes(serialization.Encoding.PEM)
Expand Down Expand Up @@ -1171,15 +1207,19 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None:
certificate_creation_request["certificate_signing_request"]
for certificate_creation_request in provider_certificates
]
requirer_unit_csrs = [
certificate_creation_request["certificate_signing_request"]
requirer_unit_certificate_requests = [
{
"csr": certificate_creation_request["certificate_signing_request"],
"is_ca": certificate_creation_request.get("ca", False),
}
for certificate_creation_request in requirer_csrs
]
for certificate_signing_request in requirer_unit_csrs:
if certificate_signing_request not in provider_csrs:
for certificate_request in requirer_unit_certificate_requests:
if certificate_request["csr"] not in provider_csrs:
self.on.certificate_creation_request.emit(
certificate_signing_request=certificate_signing_request,
certificate_signing_request=certificate_request[0],
relation_id=event.relation.id,
is_ca=certificate_request["is_ca"],
)
self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id)

Expand Down Expand Up @@ -1337,7 +1377,7 @@ def __init__(
self.framework.observe(charm.on.update_status, self._on_update_status)

@property
def _requirer_csrs(self) -> List[Dict[str, str]]:
def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]:
"""Returns list of requirer's CSRs from relation data."""
relation = self.model.get_relation(self.relationship_name)
if not relation:
Expand All @@ -1361,11 +1401,12 @@ def _provider_certificates(self) -> List[Dict[str, str]]:
return []
return provider_relation_data.get("certificates", [])

def _add_requirer_csr(self, csr: str) -> None:
def _add_requirer_csr(self, csr: str, is_ca: bool) -> None:
"""Adds CSR to relation data.
Args:
csr (str): Certificate Signing Request
is_ca (bool): Whether the certificate is a CA certificate
Returns:
None
Expand All @@ -1376,7 +1417,10 @@ def _add_requirer_csr(self, csr: str) -> None:
f"Relation {self.relationship_name} does not exist - "
f"The certificate request can't be completed"
)
new_csr_dict = {"certificate_signing_request": csr}
new_csr_dict = {
"certificate_signing_request": csr,
"ca": is_ca,
}
if new_csr_dict in self._requirer_csrs:
logger.info("CSR already in relation data - Doing nothing")
return
Expand Down Expand Up @@ -1407,11 +1451,14 @@ def _remove_requirer_csr(self, csr: str) -> None:
requirer_csrs.remove(csr_dict)
relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs)

def request_certificate_creation(self, certificate_signing_request: bytes) -> None:
def request_certificate_creation(
self, certificate_signing_request: bytes, is_ca: bool = False
) -> None:
"""Request TLS certificate to provider charm.
Args:
certificate_signing_request (bytes): Certificate Signing Request
is_ca (bool): Whether the certificate is a CA certificate
Returns:
None
Expand All @@ -1422,7 +1469,7 @@ def request_certificate_creation(self, certificate_signing_request: bytes) -> No
f"Relation {self.relationship_name} does not exist - "
f"The certificate request can't be completed"
)
self._add_requirer_csr(certificate_signing_request.decode().strip())
self._add_requirer_csr(certificate_signing_request.decode().strip(), is_ca=is_ca)
logger.info("Certificate request sent to provider")

def request_certificate_revocation(self, certificate_signing_request: bytes) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,40 @@ def test_given_ca_cert_with_subject_key_id_when_generate_certificate_then_certif
x509.SubjectKeyIdentifier
).value.key_identifier
)


def test_given_request_is_for_ca_certificate_when_generate_certificate_then_certificate_is_generated():
ca_private_key = generate_private_key()
ca = generate_ca(
private_key=ca_private_key,
subject="my.demo.ca",
)
server_private_key = generate_private_key()

server_csr = generate_csr(
private_key=server_private_key,
subject="10.10.10.10",
sans_dns=[],
sans_ip=["10.10.10.10"],
)

server_cert = generate_certificate(
csr=server_csr,
ca=ca,
ca_key=ca_private_key,
is_ca=True,
)

loaded_server_cert = x509.load_pem_x509_certificate(server_cert)

assert (
loaded_server_cert.extensions.get_extension_for_class(x509.BasicConstraints).value.ca
is True
)
assert (
loaded_server_cert.extensions.get_extension_for_class(x509.KeyUsage).value.key_cert_sign
is True
)
assert (
loaded_server_cert.extensions.get_extension_for_class(x509.KeyUsage).value.crl_sign is True
)
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_given_csr_in_relation_data_when_relation_changed_then_certificate_creat
)

patch_certificate_creation_request.assert_has_calls(
[call().emit(certificate_signing_request=csr, relation_id=relation_id)]
[call().emit(certificate_signing_request=csr, relation_id=relation_id, is_ca=False)]
)

@patch(
Expand Down Expand Up @@ -776,6 +776,42 @@ def test_given_more_than_one_application_related_to_operator_when_csrs_are_added
self.assertEqual(call_args_list[1].args[0].certificate_signing_request, csr_2)
self.assertEqual(call_args_list[1].args[0].relation_id, relation_2_id)

@patch(
f"{LIB_DIR}.CertificatesProviderCharmEvents.certificate_creation_request",
new_callable=PropertyMock,
)
def test_given_requirer_unit_requests_ca_when_relation_changed_then_certificate_creation_request_is_emitted(
self, patch_certificate_creation_request
):
relation_id = self.create_certificates_relation_with_1_remote_unit()
self.harness.set_leader(is_leader=True)
csr = "whatever csr"
remote_unit_relation_data = {
"certificate_signing_requests": json.dumps(
[
{
"certificate_signing_request": csr,
"ca": True,
}
]
)
}
self.harness.update_relation_data(
relation_id=relation_id,
app_or_unit=self.remote_unit_name,
key_values=remote_unit_relation_data,
)

patch_certificate_creation_request.assert_has_calls(
[
call().emit(
certificate_signing_request=csr,
is_ca=True,
relation_id=relation_id,
)
]
)

def test_given_certificates_in_relation_data_when_revoke_all_certificates_then_no_certificates_are_present(
self,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_given_csr_when_request_certificate_creation_then_csr_is_sent_in_relatio
)

assert json.loads(unit_relation_data["certificate_signing_requests"]) == [
{"certificate_signing_request": csr.decode().strip()}
{"certificate_signing_request": csr.decode().strip(), "ca": False}
]

def test_given_relation_data_already_contains_csr_when_request_certificate_creation_then_csr_is_not_sent_again( # noqa: E501
Expand All @@ -92,7 +92,7 @@ def test_given_relation_data_already_contains_csr_when_request_certificate_creat
)
key_values = {
"certificate_signing_requests": json.dumps(
[{"certificate_signing_request": csr.decode().strip()}]
[{"certificate_signing_request": csr.decode().strip(), "ca": False}]
)
}
self.harness.update_relation_data(
Expand All @@ -110,7 +110,7 @@ def test_given_relation_data_already_contains_csr_when_request_certificate_creat
)

assert json.loads(unit_relation_data["certificate_signing_requests"]) == [
{"certificate_signing_request": csr.decode().strip()}
{"certificate_signing_request": csr.decode().strip(), "ca": False}
]

def test_given_different_csr_in_relation_data_when_request_certificate_creation_then_new_csr_is_added( # noqa: E501
Expand All @@ -133,7 +133,7 @@ def test_given_different_csr_in_relation_data_when_request_certificate_creation_
)
key_values = {
"certificate_signing_requests": json.dumps(
[{"certificate_signing_request": initial_csr.decode().strip()}]
[{"certificate_signing_request": initial_csr.decode().strip(), "ca": False}]
)
}
self.harness.update_relation_data(
Expand All @@ -151,14 +151,39 @@ def test_given_different_csr_in_relation_data_when_request_certificate_creation_
)

expected_client_cert_requests = [
{"certificate_signing_request": initial_csr.decode().strip()},
{"certificate_signing_request": new_csr.decode().strip()},
{"certificate_signing_request": initial_csr.decode().strip(), "ca": False},
{"certificate_signing_request": new_csr.decode().strip(), "ca": False},
]
self.assertEqual(
expected_client_cert_requests,
json.loads(unit_relation_data["certificate_signing_requests"]),
)

def test_given_wants_ca_when_request_certificate_creation_then_csr_and_ca_are_set_in_relation_data(
self,
):
relation_id = self.create_certificates_relation()
private_key_password = b"whatever"
private_key = generate_private_key_helper(password=private_key_password)
csr = generate_csr_helper(
private_key=private_key,
private_key_password=private_key_password,
common_name="whatever.com",
)

self.harness.charm.certificates.request_certificate_creation(
certificate_signing_request=csr,
is_ca=True,
)

unit_relation_data = self.harness.get_relation_data(
relation_id=relation_id, app_or_unit=self.harness.charm.unit
)

assert json.loads(unit_relation_data["certificate_signing_requests"]) == [
{"certificate_signing_request": csr.decode().strip(), "ca": True}
]

def test_given_no_relation_when_request_certificate_revocation_then_runtime_error_is_raised(
self,
):
Expand Down

0 comments on commit a9dd1e6

Please sign in to comment.