Skip to content

Commit

Permalink
refactor(snowflake): enable multiple statements and clean up duplicat…
Browse files Browse the repository at this point in the history
…ed parameter setting code
  • Loading branch information
cpcloud authored and jcrist committed Jul 5, 2023
1 parent c026d2d commit 75824a6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 38 deletions.
56 changes: 35 additions & 21 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import tempfile
import textwrap
import warnings
from typing import TYPE_CHECKING, Any, Iterable, Mapping
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping

import pyarrow as pa
import sqlalchemy as sa
Expand Down Expand Up @@ -188,14 +188,26 @@ def do_connect(
if connect_args is None:
connect_args = {}

connect_args.setdefault(
"session_parameters",
{
"JSON_INDENT": "0",
"PYTHON_CONNECTOR_QUERY_RESULT_FORMAT": "ARROW",
"STRICT_JSON_OUTPUT": "TRUE",
},
session_parameters = connect_args.setdefault("session_parameters", {})

# enable multiple SQL statements by default
session_parameters.setdefault("MULTI_STATEMENT_COUNT", "0")
# don't format JSON output by default
session_parameters.setdefault("JSON_INDENT", "0")

# overwrite session parameters that are required for ibis + snowflake
# to work
session_parameters.update(
dict(
# Use Arrow for query results
PYTHON_CONNECTOR_QUERY_RESULT_FORMAT="ARROW",
# JSON output must be strict for null versus undefined
STRICT_JSON_OUTPUT="TRUE",
# Timezone must be UTC
TIMEZONE="UTC",
),
)

if authenticator is not None:
connect_args.setdefault("authenticator", authenticator)

Expand All @@ -209,7 +221,6 @@ def connect(dbapi_connection, connection_record):
dialect = engine.dialect
quote = dialect.preparer(dialect).quote_identifier
with dbapi_connection.cursor() as cur:
cur.execute("ALTER SESSION SET TIMEZONE = 'UTC'")
(database, schema) = cur.execute(
"SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()"
).fetchone()
Expand Down Expand Up @@ -353,22 +364,25 @@ def to_pyarrow_batches(
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()
target_schema = expr.as_table().schema().to_pyarrow()
target_columns = target_schema.names

def batch_producer(con):
with con.begin() as c, contextlib.closing(c.execute(sql)) as cur:
yield from itertools.chain.from_iterable(
t.rename_columns(target_columns)
.cast(target_schema)
.to_batches(max_chunksize=chunk_size)
# yields pyarrow.Table objects, which are then converted to record batches
for t in cur.cursor.fetch_arrow_batches()
)

return pa.RecordBatchReader.from_batches(
target_schema, batch_producer(self.con)
target_schema,
self._make_batch_iter(
sql, target_schema=target_schema, chunk_size=chunk_size
),
)

def _make_batch_iter(
self, sql: str, *, target_schema: sch.Schema, chunk_size: int
) -> Iterator[pa.RecordBatch]:
with self.begin() as con, contextlib.closing(con.execute(sql)) as cur:
yield from itertools.chain.from_iterable(
t.rename_columns(target_schema.names)
.cast(target_schema)
.to_batches(max_chunksize=chunk_size)
for t in cur.cursor.fetch_arrow_batches()
)

def _get_sqla_table(
self,
name: str,
Expand Down
32 changes: 15 additions & 17 deletions ibis/backends/snowflake/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import concurrent.futures
import functools
import os
from functools import partial
from typing import TYPE_CHECKING, Any

import pytest
Expand All @@ -12,7 +11,6 @@
import ibis
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero
from ibis.util import consume

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -82,30 +80,30 @@ def _load_data(
raw_url = sa.engine.make_url(snowflake_url)
_, schema = raw_url.database.rsplit("/", 1)
url = raw_url.set(database="")
con = sa.create_engine(url)
con = sa.create_engine(
url, connect_args={"session_parameters": {"MULTI_STATEMENT_COUNT": "0"}}
)

dbschema = f"ibis_testing.{schema}"

stmts = [
"CREATE DATABASE IF NOT EXISTS ibis_testing",
f"CREATE SCHEMA IF NOT EXISTS {dbschema}",
f"USE SCHEMA {dbschema}",
*script_dir.joinpath("schema", "snowflake.sql").read_text().split(";"),
]

with con.begin() as c:
consume(map(c.exec_driver_sql, filter(None, map(str.strip, stmts))))
c.exec_driver_sql(
f"""\
CREATE DATABASE IF NOT EXISTS ibis_testing;
CREATE SCHEMA IF NOT EXISTS ibis_testing.{dbschema};
USE ibis_testing.{dbschema};
{script_dir.joinpath("schema", "snowflake.sql").read_text()}"""
)

with con.begin() as c:
# not much we can do to make this faster, but running these in
# multiple threads seems to save about 2x
with concurrent.futures.ThreadPoolExecutor() as exe:
for result in concurrent.futures.as_completed(
map(
partial(exe.submit, partial(copy_into, c, data_dir)),
TEST_TABLES,
)
for future in concurrent.futures.as_completed(
exe.submit(copy_into, c, data_dir, table)
for table in TEST_TABLES.keys()
):
result.result()
future.result()

@property
def functional_alltypes(self) -> ir.Table:
Expand Down

0 comments on commit 75824a6

Please sign in to comment.