Skip to content

Commit

Permalink
make the sqlite test db setup a fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Mar 4, 2024
1 parent f738c90 commit 9729e04
Showing 1 changed file with 25 additions and 34 deletions.
59 changes: 25 additions & 34 deletions py-polars/tests/unit/io/test_database_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,17 @@
from types import GeneratorType
from typing import TYPE_CHECKING, Any, NamedTuple

import pytest
from sqlalchemy import Integer, MetaData, Table, create_engine, func, select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql.expression import cast as alchemy_cast

import polars as pl
import pytest
from polars.exceptions import UnsuitableSQLError
from polars.io.database import _ARROW_DRIVER_REGISTRY_
from polars.testing import assert_frame_equal
from sqlalchemy import Integer, MetaData, Table, create_engine, func, select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql.expression import cast as alchemy_cast

if TYPE_CHECKING:
import pyarrow as pa

from polars.type_aliases import (
ConnectionOrCursor,
DbReadEngine,
Expand All @@ -34,24 +32,26 @@ def adbc_sqlite_connect(*args: Any, **kwargs: Any) -> Any:
with suppress(ModuleNotFoundError): # not available on 3.8/windows
from adbc_driver_sqlite.dbapi import connect

args = [str(a) if isinstance(a, Path) else a for a in args]
return connect(*args, **kwargs)


def create_temp_sqlite_db(test_db: str) -> None:
Path(test_db).unlink(missing_ok=True)
@pytest.fixture()
def tmp_sqlite_db(tmp_path: Path) -> Path:
test_db = tmp_path / "test.db"
test_db.unlink(missing_ok=True)

def convert_date(val: bytes) -> date:
"""Convert ISO 8601 date to datetime.date object."""
return date.fromisoformat(val.decode())

sqlite3.register_converter("date", convert_date)

# NOTE: at the time of writing adcb/connectorx have weak SQLite support (poor or
# no bool/date/datetime dtypes, for example) and there is a bug in connectorx that
# causes float rounding < py 3.11, hence we are only testing/storing simple values
# in this test db for now. as support improves, we can add/test additional dtypes).

sqlite3.register_converter("date", convert_date)
conn = sqlite3.connect(test_db)

# ┌─────┬───────┬───────┬────────────┐
# │ id ┆ name ┆ value ┆ date │
# │ --- ┆ --- ┆ --- ┆ --- │
Expand All @@ -62,17 +62,19 @@ def convert_date(val: bytes) -> date:
# └─────┴───────┴───────┴────────────┘
conn.executescript(
"""
CREATE TABLE test_data (
CREATE TABLE IF NOT EXISTS test_data (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
value FLOAT,
date DATE
);
INSERT INTO test_data(name,value,date)
VALUES ('misc',100.0,'2020-01-01'), ('other',-99.5,'2021-12-31');
REPLACE INTO test_data(name,value,date)
VALUES ('misc',100.0,'2020-01-01'),
('other',-99.5,'2021-12-31');
"""
)
conn.close()
return test_db


class DatabaseReadTestParams(NamedTuple):
Expand Down Expand Up @@ -314,22 +316,19 @@ def test_read_database(
schema_overrides: SchemaDict | None,
batch_size: int | None,
tmp_path: Path,
tmp_sqlite_db: Path,
) -> None:
tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / "test.db")
create_temp_sqlite_db(test_db)

if read_method == "read_database_uri":
# instantiate the connection ourselves, using connectorx/adbc
df = pl.read_database_uri(
uri=f"sqlite:///{test_db}",
uri=f"sqlite:///{tmp_sqlite_db}",
query="SELECT * FROM test_data",
engine=str(connect_using), # type: ignore[arg-type]
schema_overrides=schema_overrides,
)
elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]:
# externally instantiated adbc connections
with connect_using(test_db) as conn, conn.cursor():
with connect_using(tmp_sqlite_db) as conn, conn.cursor():
df = pl.read_database(
connection=conn,
query="SELECT * FROM test_data",
Expand All @@ -339,7 +338,7 @@ def test_read_database(
else:
# other user-supplied connections
df = pl.read_database(
connection=connect_using(test_db),
connection=connect_using(tmp_sqlite_db),
query="SELECT * FROM test_data WHERE name NOT LIKE '%polars%'",
schema_overrides=schema_overrides,
batch_size=batch_size,
Expand All @@ -350,13 +349,9 @@ def test_read_database(
assert df["date"].to_list() == expected_dates


def test_read_database_alchemy_selectable(tmp_path: Path) -> None:
# setup underlying test data
tmp_path.mkdir(exist_ok=True)
create_temp_sqlite_db(test_db := str(tmp_path / "test.db"))

def test_read_database_alchemy_selectable(tmp_path: Path, tmp_sqlite_db: Path) -> None:
# various flavours of alchemy connection
alchemy_engine = create_engine(f"sqlite:///{test_db}")
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()

Expand All @@ -376,18 +371,14 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None:
)


def test_read_database_parameterised(tmp_path: Path) -> None:
def test_read_database_parameterised(tmp_path: Path, tmp_sqlite_db: Path) -> None:
supports_adbc_sqlite = sys.version_info >= (3, 9) and sys.platform != "win32"

# setup underlying test data
tmp_path.mkdir(exist_ok=True)
create_temp_sqlite_db(test_db := str(tmp_path / "test.db"))
alchemy_engine = create_engine(f"sqlite:///{test_db}")

# raw cursor "execute" only takes positional params, alchemy cursor takes kwargs
raw_conn: ConnectionOrCursor = sqlite3.connect(test_db)
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db)

# establish parameterised queries and validate usage
query = """
Expand Down

0 comments on commit 9729e04

Please sign in to comment.