Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adds supports to generate a CA certificate #40

Merged
merged 2 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _on_certificates_relation_joined(self, event: RelationJoinedEvent):
certificate = "my certificate"
ca = "my CA certificate"
chain = ["certificate 1", "certificate 2"]
self.certificate_transfer.set_certificate(certificate=certificate, ca=ca, chain=chain)
self.certificate_transfer.set_certificate(certificate=certificate, ca=ca, chain=chain, relation_id=event.relation.id)


if __name__ == "__main__":
Expand All @@ -54,6 +54,7 @@ def _on_certificates_relation_joined(self, event: RelationJoinedEvent):

from lib.charms.certificate_transfer_interface.v0.certificate_transfer import (
CertificateAvailableEvent,
CertificateRemovedEvent,
CertificateTransferRequires,
)

Expand All @@ -65,11 +66,18 @@ def __init__(self, *args):
self.framework.observe(
self.certificate_transfer.on.certificate_available, self._on_certificate_available
)
self.framework.observe(
self.certificate_transfer.on.certificate_removed, self._on_certificate_removed
)

def _on_certificate_available(self, event: CertificateAvailableEvent):
print(event.certificate)
print(event.ca)
print(event.chain)
print(event.relation_id)

def _on_certificate_removed(self, event: CertificateRemovedEvent):
print(event.relation_id)


if __name__ == "__main__":
Expand All @@ -87,10 +95,10 @@ def _on_certificate_available(self, event: CertificateAvailableEvent):

import json
import logging
from typing import List, Optional
from typing import List

from jsonschema import exceptions, validate # type: ignore[import]
from ops.charm import CharmBase, CharmEvents, RelationChangedEvent
from jsonschema import exceptions, validate # type: ignore[import-untyped]
from ops.charm import CharmBase, CharmEvents, RelationBrokenEvent, RelationChangedEvent
from ops.framework import EventBase, EventSource, Handle, Object

# The unique Charmhub library identifier, never change it
Expand All @@ -101,7 +109,7 @@ def _on_certificate_available(self, event: CertificateAvailableEvent):

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 1
LIBPATCH = 5

PYDEPS = ["jsonschema"]

Expand Down Expand Up @@ -161,25 +169,45 @@ def __init__(
certificate: str,
ca: str,
chain: List[str],
relation_id: int,
):
super().__init__(handle)
self.certificate = certificate
self.ca = ca
self.chain = chain
self.relation_id = relation_id

def snapshot(self) -> dict:
"""Return snapshot."""
return {
"certificate": self.certificate,
"ca": self.ca,
"chain": self.chain,
"relation_id": self.relation_id,
}

def restore(self, snapshot: dict):
"""Restores snapshot."""
self.certificate = snapshot["certificate"]
self.ca = snapshot["ca"]
self.chain = snapshot["chain"]
self.relation_id = snapshot["relation_id"]


class CertificateRemovedEvent(EventBase):
"""Charm Event triggered when a TLS certificate is removed."""

def __init__(self, handle: Handle, relation_id: int):
super().__init__(handle)
self.relation_id = relation_id

def snapshot(self) -> dict:
"""Return snapshot."""
return {"relation_id": self.relation_id}

def restore(self, snapshot: dict):
"""Restores snapshot."""
self.relation_id = snapshot["relation_id"]


def _load_relation_data(raw_relation_data: dict) -> dict:
Expand All @@ -204,6 +232,7 @@ class CertificateTransferRequirerCharmEvents(CharmEvents):
"""List of events that the Certificate Transfer requirer charm can leverage."""

certificate_available = EventSource(CertificateAvailableEvent)
certificate_removed = EventSource(CertificateRemovedEvent)


class CertificateTransferProvides(Object):
Expand All @@ -219,7 +248,7 @@ def set_certificate(
certificate: str,
ca: str,
chain: List[str],
relation_id: Optional[int] = None,
relation_id: int,
) -> None:
"""Add certificates to relation data.

Expand All @@ -245,7 +274,7 @@ def set_certificate(
relation.data[self.model.unit]["ca"] = ca
relation.data[self.model.unit]["chain"] = json.dumps(chain)

def remove_certificate(self, relation_id: Optional[int] = None) -> None:
def remove_certificate(self, relation_id: int) -> None:
"""Remove a given certificate from relation data.

Args:
Expand Down Expand Up @@ -303,6 +332,9 @@ def __init__(
self.framework.observe(
charm.on[relationship_name].relation_changed, self._on_relation_changed
)
self.framework.observe(
charm.on[relationship_name].relation_broken, self._on_relation_broken
)

@staticmethod
def _relation_data_is_valid(relation_data: dict) -> bool:
Expand Down Expand Up @@ -343,4 +375,16 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None:
certificate=remote_unit_relation_data.get("certificate"),
ca=remote_unit_relation_data.get("ca"),
chain=remote_unit_relation_data.get("chain"),
relation_id=event.relation.id,
)

def _on_relation_broken(self, event: RelationBrokenEvent) -> None:
"""Handler triggered on relation broken event.

Args:
event: Juju event

Returns:
None
"""
self.on.certificate_removed.emit(relation_id=event.relation.id)
102 changes: 81 additions & 21 deletions lib/charms/tls_certificates_interface/v2/tls_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.serialization import pkcs12
from cryptography.x509.extensions import Extension, ExtensionNotFound
from jsonschema import exceptions, validate # type: ignore[import]
from jsonschema import exceptions, validate # type: ignore[import-untyped]
from ops.charm import (
CharmBase,
CharmEvents,
Expand All @@ -308,7 +308,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 16
LIBPATCH = 19

PYDEPS = ["cryptography", "jsonschema"]

Expand All @@ -335,7 +335,10 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven
"type": "array",
"items": {
"type": "object",
"properties": {"certificate_signing_request": {"type": "string"}},
"properties": {
"certificate_signing_request": {"type": "string"},
"ca": {"type": "boolean"},
},
"required": ["certificate_signing_request"],
},
}
Expand Down Expand Up @@ -536,22 +539,31 @@ def restore(self, snapshot: dict):
class CertificateCreationRequestEvent(EventBase):
"""Charm Event triggered when a TLS certificate is required."""

def __init__(self, handle: Handle, certificate_signing_request: str, relation_id: int):
def __init__(
self,
handle: Handle,
certificate_signing_request: str,
relation_id: int,
is_ca: bool = False,
):
super().__init__(handle)
self.certificate_signing_request = certificate_signing_request
self.relation_id = relation_id
self.is_ca = is_ca

def snapshot(self) -> dict:
"""Returns snapshot."""
return {
"certificate_signing_request": self.certificate_signing_request,
"relation_id": self.relation_id,
"is_ca": self.is_ca,
}

def restore(self, snapshot: dict):
"""Restores snapshot."""
self.certificate_signing_request = snapshot["certificate_signing_request"]
self.relation_id = snapshot["relation_id"]
self.is_ca = snapshot["is_ca"]


class CertificateRevocationRequestEvent(EventBase):
Expand Down Expand Up @@ -685,6 +697,7 @@ def generate_certificate(
ca_key_password: Optional[bytes] = None,
validity: int = 365,
alt_names: Optional[List[str]] = None,
is_ca: bool = False,
) -> bytes:
"""Generates a TLS certificate based on a CSR.

Expand All @@ -695,6 +708,7 @@ def generate_certificate(
ca_key_password: CA private key password
validity (int): Certificate validity (in days)
alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR
is_ca (bool): Whether the certificate is a CA certificate

Returns:
bytes: Certificate
Expand Down Expand Up @@ -726,7 +740,6 @@ def generate_certificate(
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(csr_object.public_key()), critical=False
)
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=False)
)

extensions_list = csr_object.extensions
Expand Down Expand Up @@ -758,6 +771,29 @@ def generate_certificate(
critical=extension.critical,
)

if is_ca:
certificate_builder = certificate_builder.add_extension(
x509.BasicConstraints(ca=True, path_length=None), critical=True
)
certificate_builder = certificate_builder.add_extension(
x509.KeyUsage(
digital_signature=False,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
else:
certificate_builder = certificate_builder.add_extension(
x509.BasicConstraints(ca=False, path_length=None), critical=False
)

certificate_builder._version = x509.Version.v3
cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type]
return cert.public_bytes(serialization.Encoding.PEM)
Expand Down Expand Up @@ -1171,15 +1207,19 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None:
certificate_creation_request["certificate_signing_request"]
for certificate_creation_request in provider_certificates
]
requirer_unit_csrs = [
certificate_creation_request["certificate_signing_request"]
requirer_unit_certificate_requests = [
{
"csr": certificate_creation_request["certificate_signing_request"],
"is_ca": certificate_creation_request.get("ca", False),
}
for certificate_creation_request in requirer_csrs
]
for certificate_signing_request in requirer_unit_csrs:
if certificate_signing_request not in provider_csrs:
for certificate_request in requirer_unit_certificate_requests:
if certificate_request["csr"] not in provider_csrs:
self.on.certificate_creation_request.emit(
certificate_signing_request=certificate_signing_request,
certificate_signing_request=certificate_request["csr"],
relation_id=event.relation.id,
is_ca=certificate_request["is_ca"],
)
self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id)

Expand Down Expand Up @@ -1337,8 +1377,17 @@ def __init__(
self.framework.observe(charm.on.update_status, self._on_update_status)

@property
def _requirer_csrs(self) -> List[Dict[str, str]]:
"""Returns list of requirer's CSRs from relation data."""
def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]:
"""Returns list of requirer's CSRs from relation data.

Example:
[
{
"certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...",
"ca": false
}
]
"""
relation = self.model.get_relation(self.relationship_name)
if not relation:
raise RuntimeError(f"Relation {self.relationship_name} does not exist")
Expand All @@ -1361,11 +1410,12 @@ def _provider_certificates(self) -> List[Dict[str, str]]:
return []
return provider_relation_data.get("certificates", [])

def _add_requirer_csr(self, csr: str) -> None:
def _add_requirer_csr(self, csr: str, is_ca: bool) -> None:
"""Adds CSR to relation data.

Args:
csr (str): Certificate Signing Request
is_ca (bool): Whether the certificate is a CA certificate

Returns:
None
Expand All @@ -1376,7 +1426,10 @@ def _add_requirer_csr(self, csr: str) -> None:
f"Relation {self.relationship_name} does not exist - "
f"The certificate request can't be completed"
)
new_csr_dict = {"certificate_signing_request": csr}
new_csr_dict: Dict[str, Union[bool, str]] = {
"certificate_signing_request": csr,
"ca": is_ca,
}
if new_csr_dict in self._requirer_csrs:
logger.info("CSR already in relation data - Doing nothing")
return
Expand All @@ -1400,18 +1453,22 @@ def _remove_requirer_csr(self, csr: str) -> None:
f"The certificate request can't be completed"
)
requirer_csrs = copy.deepcopy(self._requirer_csrs)
csr_dict = {"certificate_signing_request": csr}
if csr_dict not in requirer_csrs:
logger.info("CSR not in relation data - Doing nothing")
if not requirer_csrs:
logger.info("No CSRs in relation data - Doing nothing")
return
requirer_csrs.remove(csr_dict)
for requirer_csr in requirer_csrs:
if requirer_csr["certificate_signing_request"] == csr:
requirer_csrs.remove(requirer_csr)
relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs)

def request_certificate_creation(self, certificate_signing_request: bytes) -> None:
def request_certificate_creation(
self, certificate_signing_request: bytes, is_ca: bool = False
) -> None:
"""Request TLS certificate to provider charm.

Args:
certificate_signing_request (bytes): Certificate Signing Request
is_ca (bool): Whether the certificate is a CA certificate

Returns:
None
Expand All @@ -1422,7 +1479,7 @@ def request_certificate_creation(self, certificate_signing_request: bytes) -> No
f"Relation {self.relationship_name} does not exist - "
f"The certificate request can't be completed"
)
self._add_requirer_csr(certificate_signing_request.decode().strip())
self._add_requirer_csr(certificate_signing_request.decode().strip(), is_ca=is_ca)
logger.info("Certificate request sent to provider")

def request_certificate_revocation(self, certificate_signing_request: bytes) -> None:
Expand Down Expand Up @@ -1701,7 +1758,10 @@ def csr_matches_certificate(csr: str, cert: str) -> bool:
format=serialization.PublicFormat.SubjectPublicKeyInfo,
):
return False
if csr_object.subject != cert_object.subject:
if (
csr_object.public_key().public_numbers().n # type: ignore[union-attr]
!= cert_object.public_key().public_numbers().n # type: ignore[union-attr]
):
return False
except ValueError:
logger.warning("Could not load certificate or CSR.")
Expand Down
Loading
Loading