Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate ChaCha20Poly1305 AEAD to Rust #9399

Merged
merged 2 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 29 additions & 240 deletions src/cryptography/hazmat/backends/openssl/aead.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,15 @@
from cryptography.hazmat.primitives.ciphers.aead import (
AESCCM,
AESGCM,
ChaCha20Poly1305,
)

_AEADTypes = typing.Union[AESCCM, AESGCM, ChaCha20Poly1305]


def _is_evp_aead_supported_cipher(
backend: Backend, cipher: _AEADTypes
) -> bool:
"""
Checks whether the given cipher is supported through
EVP_AEAD rather than the normal OpenSSL EVP_CIPHER API.
"""
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305

return backend._lib.Cryptography_HAS_EVP_AEAD and isinstance(
cipher, ChaCha20Poly1305
)
_AEADTypes = typing.Union[AESCCM, AESGCM]


def _aead_cipher_supported(backend: Backend, cipher: _AEADTypes) -> bool:
if _is_evp_aead_supported_cipher(backend, cipher):
return True
else:
cipher_name = _evp_cipher_cipher_name(cipher)
if backend._fips_enabled and cipher_name not in backend._fips_aead:
return False
return (
backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL
)
cipher_name = _evp_cipher_cipher_name(cipher)


def _aead_create_ctx(
backend: Backend,
cipher: _AEADTypes,
key: bytes,
):
if _is_evp_aead_supported_cipher(backend, cipher):
return _evp_aead_create_ctx(backend, cipher, key)
else:
return _evp_cipher_create_ctx(backend, cipher, key)
return backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL


def _encrypt(
Expand All @@ -63,153 +31,24 @@ def _encrypt(
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any = None,
) -> bytes:
if _is_evp_aead_supported_cipher(backend, cipher):
return _evp_aead_encrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)
else:
return _evp_cipher_encrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)


def _decrypt(
backend: Backend,
cipher: _AEADTypes,
nonce: bytes,
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any = None,
) -> bytes:
if _is_evp_aead_supported_cipher(backend, cipher):
return _evp_aead_decrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)
else:
return _evp_cipher_decrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)


def _evp_aead_create_ctx(
backend: Backend,
cipher: _AEADTypes,
key: bytes,
tag_len: int | None = None,
):
aead_cipher = _evp_aead_get_cipher(backend, cipher)
assert aead_cipher is not None
key_ptr = backend._ffi.from_buffer(key)
tag_len = (
backend._lib.EVP_AEAD_DEFAULT_TAG_LENGTH
if tag_len is None
else tag_len
)
ctx = backend._lib.Cryptography_EVP_AEAD_CTX_new(
aead_cipher, key_ptr, len(key), tag_len
)
backend.openssl_assert(ctx != backend._ffi.NULL)
ctx = backend._ffi.gc(ctx, backend._lib.EVP_AEAD_CTX_free)
return ctx


def _evp_aead_get_cipher(backend: Backend, cipher: _AEADTypes):
from cryptography.hazmat.primitives.ciphers.aead import (
ChaCha20Poly1305,
)

# Currently only ChaCha20-Poly1305 is supported using this API
assert isinstance(cipher, ChaCha20Poly1305)
return backend._lib.EVP_aead_chacha20_poly1305()


def _evp_aead_encrypt(
backend: Backend,
cipher: _AEADTypes,
nonce: bytes,
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any,
) -> bytes:
assert ctx is not None

aead_cipher = _evp_aead_get_cipher(backend, cipher)
assert aead_cipher is not None

out_len = backend._ffi.new("size_t *")
# max_out_len should be in_len plus the result of
# EVP_AEAD_max_overhead.
max_out_len = len(data) + backend._lib.EVP_AEAD_max_overhead(aead_cipher)
out_buf = backend._ffi.new("uint8_t[]", max_out_len)
data_ptr = backend._ffi.from_buffer(data)
nonce_ptr = backend._ffi.from_buffer(nonce)
aad = b"".join(associated_data)
aad_ptr = backend._ffi.from_buffer(aad)

res = backend._lib.EVP_AEAD_CTX_seal(
ctx,
out_buf,
out_len,
max_out_len,
nonce_ptr,
len(nonce),
data_ptr,
len(data),
aad_ptr,
len(aad),
return _evp_cipher_encrypt(
backend, cipher, nonce, data, associated_data, tag_length
)
backend.openssl_assert(res == 1)
encrypted_data = backend._ffi.buffer(out_buf, out_len[0])[:]
return encrypted_data


def _evp_aead_decrypt(
def _decrypt(
backend: Backend,
cipher: _AEADTypes,
nonce: bytes,
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any,
) -> bytes:
if len(data) < tag_length:
raise InvalidTag

assert ctx is not None

out_len = backend._ffi.new("size_t *")
# max_out_len should at least in_len
max_out_len = len(data)
out_buf = backend._ffi.new("uint8_t[]", max_out_len)
data_ptr = backend._ffi.from_buffer(data)
nonce_ptr = backend._ffi.from_buffer(nonce)
aad = b"".join(associated_data)
aad_ptr = backend._ffi.from_buffer(aad)

res = backend._lib.EVP_AEAD_CTX_open(
ctx,
out_buf,
out_len,
max_out_len,
nonce_ptr,
len(nonce),
data_ptr,
len(data),
aad_ptr,
len(aad),
return _evp_cipher_decrypt(
backend, cipher, nonce, data, associated_data, tag_length
)

if res == 0:
backend._consume_errors()
raise InvalidTag

decrypted_data = backend._ffi.buffer(out_buf, out_len[0])[:]
return decrypted_data


_ENCRYPT = 1
_DECRYPT = 0
Expand All @@ -219,12 +58,9 @@ def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes:
from cryptography.hazmat.primitives.ciphers.aead import (
AESCCM,
AESGCM,
ChaCha20Poly1305,
)

if isinstance(cipher, ChaCha20Poly1305):
return b"chacha20-poly1305"
elif isinstance(cipher, AESCCM):
if isinstance(cipher, AESCCM):
return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii")
else:
assert isinstance(cipher, AESGCM)
Expand All @@ -237,29 +73,6 @@ def _evp_cipher(cipher_name: bytes, backend: Backend):
return evp_cipher


def _evp_cipher_create_ctx(
backend: Backend,
cipher: _AEADTypes,
key: bytes,
):
ctx = backend._lib.EVP_CIPHER_CTX_new()
backend.openssl_assert(ctx != backend._ffi.NULL)
ctx = backend._ffi.gc(ctx, backend._lib.EVP_CIPHER_CTX_free)
cipher_name = _evp_cipher_cipher_name(cipher)
evp_cipher = _evp_cipher(cipher_name, backend)
key_ptr = backend._ffi.from_buffer(key)
res = backend._lib.EVP_CipherInit_ex(
ctx,
evp_cipher,
backend._ffi.NULL,
key_ptr,
backend._ffi.NULL,
0,
)
backend.openssl_assert(res != 0)
return ctx


def _evp_cipher_aead_setup(
backend: Backend,
cipher_name: bytes,
Expand Down Expand Up @@ -323,21 +136,6 @@ def _evp_cipher_set_tag(backend, ctx, tag: bytes) -> None:
backend.openssl_assert(res != 0)


def _evp_cipher_set_nonce_operation(
backend, ctx, nonce: bytes, operation: int
) -> None:
nonce_ptr = backend._ffi.from_buffer(nonce)
res = backend._lib.EVP_CipherInit_ex(
ctx,
backend._ffi.NULL,
backend._ffi.NULL,
backend._ffi.NULL,
nonce_ptr,
int(operation == _ENCRYPT),
)
backend.openssl_assert(res != 0)


def _evp_cipher_set_length(backend: Backend, ctx, data_len: int) -> None:
intptr = backend._ffi.new("int *")
res = backend._lib.EVP_CipherUpdate(
Expand Down Expand Up @@ -373,23 +171,19 @@ def _evp_cipher_encrypt(
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any = None,
) -> bytes:
from cryptography.hazmat.primitives.ciphers.aead import AESCCM

if ctx is None:
cipher_name = _evp_cipher_cipher_name(cipher)
ctx = _evp_cipher_aead_setup(
backend,
cipher_name,
cipher._key,
nonce,
None,
tag_length,
_ENCRYPT,
)
else:
_evp_cipher_set_nonce_operation(backend, ctx, nonce, _ENCRYPT)
cipher_name = _evp_cipher_cipher_name(cipher)
ctx = _evp_cipher_aead_setup(
backend,
cipher_name,
cipher._key,
nonce,
None,
tag_length,
_ENCRYPT,
)

# CCM requires us to pass the length of the data before processing
# anything.
Expand Down Expand Up @@ -425,7 +219,6 @@ def _evp_cipher_decrypt(
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any = None,
) -> bytes:
from cryptography.hazmat.primitives.ciphers.aead import AESCCM

Expand All @@ -434,20 +227,16 @@ def _evp_cipher_decrypt(

tag = data[-tag_length:]
data = data[:-tag_length]
if ctx is None:
cipher_name = _evp_cipher_cipher_name(cipher)
ctx = _evp_cipher_aead_setup(
backend,
cipher_name,
cipher._key,
nonce,
tag,
tag_length,
_DECRYPT,
)
else:
_evp_cipher_set_nonce_operation(backend, ctx, nonce, _DECRYPT)
_evp_cipher_set_tag(backend, ctx, tag)
cipher_name = _evp_cipher_cipher_name(cipher)
ctx = _evp_cipher_aead_setup(
backend,
cipher_name,
cipher._key,
nonce,
tag,
tag_length,
_DECRYPT,
)

# CCM requires us to pass the length of the data before processing
# anything.
Expand Down
17 changes: 0 additions & 17 deletions src/cryptography/hazmat/bindings/_rust/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

import types
import typing

def check_pkcs7_padding(data: bytes) -> bool: ...
Expand All @@ -16,19 +15,3 @@ class ObjectIdentifier:
def _name(self) -> str: ...

T = typing.TypeVar("T")

class FixedPool(typing.Generic[T]):
def __init__(
self,
create: typing.Callable[[], T],
) -> None: ...
def acquire(self) -> PoolAcquisition[T]: ...

class PoolAcquisition(typing.Generic[T]):
def __enter__(self) -> T: ...
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None: ...
Loading
Loading