Skip to content

Commit

Permalink
assume algorithm can only be aes or rsa, refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
erikvw committed Mar 19, 2024
1 parent 23471d7 commit fbf906c
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 94 deletions.
4 changes: 4 additions & 0 deletions django_crypto_fields/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,7 @@ class EncryptionLookupError(Exception):

class MalformedCiphertextError(Exception):
pass


class InvalidEncryptionAlgorithm(Exception):
pass
52 changes: 27 additions & 25 deletions django_crypto_fields/field_cryptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@
SALT,
)
from .cryptor import Cryptor
from .exceptions import CipherError, EncryptionError, EncryptionKeyError
from .exceptions import (
CipherError,
EncryptionError,
EncryptionKeyError,
InvalidEncryptionAlgorithm,
)
from .keys import encryption_keys
from .utils import (
get_crypt_model_cls,
has_valid_value_or_raise,
is_valid_ciphertext_or_raise,
safe_decode,
safe_encode_utf8,
)

Expand All @@ -49,6 +55,7 @@ class FieldCryptor:

def __init__(self, algorithm: str, access_mode: str):
self._using = None
self._algorithm = None
self.algorithm = algorithm
self.access_mode = access_mode
self.aes_encryption_mode = AES_CIPHER.MODE_CBC
Expand All @@ -61,6 +68,18 @@ def __init__(self, algorithm: str, access_mode: str):
def __repr__(self) -> str:
return f"FieldCryptor(algorithm='{self.algorithm}', mode='{self.access_mode}')"

@property
def algorithm(self):
return self._algorithm

@algorithm.setter
def algorithm(self, value):
self._algorithm = value
if value not in [AES, RSA]:
raise InvalidEncryptionAlgorithm(
f"Invalid encryption algorithm. Expected 'aes' or 'rsa'. Got {value}"
)

@property
def salt_key(self):
attr = "_".join([SALT, self.access_mode, PRIVATE])
Expand Down Expand Up @@ -123,13 +142,6 @@ def decrypt(self, hash_with_prefix: str):
plaintext = self.cryptor.aes_decrypt(secret, self.access_mode)
elif self.algorithm == RSA:
plaintext = self.cryptor.rsa_decrypt(secret, self.access_mode)
else:
raise CipherError(
"Cannot determine algorithm for decryption."
" Valid options are {0}. Got {1}".format(
", ".join(list(self.keys.key_filenames)), self.algorithm
)
)
return plaintext

@property
Expand Down Expand Up @@ -171,10 +183,7 @@ def get_prep_value(self, value: str | bytes | None) -> str | bytes | None:
else:
ciphertext = self.encrypt(value)
value = ciphertext.split(CIPHER_PREFIX.encode(ENCODING))[0]
try:
value.decode()
except AttributeError:
pass
value = safe_decode(value)
return value

def get_ciphertext(self, value):
Expand All @@ -183,19 +192,12 @@ def get_ciphertext(self, value):
cipher = self.cryptor.aes_encrypt
elif self.algorithm == RSA:
cipher = self.cryptor.rsa_encrypt
try:
ciphertext = (
HASH_PREFIX.encode(ENCODING)
+ self.hash(value)
+ CIPHER_PREFIX.encode(ENCODING)
+ cipher(value, self.access_mode)
)
except AttributeError as e:
raise CipherError(
"Cannot determine cipher method. Unknown "
"encryption algorithm. Valid options are {0}. "
"Got {1} ({2})".format(", ".join(self.keys.key_filenames), self.algorithm, e)
)
ciphertext = (
HASH_PREFIX.encode(ENCODING)
+ self.hash(value)
+ CIPHER_PREFIX.encode(ENCODING)
+ cipher(value, self.access_mode)
)
return is_valid_ciphertext_or_raise(ciphertext, self.hash_size)

def get_hash(self, ciphertext: bytes) -> bytes | None:
Expand Down
48 changes: 28 additions & 20 deletions django_crypto_fields/fields/encrypted_decimal_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,31 @@ class EncryptedDecimalField(BaseRsaField):
description = "local-rsa encrypted field for 'IntegerField'"

def __init__(self, *args, **kwargs):
self.validate_max_digits(kwargs)
self.validate_decimal_places(kwargs)
decimal_decimal_places = int(kwargs.get("decimal_places"))
decimal_max_digits = int(kwargs.get("max_digits"))
del kwargs["decimal_places"]
del kwargs["max_digits"]
super().__init__(*args, **kwargs)
self.decimal_decimal_places = decimal_decimal_places
self.decimal_max_digits = decimal_max_digits

def to_string(self, value):
if isinstance(value, (str,)):
raise TypeError("Expected basestring. Got {0}".format(value))
return str(value)

def to_python(self, value):
"""Returns as integer"""
retval = super(EncryptedDecimalField, self).to_python(value)
if retval:
if not self.field_cryptor.is_encrypted(retval):
retval = Decimal(retval).to_eng_string()
return retval

@staticmethod
def validate_max_digits(kwargs):
if "max_digits" not in kwargs:
raise AttributeError(
"EncryptedDecimalField requires attribute 'max_digits. " "Got none"
Expand All @@ -19,6 +44,9 @@ def __init__(self, *args, **kwargs):
f"EncryptedDecimalField attribute 'max_digits must be an "
f'integer. Got {kwargs.get("max_digits")}'
)

@staticmethod
def validate_decimal_places(kwargs):
if "decimal_places" not in kwargs:
raise AttributeError(
"EncryptedDecimalField requires attribute 'decimal_places. " "Got none"
Expand All @@ -31,23 +59,3 @@ def __init__(self, *args, **kwargs):
f"EncryptedDecimalField attribute 'decimal_places must be an "
f'integer. Got {kwargs.get("decimal_places")}'
)
decimal_decimal_places = int(kwargs.get("decimal_places"))
decimal_max_digits = int(kwargs.get("max_digits"))
del kwargs["decimal_places"]
del kwargs["max_digits"]
super().__init__(*args, **kwargs)
self.decimal_decimal_places = decimal_decimal_places
self.decimal_max_digits = decimal_max_digits

def to_string(self, value):
if isinstance(value, (str,)):
raise TypeError("Expected basestring. Got {0}".format(value))
return str(value)

def to_python(self, value):
"""Returns as integer"""
retval = super(EncryptedDecimalField, self).to_python(value)
if retval:
if not self.field_cryptor.is_encrypted(retval):
retval = Decimal(retval).to_eng_string()
return retval
21 changes: 13 additions & 8 deletions django_crypto_fields/key_path/key_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,7 @@ class KeyPath:
def __post_init__(self):
path = get_keypath_from_settings()
if not path:
if get_test_module_from_settings() in sys.argv:
path = mkdtemp()
else:
raise DjangoCryptoFieldsKeyPathError(
"Path may not be none. Production or debug systems must explicitly "
"set a valid path to the encryption keys. "
"See settings.DJANGO_CRYPTO_FIELDS_KEY_PATH."
)
path = self.create_folder_for_tests_or_raise()
elif not Path(path).exists():
raise DjangoCryptoFieldsKeyPathDoesNotExist(
"Path to encryption keys does not exist. "
Expand All @@ -59,3 +52,15 @@ def __post_init__(self):

def __str__(self) -> str:
return str(self.path)

@staticmethod
def create_folder_for_tests_or_raise() -> PurePath:
if get_test_module_from_settings() in sys.argv:
path = PurePath(mkdtemp())
else:
raise DjangoCryptoFieldsKeyPathError(
"Path may not be none. Production or debug systems must explicitly "
"set a valid path to the encryption keys. "
"See settings.DJANGO_CRYPTO_FIELDS_KEY_PATH."
)
return path
42 changes: 23 additions & 19 deletions django_crypto_fields/keys/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,7 @@ def initialize(self):
self.keys = deepcopy(self.template)
persist_key_path_or_raise()
if not key_files_exist(self.path, self.key_prefix):
if auto_create_keys := get_auto_create_keys_from_settings():
if not os.access(self.path, os.W_OK):
raise DjangoCryptoFieldsError(
"Cannot auto-create encryption keys. Folder is not writeable."
f"Got {self.path}"
)
write_msg(
self.verbose,
style.SUCCESS(f" * settings.AUTO_CREATE_KEYS={auto_create_keys}.\n"),
)
self.create()
else:
raise DjangoCryptoFieldsKeysDoNotExist(
f"Failed to find any encryption keys in path {self.path}. "
"If this is your first time loading "
"the project, set settings.AUTO_CREATE_KEYS=True and restart. "
"Make sure the folder is writeable."
)
self.create_new_keys_or_raise()
self.load_keys()
self.rsa_modes_supported = sorted([k for k in self.keys[RSA]])
self.aes_modes_supported = sorted([k for k in self.keys[AES]])
Expand All @@ -98,7 +81,28 @@ def reset_and_delete_keys(self, verbose: bool | None = None):
def get(self, k: str):
return self.keys.get(k)

def create(self) -> None:
def create_new_keys_or_raise(self):
"""Calls create after checking if allowed."""
if auto_create_keys := get_auto_create_keys_from_settings():
if not os.access(self.path, os.W_OK):
raise DjangoCryptoFieldsError(
"Cannot auto-create encryption keys. Folder is not writeable."
f"Got {self.path}"
)
write_msg(
self.verbose,
style.SUCCESS(f" * settings.AUTO_CREATE_KEYS={auto_create_keys}.\n"),
)
self._create()
else:
raise DjangoCryptoFieldsKeysDoNotExist(
f"Failed to find any encryption keys in path {self.path}. "
"If this is your first time loading "
"the project, set settings.AUTO_CREATE_KEYS=True and restart. "
"Make sure the folder is writeable."
)

def _create(self) -> None:
"""Generates RSA and AES keys as per `filenames`."""
if key_files_exist(self.path, self.key_prefix):
raise DjangoCryptoFieldsKeyAlreadyExist(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
path,date
/Users/erikvw/source/edc_source/django-crypto-fields/django_crypto_fields/tests/crypto_keys,2024-03-19 03:29:04.747127+00:00
/Users/erikvw/source/edc_source/django-crypto-fields/django_crypto_fields/tests/crypto_keys,2024-03-19 04:06:39.385684+00:00
2 changes: 1 addition & 1 deletion django_crypto_fields/tests/tests/test_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_invalid_production_path_raises(self):
def test_create_keys_does_not_overwrite_production_keys(self):
keys = Keys(verbose=False)
keys.reset()
self.assertRaises(DjangoCryptoFieldsKeyAlreadyExist, keys.create)
self.assertRaises(DjangoCryptoFieldsKeyAlreadyExist, keys.create_new_keys_or_raise)

@override_settings(
DEBUG=False,
Expand Down
36 changes: 16 additions & 20 deletions django_crypto_fields/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def safe_encode_utf8(value) -> bytes:
return value


def safe_decode(value) -> bytes:
try:
value.decode()
except AttributeError:
pass
return value


def has_valid_hash_or_raise(ciphertext: bytes, hash_size: int) -> bool:
"""Verifies hash segment of ciphertext (bytes) and
raises an exception if not OK.
Expand Down Expand Up @@ -138,11 +146,11 @@ def has_valid_value_or_raise(
raise MalformedCiphertextError("Expected a value, got just the encryption prefix.")
has_valid_hash_or_raise(encoded_value, hash_size)
if has_secret:
is_valid_ciphertext_or_raise(encoded_value, hash_size)
is_valid_ciphertext_or_raise(encoded_value)
return value # note, is original passed value


def is_valid_ciphertext_or_raise(ciphertext: bytes, hash_size: int):
def is_valid_ciphertext_or_raise(ciphertext: bytes, hash_size: int | None = None):
"""Returns an unchanged ciphertext after verifying format cipher_prefix +
hash + cipher_prefix + secret.
"""
Expand All @@ -158,22 +166,10 @@ def is_valid_ciphertext_or_raise(ciphertext: bytes, hash_size: int):
raise MalformedCiphertextError(
f"Malformed ciphertext. Expected prefixes {CIPHER_PREFIX}"
)
try:
if ciphertext[: len(HASH_PREFIX)] != HASH_PREFIX.encode(ENCODING):
raise MalformedCiphertextError(
f"Malformed ciphertext. Expected hash prefix {HASH_PREFIX}"
)
if (
len(
ciphertext.split(HASH_PREFIX.encode(ENCODING))[1].split(
CIPHER_PREFIX.encode(ENCODING)
)[0]
)
!= hash_size
):
raise MalformedCiphertextError(
f"Malformed ciphertext. Expected hash size of {hash_size}."
)
except IndexError:
MalformedCiphertextError("Malformed ciphertext.")
if ciphertext[: len(HASH_PREFIX)] != HASH_PREFIX.encode(ENCODING):
raise MalformedCiphertextError(
f"Malformed ciphertext. Expected hash prefix {HASH_PREFIX}"
)
if hash_size is not None:
has_valid_hash_or_raise(ciphertext, hash_size)
return ciphertext

0 comments on commit fbf906c

Please sign in to comment.