From c6a87e83cd261ee612a496c5e6ad31bed4ce600b Mon Sep 17 00:00:00 2001 From: Elliana May Date: Fri, 13 Dec 2024 13:12:18 +0800 Subject: [PATCH 1/3] chore: add __all__ to module --- duckdb_engine/__init__.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/duckdb_engine/__init__.py b/duckdb_engine/__init__.py index 435fa4e0..ae66f564 100644 --- a/duckdb_engine/__init__.py +++ b/duckdb_engine/__init__.py @@ -53,6 +53,15 @@ register_extension_types() +__all__ = [ + "Dialect", + "ConnectionWrapper", + "CursorWrapper", + "DBAPI", + "DuckDBEngineWarning", +] + + class DBAPI: paramstyle = "numeric_dollar" if sqlalchemy_version >= "2.0.0" else "qmark" apilevel = duckdb.apilevel From 51dbb186cff25c340c8395ef660cd6485923ec04 Mon Sep 17 00:00:00 2001 From: Elliana May Date: Fri, 13 Dec 2024 13:12:36 +0800 Subject: [PATCH 2/3] feat: reexport postgresql insert function --- duckdb_engine/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/duckdb_engine/__init__.py b/duckdb_engine/__init__.py index ae66f564..eba15b5c 100644 --- a/duckdb_engine/__init__.py +++ b/duckdb_engine/__init__.py @@ -18,7 +18,7 @@ import sqlalchemy from sqlalchemy import pool, select, sql, text, util from sqlalchemy import types as sqltypes -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects.postgresql import UUID, insert from sqlalchemy.dialects.postgresql.base import ( PGDialect, PGIdentifierPreparer, @@ -59,6 +59,7 @@ "CursorWrapper", "DBAPI", "DuckDBEngineWarning", + "insert", # reexport of sqlalchemy.dialects.postgresql.insert ] From d48e9d330a5aaf65a3632512a255f3ef5a1d1521 Mon Sep 17 00:00:00 2001 From: Elliana May Date: Fri, 13 Dec 2024 13:22:49 +0800 Subject: [PATCH 3/3] chore: add test --- duckdb_engine/tests/test_basic.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/duckdb_engine/tests/test_basic.py b/duckdb_engine/tests/test_basic.py index ed6d0416..633a2d92 100644 --- a/duckdb_engine/tests/test_basic.py +++ b/duckdb_engine/tests/test_basic.py @@ -36,7 +36,7 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, relationship, sessionmaker -from .. import Dialect, supports_attach, supports_user_agent +from .. import Dialect, insert, supports_attach, supports_user_agent from .._supports import has_comment_support try: @@ -653,3 +653,29 @@ def test_reflection(engine: Engine) -> None: with engine.connect() as conn: conn.execute(text("CREATE TABLE tbl(col1 INTEGER)")) metadata.reflect(engine) + + +def test_upsert(session: Session) -> None: + class User(Base): + __tablename__ = "users" + + id = Column(Integer(), Sequence("id_seq"), primary_key=True) + name = Column(String, unique=True) + fullname = Column(String) + + Base.metadata.create_all(session.bind) + stmt = insert(User).values( + [ + {"name": "spongebob", "fullname": "Spongebob Squarepants"}, + {"name": "sandy", "fullname": "Sandy Cheeks"}, + {"name": "patrick", "fullname": "Patrick Star"}, + {"name": "squidward", "fullname": "Squidward Tentacles"}, + {"name": "ehkrabs", "fullname": "Eugene H. Krabs"}, + ] + ) + stmt = stmt.on_conflict_do_update( + index_elements=[User.name], set_=dict(fullname=stmt.excluded.fullname) + ) + session.execute(stmt) + + assert session.query(User).count() == 5