Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ✨ add multiple primary keys support #31

Merged
merged 16 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions fastcrud/crud/fast_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
JoinConfig,
)

from ..endpoint.helper import _get_primary_key
from ..endpoint.helper import _get_primary_keys

ModelType = TypeVar("ModelType", bound=DeclarativeBase)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
Expand Down Expand Up @@ -540,13 +540,15 @@ async def count(
primary_filters = self._parse_filters(**kwargs)

if joins_config is not None:
primary_key = _get_primary_key(self.model)
if not primary_key:
primary_keys = [p.name for p in _get_primary_keys(self.model)]
if not any(primary_keys):
raise ValueError(
f"The model '{self.model.__name__}' does not have a primary key defined, which is required for counting with joins."
)

base_query = select(getattr(self.model, primary_key).label("distinct_id"))
to_select = [
getattr(self.model, pk).label(f"distinct_{pk}") for pk in primary_keys
]
base_query = select(*to_select)

for join in joins_config:
join_model = join.alias or join.model
Expand Down
91 changes: 68 additions & 23 deletions fastcrud/endpoint/endpoint_creator.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,58 @@
from typing import Type, TypeVar, Optional, Callable, Sequence, Union
import inspect
from typing import Dict, Type, TypeVar, Optional, Callable, Sequence, Union
from enum import Enum

from fastapi import Depends, Body, Query, APIRouter, params
from pydantic import BaseModel, ValidationError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from pydantic import BaseModel, ValidationError

from ..exceptions.http_exceptions import NotFoundException
from ..crud.fast_crud import FastCRUD
from ..exceptions.http_exceptions import DuplicateValueException
from .helper import CRUDMethods, _get_primary_key, _extract_unique_columns
from ..paginated.response import paginated_response
from ..exceptions.http_exceptions import DuplicateValueException, NotFoundException
from ..paginated.helper import compute_offset
from ..paginated.response import paginated_response
from .helper import (
CRUDMethods,
_extract_unique_columns,
_get_primary_keys,
_get_python_type,
)

CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
UpdateSchemaInternalType = TypeVar("UpdateSchemaInternalType", bound=BaseModel)
DeleteSchemaType = TypeVar("DeleteSchemaType", bound=BaseModel)


def apply_model_pk(**pkeys: Dict[str, type]):
"""
This decorator injects positional arguments into a fastCRUD endpoint.
It dynamically changes the endpoint signature and allows to use
multiple primary keys without defining them explicitly.
"""

def wrapper(endpoint):
signature = inspect.signature(endpoint)
parameters = [
p
for p in signature.parameters.values()
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
]
extra_positional_params = [
inspect.Parameter(
name=k, annotation=v, kind=inspect.Parameter.POSITIONAL_ONLY
)
for k, v in pkeys.items()
]

endpoint.__signature__ = signature.replace(
parameters=extra_positional_params + parameters
)
return endpoint

return wrapper


class EndpointCreator:
"""
A class to create and register CRUD endpoints for a FastAPI application.
Expand Down Expand Up @@ -152,7 +186,11 @@ def __init__(
updated_at_column: str = "updated_at",
endpoint_names: Optional[dict[str, str]] = None,
) -> None:
self.primary_key_name = _get_primary_key(model)
self._primary_keys = _get_primary_keys(model)
self._primary_keys_types = {
pk.name: _get_python_type(pk) for pk in self._primary_keys
}
self.primary_key_names = [pk.name for pk in self._primary_keys]
self.session = session
self.crud = crud or FastCRUD(
model=model,
Expand Down Expand Up @@ -208,8 +246,9 @@ async def endpoint(
def _read_item(self):
"""Creates an endpoint for reading a single item from the database."""

async def endpoint(id: int, db: AsyncSession = Depends(self.session)):
item = await self.crud.get(db, id=id)
@apply_model_pk(**self._primary_keys_types)
async def endpoint(db: AsyncSession = Depends(self.session), **pkeys):
item = await self.crud.get(db, **pkeys)
if not item:
raise NotFoundException(detail="Item not found")
return item
Expand Down Expand Up @@ -254,20 +293,22 @@ async def endpoint(
def _update_item(self):
"""Creates an endpoint for updating an existing item in the database."""

@apply_model_pk(**self._primary_keys_types)
async def endpoint(
id: int,
item: self.update_schema = Body(...), # type: ignore
db: AsyncSession = Depends(self.session),
**pkeys,
):
return await self.crud.update(db, item, id=id)
return await self.crud.update(db, item, **pkeys)

return endpoint

def _delete_item(self):
"""Creates an endpoint for deleting an item from the database."""

async def endpoint(id: int, db: AsyncSession = Depends(self.session)):
await self.crud.delete(db, id=id)
@apply_model_pk(**self._primary_keys_types)
async def endpoint(db: AsyncSession = Depends(self.session), **pkeys):
await self.crud.delete(db, **pkeys)
return {"message": "Item deleted successfully"}

return endpoint
Expand All @@ -281,8 +322,9 @@ def _db_delete(self):
async session to permanently delete the item from the database.
"""

async def endpoint(id: int, db: AsyncSession = Depends(self.session)):
await self.crud.db_delete(db, id=id)
@apply_model_pk(**self._primary_keys_types)
async def endpoint(db: AsyncSession = Depends(self.session), **pkeys):
await self.crud.db_delete(db, **pkeys)
return {"message": "Item permanently deleted from the database"}

return endpoint
Expand Down Expand Up @@ -396,6 +438,8 @@ def get_current_user(...):
if self.delete_schema:
delete_description = "Soft delete a"

_primary_keys_path_suffix = "/".join(f"{{{n}}}" for n in self.primary_key_names)

if ("create" in included_methods) and ("create" not in deleted_methods):
endpoint_name = self._get_endpoint_name("create")
self.router.add_api_route(
Expand All @@ -410,14 +454,15 @@ def get_current_user(...):

if ("read" in included_methods) and ("read" not in deleted_methods):
endpoint_name = self._get_endpoint_name("read")

self.router.add_api_route(
f"{self.path}/{endpoint_name}/{{{self.primary_key_name}}}",
f"{self.path}/{endpoint_name}/{_primary_keys_path_suffix}",
self._read_item(),
methods=["GET"],
include_in_schema=self.include_in_schema,
tags=self.tags,
dependencies=read_deps,
description=f"Read a single {self.model.__name__} row from the database by its primary key: {self.primary_key_name}.",
description=f"Read a single {self.model.__name__} row from the database by its primary keys: {self.primary_key_names}.",
)

if ("read_multi" in included_methods) and ("read_multi" not in deleted_methods):
Expand Down Expand Up @@ -449,25 +494,25 @@ def get_current_user(...):
if ("update" in included_methods) and ("update" not in deleted_methods):
endpoint_name = self._get_endpoint_name("update")
self.router.add_api_route(
f"{self.path}/{endpoint_name}/{{{self.primary_key_name}}}",
f"{self.path}/{endpoint_name}/{_primary_keys_path_suffix}",
self._update_item(),
methods=["PATCH"],
include_in_schema=self.include_in_schema,
tags=self.tags,
dependencies=update_deps,
description=f"Update an existing {self.model.__name__} row in the database by its primary key: {self.primary_key_name}.",
description=f"Update an existing {self.model.__name__} row in the database by its primary keys: {self.primary_key_names}.",
)

if ("delete" in included_methods) and ("delete" not in deleted_methods):
endpoint_name = self._get_endpoint_name("delete")
self.router.add_api_route(
f"{self.path}/{endpoint_name}/{{{self.primary_key_name}}}",
f"{self.path}/{endpoint_name}/{_primary_keys_path_suffix}",
self._delete_item(),
methods=["DELETE"],
include_in_schema=self.include_in_schema,
tags=self.tags,
dependencies=delete_deps,
description=f"{delete_description} {self.model.__name__} row from the database by its primary key: {self.primary_key_name}.",
description=f"{delete_description} {self.model.__name__} row from the database by its primary keys: {self.primary_key_names}.",
)

if (
Expand All @@ -477,13 +522,13 @@ def get_current_user(...):
):
endpoint_name = self._get_endpoint_name("db_delete")
self.router.add_api_route(
f"{self.path}/{endpoint_name}/{{{self.primary_key_name}}}",
f"{self.path}/{endpoint_name}/{_primary_keys_path_suffix}",
self._db_delete(),
methods=["DELETE"],
include_in_schema=self.include_in_schema,
tags=self.tags,
dependencies=db_delete_deps,
description=f"Permanently delete a {self.model.__name__} row from the database by its primary key: {self.primary_key_name}.",
description=f"Permanently delete a {self.model.__name__} row from the database by its primary keys: {self.primary_key_names}.",
)

def add_custom_route(
Expand Down
23 changes: 20 additions & 3 deletions fastcrud/endpoint/helper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Union, Annotated, Sequence
from typing import Optional, Union, Annotated, Sequence
from pydantic import BaseModel, Field, ValidationError
from pydantic.functional_validators import field_validator

from sqlalchemy import inspect
from sqlalchemy import Column, inspect
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.sql.elements import KeyedColumnElement

Expand Down Expand Up @@ -43,10 +43,27 @@ def check_valid_method(cls, values: Sequence[str]) -> Sequence[str]:


def _get_primary_key(model: type[DeclarativeBase]) -> Union[str, None]:
return _get_primary_keys(model)[0].name


def _get_primary_keys(model: type[DeclarativeBase]) -> Sequence[Column]:
"""Get the primary key of a SQLAlchemy model."""
inspector = inspect(model).mapper
primary_key_columns = inspector.primary_key
return primary_key_columns[0].name if primary_key_columns else None

return primary_key_columns


def _get_python_type(column: Column) -> Optional[type]:
try:
return column.type.python_type
except NotImplementedError:
if hasattr(column.type, "impl") and hasattr(column.type.impl, "python_type"): # type: ignore
return column.type.impl.python_type # type: ignore
else:
raise NotImplementedError(
f"The primary key column {column.name} uses a custom type without a defined `python_type` or suitable `impl` fallback."
) # this could just warn and return the object as well if it's not that necessary: # logging.warning(f"Column {column.name} lacks a python_type and a suitable impl fallback.") # return object


def _extract_unique_columns(
Expand Down
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ authors = ["Igor Benav <igor.magalhaes.r@gmail.com>"]
license = "MIT"
readme = "README.md"
repository = "https://github.com/igorbenav/fastcrud"
include = [
"LICENSE",
]
include = ["LICENSE"]

classifiers = [
"Development Status :: 4 - Beta",
Expand All @@ -22,7 +20,7 @@ classifiers = [
"Programming Language :: Python :: 3.12",
"Operating System :: OS Independent",
"Framework :: FastAPI",
"Typing :: Typed"
"Typing :: Typed",
]

keywords = ["fastapi", "crud", "async", "sqlalchemy", "pydantic"]
Expand Down
Loading
Loading