Skip to content

Commit

Permalink
feat: add max length for encrypted string (#290)
Browse files Browse the repository at this point in the history
* feat: add max length for encrypted string

* fix: updated exceptions and added tests

* fix: skip mocks
  • Loading branch information
cofin authored Nov 15, 2024
1 parent 28c918e commit be8054e
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
- id: unasyncd
additional_dependencies: ["ruff"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.7.3"
rev: "v0.7.4"
hooks:
# Run the linter.
- id: ruff
Expand Down
9 changes: 7 additions & 2 deletions advanced_alchemy/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import re
from contextlib import contextmanager
from typing import Any, Callable, Generator, TypedDict, Union
from typing import Any, Callable, Generator, TypedDict, Union, cast

from sqlalchemy.exc import IntegrityError as SQLAlchemyIntegrityError
from sqlalchemy.exc import InvalidRequestError as SQLAlchemyInvalidRequestError
from sqlalchemy.exc import MultipleResultsFound, SQLAlchemyError
from sqlalchemy.exc import MultipleResultsFound, SQLAlchemyError, StatementError

from advanced_alchemy.utils.deprecation import deprecated

Expand Down Expand Up @@ -291,6 +291,7 @@ def wrap_sqlalchemy_exception(
"""
try:
yield

except MultipleResultsFound as exc:
if error_messages is not None:
msg = _get_error_message(error_messages=error_messages, key="multiple_rows", exc=exc)
Expand Down Expand Up @@ -318,6 +319,10 @@ def wrap_sqlalchemy_exception(
raise IntegrityError(detail=f"An integrity error occurred: {exc}") from exc
except SQLAlchemyInvalidRequestError as exc:
raise InvalidRequestError(detail="An invalid request was made.") from exc
except StatementError as exc:
raise IntegrityError(
detail=cast(str, getattr(exc.orig, "detail", "There was an issue processing the statement."))
) from exc
except SQLAlchemyError as exc:
if error_messages is not None:
msg = _get_error_message(error_messages=error_messages, key="other", exc=exc)
Expand Down
23 changes: 22 additions & 1 deletion advanced_alchemy/types/encrypted_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from sqlalchemy import String, Text, TypeDecorator
from sqlalchemy import func as sql_func

from advanced_alchemy.exceptions import IntegrityError

if TYPE_CHECKING:
from sqlalchemy.engine import Dialect

Expand Down Expand Up @@ -222,11 +224,13 @@ class EncryptedString(TypeDecorator[str]):
Args:
key (str | bytes | Callable[[], str | bytes] | None): The encryption key. Can be a string, bytes, or callable returning either. Defaults to os.urandom(32).
backend (Type[EncryptionBackend] | None): The encryption backend class to use. Defaults to FernetBackend.
length (int | None): The length of the unencrypted string. This is used for documentation and validation purposes only, as encrypted strings will be longer.
**kwargs (Any | None): Additional arguments passed to the underlying String type.
Attributes:
key (str | bytes | Callable[[], str | bytes]): The encryption key.
backend (EncryptionBackend): The encryption backend instance.
length (int | None): The unencrypted string length.
"""

impl = String
Expand All @@ -236,18 +240,21 @@ def __init__(
self,
key: str | bytes | Callable[[], str | bytes] = os.urandom(32),
backend: type[EncryptionBackend] = FernetBackend,
length: int | None = None,
**kwargs: Any,
) -> None:
"""Initializes the EncryptedString TypeDecorator.
Args:
key (str | bytes | Callable[[], str | bytes] | None): The encryption key. Can be a string, bytes, or callable returning either. Defaults to os.urandom(32).
backend (Type[EncryptionBackend] | None): The encryption backend class to use. Defaults to FernetBackend.
length (int | None): The length of the unencrypted string. This is used for documentation and validation purposes only.
**kwargs (Any | None): Additional arguments passed to the underlying String type.
"""
super().__init__()
self.key = key
self.backend = backend()
self.length = length

@property
def python_type(self) -> type[str]:
Expand All @@ -261,32 +268,46 @@ def python_type(self) -> type[str]:
def load_dialect_impl(self, dialect: Dialect) -> Any:
"""Loads the appropriate dialect implementation based on the database dialect.
Note: The actual column length will be larger than the specified length due to encryption overhead.
For most encryption methods, the encrypted string will be approximately 1.35x longer than the original.
Args:
dialect (Dialect): The SQLAlchemy dialect.
Returns:
Any: The dialect-specific type descriptor.
"""
if dialect.name in {"mysql", "mariadb"}:
# For MySQL/MariaDB, always use Text to avoid length limitations
return dialect.type_descriptor(Text())
if dialect.name == "oracle":
# Oracle has a 4000-byte limit for VARCHAR2 (by default)
return dialect.type_descriptor(String(length=4000))
return dialect.type_descriptor(String())

def process_bind_param(self, value: Any, dialect: Dialect) -> str | None:
"""Processes the value before binding it to the SQL statement.
This method encrypts the value using the specified backend.
This method encrypts the value using the specified backend and validates length if specified.
Args:
value (Any): The value to process.
dialect (Dialect): The SQLAlchemy dialect.
Returns:
str | None: The encrypted value or None if the input is None.
Raises:
ValueError: If the value exceeds the specified length.
"""
if value is None:
return value

# Validate length if specified
if self.length is not None and len(str(value)) > self.length:
msg = f"Unencrypted value exceeds maximum unencrypted length of {self.length}"
raise IntegrityError(msg)

self.mount_vault()
return self.backend.encrypt(value)

Expand Down
4 changes: 4 additions & 0 deletions tests/fixtures/bigint/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,7 @@ class BigIntSecret(BigIntBase):
long_secret: Mapped[str] = mapped_column(
EncryptedText(key="super_secret"),
)
length_validated_secret: Mapped[str] = mapped_column(
EncryptedString(key="super_secret", length=10),
nullable=True,
)
4 changes: 4 additions & 0 deletions tests/fixtures/uuid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class UUIDSecret(UUIDv7Base):
long_secret: Mapped[str] = mapped_column(
EncryptedText(key="super_secret"),
)
length_validated_secret: Mapped[str] = mapped_column(
EncryptedString(key="super_secret", length=10),
nullable=True,
)


class UUIDModelWithFetchedValue(UUIDv6Base):
Expand Down
30 changes: 30 additions & 0 deletions tests/integration/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
ModelWithFetchedValueRepository = SQLAlchemyAsyncRepository[AnyModelWithFetchedValue]
ModelWithFetchedValueService = SQLAlchemyAsyncRepositoryService[AnyModelWithFetchedValue]


RawRecordData = List[Dict[str, Any]]

mock_engines = {"mock_async_engine", "mock_sync_engine"}
Expand Down Expand Up @@ -1937,6 +1938,35 @@ async def test_repo_encrypted_methods(
assert obj.long_secret == updated.long_secret


async def test_encrypted_string_length_validation(
request: FixtureRequest, secret_repo: SecretRepository, secret_model: SecretModel
) -> None:
"""Test that EncryptedString enforces length validation.
Args:
secret_repo: The secret repository
secret_model: The secret model class
"""
if any(fixture in request.fixturenames for fixture in ["mock_async_engine", "mock_sync_engine"]):
pytest.skip(
f"{SQLAlchemyAsyncMockRepository.__name__} does not works with client side validated encrypted strings lengths"
)
# Test valid length
valid_secret = "AAAAAAAAA"
secret = secret_model(secret="test", long_secret="test", length_validated_secret=valid_secret)
saved_secret = await maybe_async(secret_repo.add(secret))
assert saved_secret.length_validated_secret == valid_secret

# Test exceeding length
long_secret = "A" * 51 # Exceeds 50 character limit
with pytest.raises(IntegrityError) as exc_info:
secret = secret_model(secret="test", long_secret="test", length_validated_secret=long_secret)
await maybe_async(secret_repo.add(secret))

assert exc_info.value.__class__.__name__ == "IntegrityError"
assert "exceeds maximum unencrypted length" in str(exc_info.value.detail)


# service tests
async def test_service_filter_search(author_service: AuthorService) -> None:
existing_obj = await maybe_async(
Expand Down
Loading

0 comments on commit be8054e

Please sign in to comment.