Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): DataFrame write_database not passing down "engine_options" when using ADBC #18451

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3967,6 +3967,7 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
mode=mode,
catalog_name=catalog,
db_schema_name=db_schema,
**(engine_options or {}),
)
elif db_schema is not None:
adbc_str_version = ".".join(str(v) for v in adbc_version)
Expand Down
36 changes: 36 additions & 0 deletions py-polars/tests/unit/io/database/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,39 @@ def test_write_database_sa_commit(tmp_path: str, pass_connection: bool) -> None:
)

assert_frame_equal(result, df)


@pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc not available on Windows or <= Python 3.8",
)
def test_write_database_adbc_temporary_table() -> None:
"""Confirm that execution_options are passed along to create temporary tables."""
df = pl.DataFrame({"colx": [1, 2, 3]})
temp_tbl_name = "should_be_temptable"
expected_temp_table_create_sql = (
"""CREATE TABLE "should_be_temptable" ("colx" INTEGER)"""
)

# test with sqlite in memory
conn = _open_adbc_connection("sqlite:///:memory:")
assert (
df.write_database(
temp_tbl_name,
connection=conn,
if_table_exists="fail",
engine_options={"temporary": True},
)
== 3
)
temp_tbl_sql_df = pl.read_database(
"select sql from sqlite_temp_master where type='table' and tbl_name = ?",
connection=conn,
execute_options={"parameters": [temp_tbl_name]},
)
assert temp_tbl_sql_df.shape[0] == 1, "no temp table created"
actual_temp_table_create_sql = temp_tbl_sql_df["sql"][0]
assert expected_temp_table_create_sql == actual_temp_table_create_sql

if hasattr(conn, "close"):
conn.close()