Skip to content

Commit

Permalink
refactor(sqlalchemy): use exec_driver_sql everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Jan 24, 2023
1 parent ee6d58a commit e8f96b6
Show file tree
Hide file tree
Showing 20 changed files with 148 additions and 185 deletions.
3 changes: 0 additions & 3 deletions ci/schema/postgresql.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
DROP SEQUENCE IF EXISTS test_sequence;
CREATE SEQUENCE IF NOT EXISTS test_sequence;

CREATE EXTENSION IF NOT EXISTS hstore;
CREATE EXTENSION IF NOT EXISTS postgis;
CREATE EXTENSION IF NOT EXISTS plpython3u;
Expand Down
12 changes: 12 additions & 0 deletions ci/schema/trino.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
DROP TABLE IF EXISTS map;
CREATE TABLE map (kv MAP<VARCHAR, BIGINT>);
INSERT INTO map VALUES
(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3])),
(MAP(ARRAY['d', 'e', 'f'], ARRAY[4, 5, 6]));

DROP TABLE IF EXISTS ts;
CREATE TABLE ts (x TIMESTAMP(3), y TIMESTAMP(6), z TIMESTAMP(9));
INSERT INTO ts VALUES
(TIMESTAMP '2023-01-07 13:20:05.561',
TIMESTAMP '2023-01-07 13:20:05.561021',
TIMESTAMP '2023-01-07 13:20:05.561000231');
16 changes: 10 additions & 6 deletions ibis/backends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,17 @@ def recreate_database(
engine = sa.create_engine(url.set(database=""), **kwargs)

if url.database is not None:
with engine.begin() as conn:
conn.execute(sa.text(f'DROP DATABASE IF EXISTS {database}'))
conn.execute(sa.text(f'CREATE DATABASE {database}'))
with engine.begin() as con:
con.exec_driver_sql(f"DROP DATABASE IF EXISTS {database}")
con.exec_driver_sql(f"CREATE DATABASE {database}")


def init_database(
url: sa.engine.url.URL,
database: str,
schema: TextIO | None = None,
recreate: bool = True,
isolation_level: str = "AUTOCOMMIT",
**kwargs: Any,
) -> sa.engine.Engine:
"""Initialise `database` at `url` with `schema`.
Expand All @@ -163,20 +164,23 @@ def init_database(
File object containing schema to use
recreate : bool
If true, drop the database if it exists
isolation_level : str
Transaction isolation_level
Returns
-------
sa.engine.Engine for the database created
sa.engine.Engine
SQLAlchemy engine object
"""
if recreate:
recreate_database(url, database, **kwargs)
recreate_database(url, database, isolation_level=isolation_level, **kwargs)

try:
url.database = database
except AttributeError:
url = url.set(database=database)

engine = sa.create_engine(url, **kwargs)
engine = sa.create_engine(url, isolation_level=isolation_level, **kwargs)

if schema:
with engine.begin() as conn:
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def _load_extensions(self, extensions):
for extension in extensions:
if extension not in self._extensions:
with self.begin() as con:
con.execute(sa.text(f"INSTALL '{extension}'"))
con.execute(sa.text(f"LOAD '{extension}'"))
con.exec_driver_sql(f"INSTALL '{extension}'")
con.exec_driver_sql(f"LOAD '{extension}'")
self._extensions.add(extension)

def register(
Expand Down Expand Up @@ -449,7 +449,7 @@ def fetch_from_cursor(

def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]:
with self.begin() as con:
rows = con.execute(sa.text(f"DESCRIBE {query}"))
rows = con.exec_driver_sql(f"DESCRIBE {query}")

for name, type, null in toolz.pluck(
["column_name", "column_type", "null"], rows.mappings()
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def do_connect(
@contextlib.contextmanager
def begin(self):
with super().begin() as bind:
prev = bind.execute(sa.text('SELECT @@DATEFIRST')).scalar()
bind.execute(sa.text('SET DATEFIRST 1'))
prev = bind.exec_driver_sql("SELECT @@DATEFIRST").scalar()
bind.exec_driver_sql("SET DATEFIRST 1")
yield bind
bind.execute(sa.text("SET DATEFIRST :prev").bindparams(prev=prev))

Expand All @@ -53,8 +53,8 @@ def _metadata(self, query):
query = f"SELECT * FROM [{query}]"

with self.begin() as bind:
for column in bind.execute(
sa.text(f"EXEC sp_describe_first_result_set @tsql = N'{query}';")
for column in bind.exec_driver_sql(
f"EXEC sp_describe_first_result_set @tsql = N'{query}'"
).mappings():
yield column["name"], _type_from_result_set_info(column)

Expand Down
5 changes: 2 additions & 3 deletions ibis/backends/mssql/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import sqlalchemy as sa
from pytest import param

import ibis
Expand Down Expand Up @@ -84,12 +83,12 @@ def test_get_schema_from_query(con, server_type, expected_type):
expected_schema = ibis.schema(dict(x=expected_type))
try:
with con.begin() as c:
c.execute(sa.text(f"CREATE TABLE {name} (x {server_type})"))
c.exec_driver_sql(f"CREATE TABLE {name} (x {server_type})")
expected_schema = ibis.schema(dict(x=expected_type))
result_schema = con._get_schema_using_query(f"SELECT * FROM {name}")
assert result_schema == expected_schema
t = con.table(raw_name)
assert t.schema() == expected_schema
finally:
with con.begin() as c:
c.execute(sa.text(f"DROP TABLE IF EXISTS {name}"))
c.exec_driver_sql(f"DROP TABLE IF EXISTS {name}")
13 changes: 6 additions & 7 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,18 @@ def do_connect(
@contextlib.contextmanager
def begin(self):
with super().begin() as bind:
prev = bind.execute(sa.text('SELECT @@session.time_zone')).scalar()
prev = bind.exec_driver_sql('SELECT @@session.time_zone').scalar()
try:
bind.execute(sa.text("SET @@session.time_zone = 'UTC'"))
bind.exec_driver_sql("SET @@session.time_zone = 'UTC'")
except Exception as e: # noqa: BLE001
warnings.warn(f"Couldn't set MySQL timezone: {e}")

yield bind
stmt = sa.text("SET @@session.time_zone = :prev").bindparams(prev=prev)
try:
bind.execute(
sa.text("SET @@session.time_zone = :prev").bindparams(prev=prev)
)
bind.execute(stmt)
except Exception as e: # noqa: BLE001
warnings.warn(sa.text(f"Couldn't reset MySQL timezone: {e}"))
warnings.warn(f"Couldn't reset MySQL timezone: {e}")

def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
if (
Expand All @@ -128,7 +127,7 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
query = f"({query})"

with self.begin() as con:
result = con.execute(sa.text(f"SELECT * FROM {query} _ LIMIT 0"))
result = con.exec_driver_sql(f"SELECT * FROM {query} _ LIMIT 0")
cursor = result.cursor
yield from (
(field.name, _type_from_cursor_info(descr, field))
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/mysql/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _load_data(
"LINES TERMINATED BY '\\n'",
"IGNORE 1 LINES",
]
con.execute(sa.text("\n".join(lines)))
con.exec_driver_sql("\n".join(lines))

@staticmethod
def connect(_: Path):
Expand Down
32 changes: 18 additions & 14 deletions ibis/backends/mysql/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import sqlalchemy as sa
from pytest import param

import ibis
Expand Down Expand Up @@ -64,21 +63,26 @@ def test_get_schema_from_query(con, mysql_type, expected_type):
# temporary tables get cleaned up by the db when the session ends, so we
# don't need to explicitly drop the table
with con.begin() as c:
c.execute(sa.text(f"CREATE TEMPORARY TABLE {name} (x {mysql_type})"))
expected_schema = ibis.schema(dict(x=expected_type))
t = con.table(raw_name)
result_schema = con._get_schema_using_query(f"SELECT * FROM {name}")
assert t.schema() == expected_schema
assert result_schema == expected_schema
c.exec_driver_sql(f"CREATE TEMPORARY TABLE {name} (x {mysql_type})")
try:
expected_schema = ibis.schema(dict(x=expected_type))
t = con.table(raw_name)
result_schema = con._get_schema_using_query(f"SELECT * FROM {name}")
assert t.schema() == expected_schema
assert result_schema == expected_schema
finally:
with con.begin() as c:
c.exec_driver_sql(f"DROP TABLE {name}")


@pytest.mark.parametrize(
"coltype",
["TINYBLOB", "MEDIUMBLOB", "BLOB", "LONGBLOB"],
)
@pytest.mark.parametrize("coltype", ["TINYBLOB", "MEDIUMBLOB", "BLOB", "LONGBLOB"])
def test_blob_type(con, coltype):
tmp = f"tmp_{ibis.util.guid()}"
with con.begin() as c:
c.execute(sa.text(f"CREATE TEMPORARY TABLE {tmp} (a {coltype})"))
t = con.table(tmp)
assert t.schema() == ibis.schema({"a": dt.binary})
c.exec_driver_sql(f"CREATE TEMPORARY TABLE {tmp} (a {coltype})")
try:
t = con.table(tmp)
assert t.schema() == ibis.schema({"a": dt.binary})
finally:
with con.begin() as c:
c.exec_driver_sql(f"DROP TABLE {tmp}")
15 changes: 7 additions & 8 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,18 @@ def list_databases(self, like=None):
# http://dba.stackexchange.com/a/1304/58517
databases = [
row.datname
for row in con.execute(
sa.text('SELECT datname FROM pg_database WHERE NOT datistemplate')
)
for row in con.exec_driver_sql(
"SELECT datname FROM pg_database WHERE NOT datistemplate"
).mappings()
]
return self._filter_with_like(databases, like)

@contextlib.contextmanager
def begin(self):
with super().begin() as bind:
prev = bind.execute(sa.text('SHOW TIMEZONE')).scalar()
bind.execute(sa.text('SET TIMEZONE = UTC'))
# LOCAL takes effect for the current transaction only
bind.exec_driver_sql("SET LOCAL TIMEZONE = UTC")
yield bind
bind.execute(sa.text("SET TIMEZONE = :prev").bindparams(prev=prev))

def udf(
self,
Expand Down Expand Up @@ -186,12 +185,12 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
AND NOT attisdropped
ORDER BY attnum"""
with self.begin() as con:
con.execute(sa.text(f"CREATE TEMPORARY VIEW {name} AS {query}"))
con.exec_driver_sql(f"CREATE TEMPORARY VIEW {name} AS {query}")
type_info = con.execute(
sa.text(type_info_sql).bindparams(raw_name=raw_name)
)
yield from ((col, _get_type(typestr)) for col, typestr in type_info)
con.execute(sa.text(f"DROP VIEW IF EXISTS {name}"))
con.exec_driver_sql(f"DROP VIEW IF EXISTS {name}")

def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _load_data(
with data_dir.joinpath(f'{table}.csv').open('r') as file:
cur.copy_expert(sql=sql, file=file)

con.execute(sa.text("VACUUM FULL ANALYZE"))
con.exec_driver_sql("VACUUM FULL ANALYZE")

@staticmethod
def connect(data_directory: Path):
Expand Down
7 changes: 2 additions & 5 deletions ibis/backends/postgres/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,9 @@ def test_create_and_drop_table(con, temp_table, params):
],
)
def test_get_schema_from_query(con, pg_type, expected_type):
raw_name = ibis.util.guid()
name = con._quote(raw_name)
name = con._quote(ibis.util.guid())
with con.begin() as c:
c.execute(
sa.text(f"CREATE TEMPORARY TABLE {name} (x {pg_type}, y {pg_type}[])")
)
c.exec_driver_sql(f"CREATE TEMP TABLE {name} (x {pg_type}, y {pg_type}[])")
expected_schema = ibis.schema(dict(x=expected_type, y=dt.Array(expected_type)))
result_schema = con._get_schema_using_query(f"SELECT x, y FROM {name}")
assert result_schema == expected_schema
28 changes: 12 additions & 16 deletions ibis/backends/postgres/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,8 +1041,8 @@ def test_array_concat_mixed_types(array_types):
@pytest.fixture
def t(con, guid):
with con.begin() as c:
c.execute(
sa.text(f"CREATE TABLE \"{guid}\" (id SERIAL PRIMARY KEY, name TEXT)")
c.exec_driver_sql(
f"CREATE TABLE {con._quote(guid)} (id SERIAL PRIMARY KEY, name TEXT)"
)
return con.table(guid)

Expand All @@ -1053,27 +1053,24 @@ def s(con, t, guid, guid2):
assert t.op().name != guid2

with con.begin() as c:
c.execute(
sa.text(
f"""
CREATE TABLE \"{guid2}\" (
c.exec_driver_sql(
f"""
CREATE TABLE {con._quote(guid2)} (
id SERIAL PRIMARY KEY,
left_t_id INTEGER REFERENCES "{guid}",
left_t_id INTEGER REFERENCES {con._quote(guid)},
cost DOUBLE PRECISION
)
"""
)
)
return con.table(guid2)


@pytest.fixture
def trunc(con, guid):
quoted = con._quote(guid)
with con.begin() as c:
c.execute(
sa.text(f"CREATE TABLE \"{guid}\" (id SERIAL PRIMARY KEY, name TEXT)")
)
c.execute(sa.text(f"INSERT INTO \"{guid}\" (name) VALUES ('a'), ('b'), ('c')"))
c.exec_driver_sql(f"CREATE TABLE {quoted} (id SERIAL PRIMARY KEY, name TEXT)")
c.exec_driver_sql(f"INSERT INTO {quoted} (name) VALUES ('a'), ('b'), ('c')")
return con.table(guid)


Expand Down Expand Up @@ -1314,9 +1311,8 @@ def test_timestamp_with_timezone_select(tzone_compute, tz):


def test_timestamp_type_accepts_all_timezones(con):
query = 'SELECT name FROM pg_timezone_names'
with con.begin() as c:
cur = c.execute(sa.text(query)).fetchall()
cur = c.exec_driver_sql("SELECT name FROM pg_timezone_names").fetchall()
assert all(dt.Timestamp(row.name).timezone == row.name for row in cur)


Expand Down Expand Up @@ -1416,7 +1412,7 @@ def test_string_to_binary_cast(con):
"FROM functional_alltypes LIMIT 10"
)
with con.begin() as c:
cur = c.execute(sa.text(sql_string))
cur = c.exec_driver_sql(sql_string)
raw_data = [row[0][0] for row in cur]
expected = pd.Series(raw_data, name=name)
tm.assert_series_equal(result, expected)
Expand All @@ -1433,6 +1429,6 @@ def test_string_to_binary_round_trip(con):
"FROM functional_alltypes LIMIT 10"
)
with con.begin() as c:
cur = c.execute(sa.text(sql_string))
cur = c.exec_driver_sql(sql_string)
expected = pd.Series([row[0][0] for row in cur], name=name)
tm.assert_series_equal(result, expected)
Loading

0 comments on commit e8f96b6

Please sign in to comment.