From 9729e04f88525a948e812e7575b9c1860d1975bd Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Mon, 4 Mar 2024 23:29:31 +0400 Subject: [PATCH] make the `sqlite` test db setup a fixture --- py-polars/tests/unit/io/test_database_read.py | 59 ++++++++----------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index e82bffefb62b9..5ab6965be40e5 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -9,19 +9,17 @@ from types import GeneratorType from typing import TYPE_CHECKING, Any, NamedTuple -import pytest -from sqlalchemy import Integer, MetaData, Table, create_engine, func, select -from sqlalchemy.orm import sessionmaker -from sqlalchemy.sql.expression import cast as alchemy_cast - import polars as pl +import pytest from polars.exceptions import UnsuitableSQLError from polars.io.database import _ARROW_DRIVER_REGISTRY_ from polars.testing import assert_frame_equal +from sqlalchemy import Integer, MetaData, Table, create_engine, func, select +from sqlalchemy.orm import sessionmaker +from sqlalchemy.sql.expression import cast as alchemy_cast if TYPE_CHECKING: import pyarrow as pa - from polars.type_aliases import ( ConnectionOrCursor, DbReadEngine, @@ -34,24 +32,26 @@ def adbc_sqlite_connect(*args: Any, **kwargs: Any) -> Any: with suppress(ModuleNotFoundError): # not available on 3.8/windows from adbc_driver_sqlite.dbapi import connect + args = [str(a) if isinstance(a, Path) else a for a in args] return connect(*args, **kwargs) -def create_temp_sqlite_db(test_db: str) -> None: - Path(test_db).unlink(missing_ok=True) +@pytest.fixture() +def tmp_sqlite_db(tmp_path: Path) -> Path: + test_db = tmp_path / "test.db" + test_db.unlink(missing_ok=True) def convert_date(val: bytes) -> date: """Convert ISO 8601 date to datetime.date object.""" return date.fromisoformat(val.decode()) - sqlite3.register_converter("date", convert_date) - # NOTE: at the time of writing adcb/connectorx have weak SQLite support (poor or # no bool/date/datetime dtypes, for example) and there is a bug in connectorx that # causes float rounding < py 3.11, hence we are only testing/storing simple values # in this test db for now. as support improves, we can add/test additional dtypes). - + sqlite3.register_converter("date", convert_date) conn = sqlite3.connect(test_db) + # ┌─────┬───────┬───────┬────────────┐ # │ id ┆ name ┆ value ┆ date │ # │ --- ┆ --- ┆ --- ┆ --- │ @@ -62,17 +62,19 @@ def convert_date(val: bytes) -> date: # └─────┴───────┴───────┴────────────┘ conn.executescript( """ - CREATE TABLE test_data ( + CREATE TABLE IF NOT EXISTS test_data ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, value FLOAT, date DATE ); - INSERT INTO test_data(name,value,date) - VALUES ('misc',100.0,'2020-01-01'), ('other',-99.5,'2021-12-31'); + REPLACE INTO test_data(name,value,date) + VALUES ('misc',100.0,'2020-01-01'), + ('other',-99.5,'2021-12-31'); """ ) conn.close() + return test_db class DatabaseReadTestParams(NamedTuple): @@ -314,22 +316,19 @@ def test_read_database( schema_overrides: SchemaDict | None, batch_size: int | None, tmp_path: Path, + tmp_sqlite_db: Path, ) -> None: - tmp_path.mkdir(exist_ok=True) - test_db = str(tmp_path / "test.db") - create_temp_sqlite_db(test_db) - if read_method == "read_database_uri": # instantiate the connection ourselves, using connectorx/adbc df = pl.read_database_uri( - uri=f"sqlite:///{test_db}", + uri=f"sqlite:///{tmp_sqlite_db}", query="SELECT * FROM test_data", engine=str(connect_using), # type: ignore[arg-type] schema_overrides=schema_overrides, ) elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]: # externally instantiated adbc connections - with connect_using(test_db) as conn, conn.cursor(): + with connect_using(tmp_sqlite_db) as conn, conn.cursor(): df = pl.read_database( connection=conn, query="SELECT * FROM test_data", @@ -339,7 +338,7 @@ def test_read_database( else: # other user-supplied connections df = pl.read_database( - connection=connect_using(test_db), + connection=connect_using(tmp_sqlite_db), query="SELECT * FROM test_data WHERE name NOT LIKE '%polars%'", schema_overrides=schema_overrides, batch_size=batch_size, @@ -350,13 +349,9 @@ def test_read_database( assert df["date"].to_list() == expected_dates -def test_read_database_alchemy_selectable(tmp_path: Path) -> None: - # setup underlying test data - tmp_path.mkdir(exist_ok=True) - create_temp_sqlite_db(test_db := str(tmp_path / "test.db")) - +def test_read_database_alchemy_selectable(tmp_path: Path, tmp_sqlite_db: Path) -> None: # various flavours of alchemy connection - alchemy_engine = create_engine(f"sqlite:///{test_db}") + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() @@ -376,18 +371,14 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None: ) -def test_read_database_parameterised(tmp_path: Path) -> None: +def test_read_database_parameterised(tmp_path: Path, tmp_sqlite_db: Path) -> None: supports_adbc_sqlite = sys.version_info >= (3, 9) and sys.platform != "win32" - # setup underlying test data - tmp_path.mkdir(exist_ok=True) - create_temp_sqlite_db(test_db := str(tmp_path / "test.db")) - alchemy_engine = create_engine(f"sqlite:///{test_db}") - # raw cursor "execute" only takes positional params, alchemy cursor takes kwargs - raw_conn: ConnectionOrCursor = sqlite3.connect(test_db) + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() + raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db) # establish parameterised queries and validate usage query = """