Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add upsert_multi #119

Merged
merged 3 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/advanced/crud.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,29 @@ items = await crud_items.get_multi(db=db, limit=None)

To facilitate complex data relationships, `get_joined` and `get_multi_joined` can be configured to handle joins with multiple models. This is achieved using the `joins_config` parameter, where you can specify a list of `JoinConfig` instances, each representing a distinct join configuration.

## Upserting multiple records using `upsert_multi`

FastCRUD provides an `upsert_multi` method to efficiently upsert multiple records in a single operation. This method is particularly useful when you need to insert new records or update existing ones based on a unique constraint.

```python
from fastcrud import FastCRUD

from .models.item import Item
from .schemas.item import ItemCreateSchema
from .database import session as db

crud_items = FastCRUD(Item)
items = await crud_items.upsert_multi(
db=db,
instances=[
ItemCreateSchema(price=9.99),
],
schema_to_select=ItemSchema,
return_as_model=True,
)
# this will return the upserted data in the form of ItemSchema
```

#### Example: Joining `User`, `Tier`, and `Department` Models

Consider a scenario where you want to retrieve users along with their associated tier and department information. Here's how you can achieve this using `get_multi_joined`.
Expand Down
126 changes: 126 additions & 0 deletions fastcrud/crud/fast_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from pydantic import BaseModel, ValidationError
from sqlalchemy import (
Insert,
Result,
and_,
select,
update,
delete,
Expand All @@ -21,6 +23,7 @@
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql.elements import BinaryExpression, ColumnElement
from sqlalchemy.sql.selectable import Select
from sqlalchemy.dialects import postgresql, sqlite, mysql

from fastcrud.types import (
CreateSchemaType,
Expand Down Expand Up @@ -582,6 +585,129 @@ async def upsert(

return db_instance

async def upsert_multi(
self,
db: AsyncSession,
instances: list[Union[UpdateSchemaType, CreateSchemaType]],
return_columns: Optional[list[str]] = None,
schema_to_select: Optional[type[BaseModel]] = None,
return_as_model: bool = False,
**kwargs: Any,
) -> Optional[Dict[str, Any]]:
"""
Upsert multiple records in the database. The underlying implementation varies based on the database dialect.

Args:
db: The database session to use for the operation.
instances: A list of Pydantic schemas representing the instances to upsert.
return_columns: Optional list of column names to return after the upsert operation.
schema_to_select: Optional Pydantic schema for selecting specific columns. Required if return_as_model is True.
return_as_model: If True, returns data as instances of the specified Pydantic model.
**kwargs: Filters to identify the record(s) to update on conflict, supporting advanced comparison operators for refined querying.

Returns:
The updated record(s) as a dictionary or Pydantic model instance or None, depending on the value of `return_as_model` and `return_columns`.

Raises:
ValueError: If the MySQL dialect is used with filters, return_columns, schema_to_select, or return_as_model.
NotImplementedError: If the database dialect is not supported for upsert multi.
"""
filters = self._parse_filters(**kwargs)

if db.bind.dialect.name == "postgresql":
statement, params = await self._upsert_multi_postgresql(instances, filters)
elif db.bind.dialect.name == "sqlite":
statement, params = await self._upsert_multi_sqlite(instances, filters)
elif db.bind.dialect.name in ["mysql", "mariadb"]:
if filters:
raise ValueError(
"MySQL does not support filtering on insert operations."
)
if return_columns or schema_to_select or return_as_model:
raise ValueError(
"MySQL does not support the returning clause for insert operations."
)
statement, params = await self._upsert_multi_mysql(instances)
else:
raise NotImplementedError(
f"Upsert multi is not implemented for {db.bind.dialect.name}"
)

if return_as_model:
# All columns are returned to ensure the model can be constructed
return_columns = self.model_col_names

if return_columns:
statement = statement.returning(*[column(name) for name in return_columns])
db_row = await db.execute(statement, params)
return self._as_multi_response(
db_row,
schema_to_select=schema_to_select,
return_as_model=return_as_model,
)

await db.execute(statement, params)
return None

async def _upsert_multi_postgresql(
self,
instances: list[Union[UpdateSchemaType, CreateSchemaType]],
filters: list[ColumnElement],
) -> tuple[Insert, list[dict]]:
statement = postgresql.insert(self.model)
statement = statement.on_conflict_do_update(
index_elements=self._primary_keys,
set_={
column.name: getattr(statement.excluded, column.name)
for column in self.model.__table__.columns
if not column.primary_key and not column.unique
},
where=and_(*filters) if filters else None,
)
params = [
self.model(**instance.model_dump()).__dict__ for instance in instances
]
return statement, params

async def _upsert_multi_sqlite(
self,
instances: list[Union[UpdateSchemaType, CreateSchemaType]],
filters: list[ColumnElement],
) -> tuple[Insert, list[dict]]:
statement = sqlite.insert(self.model)
statement = statement.on_conflict_do_update(
index_elements=self._primary_keys,
set_={
column.name: getattr(statement.excluded, column.name)
for column in self.model.__table__.columns
if not column.primary_key and not column.unique
},
where=and_(*filters) if filters else None,
)
params = [
self.model(**instance.model_dump()).__dict__ for instance in instances
]
return statement, params

async def _upsert_multi_mysql(
self,
instances: list[Union[UpdateSchemaType, CreateSchemaType]],
) -> tuple[Insert, list[dict]]:
statement = mysql.insert(self.model)
statement = statement.on_duplicate_key_update(
{
column.name: getattr(statement.inserted, column.name)
for column in self.model.__table__.columns
if not column.primary_key
and not column.unique
and column.name != self.deleted_at_column
}
)
params = [
self.model(**instance.model_dump()).__dict__ for instance in instances
]
return statement, params

async def exists(self, db: AsyncSession, **kwargs: Any) -> bool:
"""
Checks if any records exist that match the given filter conditions.
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,16 @@ sqlmodel = "^0.0.14"
mypy = "^1.9.0"
ruff = "^0.3.4"
coverage = "^7.4.4"
testcontainers = "^4.7.1"
psycopg = "^3.2.1"
aiomysql = "^0.2.0"
cryptography = "^36.0.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
markers = [
"dialect(name): mark test to run only on specific SQL dialect",
]
Loading
Loading