diff --git a/docs/advanced/crud.md b/docs/advanced/crud.md index 26e37ec..eba46f2 100644 --- a/docs/advanced/crud.md +++ b/docs/advanced/crud.md @@ -253,6 +253,29 @@ items = await crud_items.get_multi(db=db, limit=None) To facilitate complex data relationships, `get_joined` and `get_multi_joined` can be configured to handle joins with multiple models. This is achieved using the `joins_config` parameter, where you can specify a list of `JoinConfig` instances, each representing a distinct join configuration. +## Upserting multiple records using `upsert_multi` + +FastCRUD provides an `upsert_multi` method to efficiently upsert multiple records in a single operation. This method is particularly useful when you need to insert new records or update existing ones based on a unique constraint. + +```python +from fastcrud import FastCRUD + +from .models.item import Item +from .schemas.item import ItemCreateSchema +from .database import session as db + +crud_items = FastCRUD(Item) +items = await crud_items.upsert_multi( + db=db, + instances=[ + ItemCreateSchema(price=9.99), + ], + schema_to_select=ItemSchema, + return_as_model=True, +) +# this will return the upserted data in the form of ItemSchema +``` + #### Example: Joining `User`, `Tier`, and `Department` Models Consider a scenario where you want to retrieve users along with their associated tier and department information. Here's how you can achieve this using `get_multi_joined`. diff --git a/fastcrud/crud/fast_crud.py b/fastcrud/crud/fast_crud.py index bd37e3b..e3b5358 100644 --- a/fastcrud/crud/fast_crud.py +++ b/fastcrud/crud/fast_crud.py @@ -3,7 +3,9 @@ from pydantic import BaseModel, ValidationError from sqlalchemy import ( + Insert, Result, + and_, select, update, delete, @@ -21,6 +23,7 @@ from sqlalchemy.orm.util import AliasedClass from sqlalchemy.sql.elements import BinaryExpression, ColumnElement from sqlalchemy.sql.selectable import Select +from sqlalchemy.dialects import postgresql, sqlite, mysql from fastcrud.types import ( CreateSchemaType, @@ -582,6 +585,129 @@ async def upsert( return db_instance + async def upsert_multi( + self, + db: AsyncSession, + instances: list[Union[UpdateSchemaType, CreateSchemaType]], + return_columns: Optional[list[str]] = None, + schema_to_select: Optional[type[BaseModel]] = None, + return_as_model: bool = False, + **kwargs: Any, + ) -> Optional[Dict[str, Any]]: + """ + Upsert multiple records in the database. The underlying implementation varies based on the database dialect. + + Args: + db: The database session to use for the operation. + instances: A list of Pydantic schemas representing the instances to upsert. + return_columns: Optional list of column names to return after the upsert operation. + schema_to_select: Optional Pydantic schema for selecting specific columns. Required if return_as_model is True. + return_as_model: If True, returns data as instances of the specified Pydantic model. + **kwargs: Filters to identify the record(s) to update on conflict, supporting advanced comparison operators for refined querying. + + Returns: + The updated record(s) as a dictionary or Pydantic model instance or None, depending on the value of `return_as_model` and `return_columns`. + + Raises: + ValueError: If the MySQL dialect is used with filters, return_columns, schema_to_select, or return_as_model. + NotImplementedError: If the database dialect is not supported for upsert multi. + """ + filters = self._parse_filters(**kwargs) + + if db.bind.dialect.name == "postgresql": + statement, params = await self._upsert_multi_postgresql(instances, filters) + elif db.bind.dialect.name == "sqlite": + statement, params = await self._upsert_multi_sqlite(instances, filters) + elif db.bind.dialect.name in ["mysql", "mariadb"]: + if filters: + raise ValueError( + "MySQL does not support filtering on insert operations." + ) + if return_columns or schema_to_select or return_as_model: + raise ValueError( + "MySQL does not support the returning clause for insert operations." + ) + statement, params = await self._upsert_multi_mysql(instances) + else: + raise NotImplementedError( + f"Upsert multi is not implemented for {db.bind.dialect.name}" + ) + + if return_as_model: + # All columns are returned to ensure the model can be constructed + return_columns = self.model_col_names + + if return_columns: + statement = statement.returning(*[column(name) for name in return_columns]) + db_row = await db.execute(statement, params) + return self._as_multi_response( + db_row, + schema_to_select=schema_to_select, + return_as_model=return_as_model, + ) + + await db.execute(statement, params) + return None + + async def _upsert_multi_postgresql( + self, + instances: list[Union[UpdateSchemaType, CreateSchemaType]], + filters: list[ColumnElement], + ) -> tuple[Insert, list[dict]]: + statement = postgresql.insert(self.model) + statement = statement.on_conflict_do_update( + index_elements=self._primary_keys, + set_={ + column.name: getattr(statement.excluded, column.name) + for column in self.model.__table__.columns + if not column.primary_key and not column.unique + }, + where=and_(*filters) if filters else None, + ) + params = [ + self.model(**instance.model_dump()).__dict__ for instance in instances + ] + return statement, params + + async def _upsert_multi_sqlite( + self, + instances: list[Union[UpdateSchemaType, CreateSchemaType]], + filters: list[ColumnElement], + ) -> tuple[Insert, list[dict]]: + statement = sqlite.insert(self.model) + statement = statement.on_conflict_do_update( + index_elements=self._primary_keys, + set_={ + column.name: getattr(statement.excluded, column.name) + for column in self.model.__table__.columns + if not column.primary_key and not column.unique + }, + where=and_(*filters) if filters else None, + ) + params = [ + self.model(**instance.model_dump()).__dict__ for instance in instances + ] + return statement, params + + async def _upsert_multi_mysql( + self, + instances: list[Union[UpdateSchemaType, CreateSchemaType]], + ) -> tuple[Insert, list[dict]]: + statement = mysql.insert(self.model) + statement = statement.on_duplicate_key_update( + { + column.name: getattr(statement.inserted, column.name) + for column in self.model.__table__.columns + if not column.primary_key + and not column.unique + and column.name != self.deleted_at_column + } + ) + params = [ + self.model(**instance.model_dump()).__dict__ for instance in instances + ] + return statement, params + async def exists(self, db: AsyncSession, **kwargs: Any) -> bool: """ Checks if any records exist that match the given filter conditions. diff --git a/pyproject.toml b/pyproject.toml index 3e73fb5..ca18513 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,16 @@ sqlmodel = "^0.0.14" mypy = "^1.9.0" ruff = "^0.3.4" coverage = "^7.4.4" +testcontainers = "^4.7.1" +psycopg = "^3.2.1" +aiomysql = "^0.2.0" +cryptography = "^36.0.0" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +markers = [ + "dialect(name): mark test to run only on specific SQL dialect", +] diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py index 9221223..3377d41 100644 --- a/tests/sqlalchemy/conftest.py +++ b/tests/sqlalchemy/conftest.py @@ -1,16 +1,27 @@ from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from typing import Optional from datetime import datetime import pytest import pytest_asyncio -from sqlalchemy import Column, Integer, String, ForeignKey, Boolean, DateTime +from sqlalchemy import ( + Column, + Integer, + String, + ForeignKey, + Boolean, + DateTime, + make_url, +) from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker, DeclarativeBase, relationship from pydantic import BaseModel, ConfigDict from fastapi import FastAPI from fastapi.testclient import TestClient from sqlalchemy.sql import func +from testcontainers.postgres import PostgresContainer +from testcontainers.mysql import MySqlContainer from fastcrud.crud.fast_crud import FastCRUD from fastcrud.endpoint.crud_router import crud_router @@ -25,7 +36,7 @@ class MultiPkModel(Base): __tablename__ = "multi_pk" id = Column(Integer, primary_key=True) uuid = Column(String(32), primary_key=True) - name = Column(String, unique=True) + name = Column(String(32), unique=True) test_id = Column(Integer, ForeignKey("test.id")) test = relationship("ModelTest", back_populates="multi_pk") @@ -34,13 +45,13 @@ class CategoryModel(Base): __tablename__ = "category" tests = relationship("ModelTest", back_populates="category") id = Column(Integer, primary_key=True) - name = Column(String, unique=True) + name = Column(String(32), unique=True) class ModelTest(Base): __tablename__ = "test" id = Column(Integer, primary_key=True) - name = Column(String) + name = Column(String(32)) tier_id = Column(Integer, ForeignKey("tier.id")) category_id = Column( Integer, ForeignKey("category.id"), nullable=True, default=None @@ -55,7 +66,7 @@ class ModelTest(Base): class ModelTestWithTimestamp(Base): __tablename__ = "model_test_with_timestamp" id = Column(Integer, primary_key=True) - name = Column(String) + name = Column(String(32)) tier_id = Column(Integer, ForeignKey("tier.id")) category_id = Column( Integer, ForeignKey("category.id"), nullable=True, default=None @@ -70,7 +81,7 @@ class ModelTestWithTimestamp(Base): class TierModel(Base): __tablename__ = "tier" id = Column(Integer, primary_key=True) - name = Column(String, unique=True) + name = Column(String(32), unique=True) tests = relationship("ModelTest", back_populates="tier") @@ -87,8 +98,8 @@ class BookingModel(Base): class Project(Base): __tablename__ = "projects" id = Column(Integer, primary_key=True) - name = Column(String, nullable=False) - description = Column(String) + name = Column(String(32), nullable=False) + description = Column(String(32)) participants = relationship( "Participant", secondary="projects_participants_association", @@ -99,8 +110,8 @@ class Project(Base): class Participant(Base): __tablename__ = "participants" id = Column(Integer, primary_key=True) - name = Column(String, nullable=False) - role = Column(String) + name = Column(String(32), nullable=False) + role = Column(String(32)) projects = relationship( "Project", secondary="projects_participants_association", @@ -117,13 +128,13 @@ class ProjectsParticipantsAssociation(Base): class Card(Base): __tablename__ = "cards" id = Column(Integer, primary_key=True) - title = Column(String) + title = Column(String(32)) class Article(Base): __tablename__ = "articles" id = Column(Integer, primary_key=True) - title = Column(String) + title = Column(String(32)) card_id = Column(Integer, ForeignKey("cards.id")) card = relationship("Card", back_populates="articles") @@ -134,26 +145,26 @@ class Article(Base): class Client(Base): __tablename__ = "clients" id = Column(Integer, primary_key=True) - name = Column(String, nullable=False) - contact = Column(String, nullable=False) - phone = Column(String, nullable=False) - email = Column(String, nullable=False) + name = Column(String(32), nullable=False) + contact = Column(String(32), nullable=False) + phone = Column(String(32), nullable=False) + email = Column(String(32), nullable=False) class Department(Base): __tablename__ = "departments" id = Column(Integer, primary_key=True) - name = Column(String, nullable=False) + name = Column(String(32), nullable=False) class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) - name = Column(String, nullable=False) - username = Column(String, nullable=False, unique=True) - email = Column(String, nullable=False, unique=True) - phone = Column(String, nullable=True) - profile_image_url = Column(String, nullable=True) + name = Column(String(32), nullable=False) + username = Column(String(32), nullable=False, unique=True) + email = Column(String(32), nullable=False, unique=True) + phone = Column(String(32), nullable=True) + profile_image_url = Column(String(32), nullable=True) department_id = Column(Integer, ForeignKey("departments.id"), nullable=True) company_id = Column(Integer, ForeignKey("clients.id"), nullable=True) department = relationship("Department", backref="users") @@ -163,8 +174,8 @@ class User(Base): class Task(Base): __tablename__ = "tasks" id = Column(Integer, primary_key=True) - name = Column(String, nullable=False) - description = Column(String, nullable=True) + name = Column(String(32), nullable=False) + description = Column(String(32), nullable=True) client_id = Column(Integer, ForeignKey("clients.id"), nullable=True) department_id = Column(Integer, ForeignKey("departments.id"), nullable=True) assignee_id = Column(Integer, ForeignKey("users.id"), nullable=True) @@ -273,13 +284,10 @@ class TaskRead(TaskReadSub): client: Optional[ClientRead] -async_engine = create_async_engine( - "sqlite+aiosqlite:///:memory:", echo=True, future=True -) - +@asynccontextmanager +async def _async_session(url: str) -> AsyncGenerator[AsyncSession]: + async_engine = create_async_engine(url, echo=True, future=True) -@pytest_asyncio.fixture(scope="function") -async def async_session() -> AsyncGenerator[AsyncSession]: session = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) async with session() as s: @@ -294,6 +302,31 @@ async def async_session() -> AsyncGenerator[AsyncSession]: await async_engine.dispose() +@pytest_asyncio.fixture(scope="function") +async def async_session(request: pytest.FixtureRequest) -> AsyncGenerator[AsyncSession]: + dialect_marker = request.node.get_closest_marker("dialect") + dialect = dialect_marker.args[0] if dialect_marker else "sqlite" + if dialect == "postgresql": + with PostgresContainer(driver="psycopg") as pg: + async with _async_session( + url=pg.get_connection_url(host=pg.get_container_host_ip()) + ) as session: + yield session + elif dialect == "sqlite": + async with _async_session(url="sqlite+aiosqlite:///:memory:") as session: + yield session + elif dialect == "mysql": + with MySqlContainer() as mysql: + async with _async_session( + url=make_url(name_or_url=mysql.get_connection_url())._replace( + drivername="mysql+aiomysql" + ) + ) as session: + yield session + else: + raise ValueError(f"Unsupported dialect: {dialect}") + + @pytest.fixture(scope="function") def test_data() -> list[dict]: return [ diff --git a/tests/sqlalchemy/crud/test_upsert.py b/tests/sqlalchemy/crud/test_upsert.py index 7fdaba2..5cfedb3 100644 --- a/tests/sqlalchemy/crud/test_upsert.py +++ b/tests/sqlalchemy/crud/test_upsert.py @@ -1,6 +1,7 @@ import pytest from fastcrud.crud.fast_crud import FastCRUD +from tests.sqlalchemy.conftest import CategoryModel, ReadSchemaTest, TierModel @pytest.mark.asyncio @@ -14,3 +15,295 @@ async def test_upsert_successful(async_session, test_model, read_schema): updated_fetched_record = await crud.upsert(async_session, fetched_record) assert read_schema.model_validate(updated_fetched_record) == fetched_record + + +@pytest.mark.parametrize( + ["insert", "update"], + [ + pytest.param( + { + "kwargs": {}, + "expected_result": None, + }, + { + "kwargs": {}, + "expected_result": None, + }, + marks=pytest.mark.dialect("postgresql"), + id="postgresql-none", + ), + pytest.param( + { + "kwargs": {"return_columns": ["id", "name"]}, + "expected_result": { + "data": [ + { + "id": 1, + "name": "New Record", + } + ] + }, + }, + { + "kwargs": {"return_columns": ["id", "name"]}, + "expected_result": { + "data": [ + { + "id": 1, + "name": "New name", + } + ] + }, + }, + marks=pytest.mark.dialect("postgresql"), + id="postgresql-dict", + ), + pytest.param( + { + "kwargs": {"return_columns": ["id", "name"]}, + "expected_result": { + "data": [ + { + "id": 1, + "name": "New Record", + } + ] + }, + }, + { + "kwargs": { + "return_columns": ["id", "name"], + "name__match": "NewRecord", + }, + "expected_result": {"data": []}, + }, + marks=pytest.mark.dialect("postgresql"), + id="postgresql-dict-filtered", + ), + pytest.param( + { + "kwargs": { + "schema_to_select": ReadSchemaTest, + "return_as_model": True, + }, + "expected_result": { + "data": [ + ReadSchemaTest( + id=1, name="New Record", tier_id=1, category_id=1 + ) + ] + }, + }, + { + "kwargs": { + "schema_to_select": ReadSchemaTest, + "return_as_model": True, + }, + "expected_result": { + "data": [ + ReadSchemaTest(id=1, name="New name", tier_id=1, category_id=1) + ] + }, + }, + marks=pytest.mark.dialect("postgresql"), + id="postgresql-model", + ), + pytest.param( + { + "kwargs": {}, + "expected_result": None, + }, + { + "kwargs": {}, + "expected_result": None, + }, + marks=pytest.mark.dialect("sqlite"), + id="sqlite-none", + ), + pytest.param( + { + "kwargs": {"return_columns": ["id", "name"]}, + "expected_result": { + "data": [ + { + "id": 1, + "name": "New Record", + } + ] + }, + }, + { + "kwargs": {"return_columns": ["id", "name"]}, + "expected_result": { + "data": [ + { + "id": 1, + "name": "New name", + } + ] + }, + }, + marks=pytest.mark.dialect("sqlite"), + id="sqlite-dict", + ), + pytest.param( + { + "kwargs": {"return_columns": ["id", "name"]}, + "expected_result": { + "data": [ + { + "id": 1, + "name": "New Record", + } + ] + }, + }, + { + "kwargs": { + "return_columns": ["id", "name"], + "name__like": "NewRecord", + }, + "expected_result": {"data": []}, + }, + marks=pytest.mark.dialect("sqlite"), + id="sqlite-dict-filtered", + ), + pytest.param( + { + "kwargs": { + "schema_to_select": ReadSchemaTest, + "return_as_model": True, + }, + "expected_result": { + "data": [ + ReadSchemaTest( + id=1, name="New Record", tier_id=1, category_id=1 + ) + ] + }, + }, + { + "kwargs": { + "schema_to_select": ReadSchemaTest, + "return_as_model": True, + }, + "expected_result": { + "data": [ + ReadSchemaTest(id=1, name="New name", tier_id=1, category_id=1) + ] + }, + }, + marks=pytest.mark.dialect("sqlite"), + id="sqlite-model", + ), + pytest.param( + { + "kwargs": {}, + "expected_result": None, + }, + { + "kwargs": {}, + "expected_result": None, + }, + marks=pytest.mark.dialect("mysql"), + id="mysql-none", + ), + ], +) +@pytest.mark.asyncio +async def test_upsert_multi_successful( + async_session, + test_model, + read_schema, + test_data_tier, + test_data_category, + insert, + update, +): + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + for category_item in test_data_category: + async_session.add(CategoryModel(**category_item)) + await async_session.commit() + + crud = FastCRUD(test_model) + new_data = read_schema(id=1, name="New Record", tier_id=1, category_id=1) + fetched_records = await crud.upsert_multi( + async_session, [new_data], **insert["kwargs"] + ) + + assert fetched_records == insert["expected_result"] + + updated_new_data = new_data.model_copy(update={"name": "New name"}) + updated_fetched_records = await crud.upsert_multi( + async_session, [updated_new_data], **update["kwargs"] + ) + + assert updated_fetched_records == update["expected_result"] + + +@pytest.mark.parametrize( + ["insert"], + [ + pytest.param( + { + "kwargs": {"return_columns": ["id", "name"]}, + "expected_exception": { + "type": ValueError, + "match": r"MySQL does not support the returning clause for insert operations.", + }, + }, + marks=pytest.mark.dialect("mysql"), + id="mysql-dict", + ), + pytest.param( + { + "kwargs": { + "name__like": "NewRecord", + }, + "expected_exception": { + "type": ValueError, + "match": r"MySQL does not support filtering on insert operations.", + }, + }, + marks=pytest.mark.dialect("mysql"), + id="mysql-dict-filtered", + ), + pytest.param( + { + "kwargs": { + "schema_to_select": ReadSchemaTest, + "return_as_model": True, + }, + "expected_exception": { + "type": ValueError, + "match": r"MySQL does not support the returning clause for insert operations.", + }, + }, + marks=pytest.mark.dialect("mysql"), + id="mysql-model", + ), + ], +) +@pytest.mark.asyncio +async def test_upsert_multi_unsupported( + async_session, + test_model, + read_schema, + test_data_tier, + test_data_category, + insert, +): + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + for category_item in test_data_category: + async_session.add(CategoryModel(**category_item)) + await async_session.commit() + + crud = FastCRUD(test_model) + new_data = read_schema(id=1, name="New Record", tier_id=1, category_id=1) + with pytest.raises( + insert["expected_exception"]["type"], + match=insert["expected_exception"]["match"], + ): + await crud.upsert_multi(async_session, [new_data], **insert["kwargs"])