Skip to content

Commit

Permalink
fix: propagate kwargs through compilation nodes
Browse files Browse the repository at this point in the history
Fixes #1138

When we visit a struct or union, we are handed an IdentifierProvider. Before, we just used it at this first layer, but forgot to pass it down to lower compilation steps. Now, we pass it along. This was a problem when we encountered nested dtypes, eg struct<struct<>>
  • Loading branch information
NickCrews committed Oct 22, 2024
1 parent 80d4c0b commit 95d8cef
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
17 changes: 11 additions & 6 deletions duckdb_engine/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def visit_struct(
identifier_preparer: PGIdentifierPreparer,
**kw: Any,
) -> str:
return "STRUCT" + struct_or_union(instance, compiler, identifier_preparer)
return "STRUCT" + struct_or_union(instance, compiler, identifier_preparer, **kw)


@compiles(Union, "duckdb") # type: ignore[misc]
Expand All @@ -243,21 +243,25 @@ def visit_union(
identifier_preparer: PGIdentifierPreparer,
**kw: Any,
) -> str:
return "UNION" + struct_or_union(instance, compiler, identifier_preparer)
return "UNION" + struct_or_union(instance, compiler, identifier_preparer, **kw)

Check warning on line 246 in duckdb_engine/datatypes.py

View check run for this annotation

Codecov / codecov/patch

duckdb_engine/datatypes.py#L246

Added line #L246 was not covered by tests


def struct_or_union(
instance: typing.Union[Union, Struct],
compiler: PGTypeCompiler,
identifier_preparer: PGIdentifierPreparer,
**kw: Any,
) -> str:
fields = instance.fields
if fields is None:
raise exc.CompileError(f"DuckDB {repr(instance)} type requires fields")
return "({})".format(
", ".join(
"{} {}".format(
identifier_preparer.quote_identifier(key), process_type(value, compiler)
identifier_preparer.quote_identifier(key),
process_type(
value, compiler, identifier_preparer=identifier_preparer, **kw
),
)
for key, value in fields.items()
)
Expand All @@ -267,13 +271,14 @@ def struct_or_union(
def process_type(
value: typing.Union[TypeEngine, Type[TypeEngine]],
compiler: PGTypeCompiler,
**kw: Any,
) -> str:
return compiler.process(type_api.to_instance(value))
return compiler.process(type_api.to_instance(value), **kw)


@compiles(Map, "duckdb") # type: ignore[misc]
def visit_map(instance: Map, compiler: PGTypeCompiler, **kw: Any) -> str:
return "MAP({}, {})".format(
process_type(instance.key_type, compiler),
process_type(instance.value_type, compiler),
process_type(instance.key_type, compiler, **kw),
process_type(instance.value_type, compiler, **kw),
)
23 changes: 23 additions & 0 deletions duckdb_engine/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,29 @@ class Entry(base):
assert result.map == map_data


def test_double_nested_types(engine: Engine, session: Session) -> None:
"""Test for https://github.com/Mause/duckdb_engine/issues/1138"""
importorskip("duckdb", "0.5.0") # nested types require at least duckdb 0.5.0
base = declarative_base()

class Entry(base):
__tablename__ = "test_struct"

id = Column(Integer, primary_key=True, default=0)
outer = Column(Struct({"inner": Struct({"val": Integer})}))

base.metadata.create_all(bind=engine)

outer = {"inner": {"val": 42}}

session.add(Entry(outer=outer)) # type: ignore[call-arg]
session.commit()

result = session.query(Entry).one()

assert result.outer == outer


def test_interval(engine: Engine, snapshot: SnapshotTest) -> None:
test_table = Table("test_table", MetaData(), Column("duration", Interval))

Expand Down

0 comments on commit 95d8cef

Please sign in to comment.