Skip to content

Commit

Permalink
fix(snowflake): ensure the correct compilation of tables from other d…
Browse files Browse the repository at this point in the history
…atabases and schemas
  • Loading branch information
cpcloud committed Aug 9, 2023
1 parent 910c1d9 commit 0ee68e2
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 46 deletions.
1 change: 1 addition & 0 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def _table_from_schema(
*columns,
prefixes=[self._temporary_prefix] if temp else [],
quote=self.compiler.translator_class._quote_table_names,
schema=database,
**kwargs,
)

Expand Down
169 changes: 141 additions & 28 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import pyarrow as pa
import sqlalchemy as sa
import sqlalchemy.types as sat
from snowflake.connector.constants import FIELD_ID_TO_NAME
from snowflake.sqlalchemy import ARRAY, DOUBLE, OBJECT, URL
from sqlalchemy.ext.compiler import compiles
Expand Down Expand Up @@ -236,14 +235,15 @@ def do_connect(
@sa.event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
"""Register UDFs on a `"connect"` event."""
dialect = engine.dialect
quote = dialect.preparer(dialect).quote_identifier
with dbapi_connection.cursor() as cur:
database, schema = cur.execute(
"SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()"
).fetchone()
try:
cur.execute("CREATE DATABASE IF NOT EXISTS ibis_udfs")
# snowflake activates a database on creation
cur.execute(f"USE SCHEMA {database}.{schema}")
cur.execute(f"USE SCHEMA {quote(database)}.{quote(schema)}")
for name, defn in _SNOWFLAKE_MAP_UDFS.items():
cur.execute(self._make_udf(name, defn))
except Exception as e: # noqa: BLE001
Expand Down Expand Up @@ -395,27 +395,38 @@ def _make_batch_iter(
for t in cur.cursor.fetch_arrow_batches()
)

@contextlib.contextmanager
def _use_schema(self, ident):
db = self.current_database
schema = self.current_schema
try:
with self.begin() as c:
c.exec_driver_sql(f"USE SCHEMA {ident}")
yield
finally:
with self.begin() as c:
c.exec_driver_sql(f"USE SCHEMA {self._quote(db)}.{self._quote(schema)}")

def _get_sqla_table(
self,
name: str,
schema: str | None = None,
database: str | None = None,
autoload: bool = True,
**kwargs: Any,
**_: Any,
) -> sa.Table:
default_db, default_schema = self.con.url.database.split("/", 1)
current_db = self.current_database
current_schema = self.current_schema
if schema is None:
schema = default_schema
schema = current_schema
*db, schema = schema.split(".")
db = "".join(db) or database or default_db
db = "".join(db) or database or current_db
ident = ".".join(map(self._quote, filter(None, (db, schema))))
try:
result = super()._get_sqla_table(
name, schema=schema, autoload=autoload, database=db, **kwargs
)
except sa.exc.NoSuchTableError:
raise sa.exc.NoSuchTableError(name)

pairs = self._metadata(f"SELECT * FROM {ident}.{self._quote(name)}")
ibis_schema = ibis.schema(pairs)

with self._use_schema(ident):
result = self._table_from_schema(name, schema=ibis_schema)
result.schema = ident
return result

Expand Down Expand Up @@ -451,10 +462,12 @@ def list_schemas(
def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
import pyarrow.parquet as pq

dialect = self.con.dialect
quote = dialect.preparer(dialect).quote_identifier
raw_name = op.name
table = quote(raw_name)
table = self._quote(raw_name)

current_db = self.current_database
current_schema = self.current_schema
ident = f"{self._quote(current_db)}.{self._quote(current_schema)}.{table}"

with self.begin() as con:
if con.exec_driver_sql(f"SHOW TABLES LIKE '{raw_name}'").scalar() is None:
Expand Down Expand Up @@ -491,22 +504,17 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:

# 3. create a temporary table
schema = ", ".join(
"{name} {typ}".format(
name=quote(col),
typ=sat.to_instance(SnowflakeType.from_ibis(typ)).compile(
dialect=dialect
),
)
f"{self._quote(col)} {SnowflakeType.to_string(typ) + ' NOT NULL' * (not typ.nullable)}"
for col, typ in op.schema.items()
)
con.exec_driver_sql(f"CREATE TEMP TABLE {table} ({schema})")
con.exec_driver_sql(f"CREATE TEMP TABLE {ident} ({schema})")
# 4. copy the data into the table
columns = op.schema.names
column_names = ", ".join(map(quote, columns))
column_names = ", ".join(map(self._quote, columns))
parquet_column_names = ", ".join(f"$1:{col}" for col in columns)
con.exec_driver_sql(
f"""
COPY INTO {table} ({column_names})
COPY INTO {ident} ({column_names})
FROM (SELECT {parquet_column_names} FROM @{stage})
FILE_FORMAT = (TYPE = PARQUET COMPRESSION = AUTO)
PURGE = TRUE
Expand Down Expand Up @@ -576,13 +584,118 @@ def drop_schema(
with self.begin() as con:
con.exec_driver_sql(f"DROP SCHEMA {if_exists}{name}")

def create_table(
self,
name: str,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
*,
schema: sch.Schema | None = None,
database: str | None = None,
temp: bool = False,
overwrite: bool = False,
comment: str | None = None,
) -> ir.Table:
"""Create a table in Snowflake.
Parameters
----------
name
Name of the table to create
obj
The data with which to populate the table; optional, but at least
one of `obj` or `schema` must be specified
schema
The schema of the table to create; optional, but at least one of
`obj` or `schema` must be specified
database
The name of the database in which to create the table; if not
passed, the current database is used.
temp
Create a temporary table
overwrite
If `True`, replace the table if it already exists, otherwise fail
if the table exists
comment
Add a comment to the table
"""
if obj is None and schema is None:
raise ValueError("Either `obj` or `schema` must be specified")

create_stmt = "CREATE"

if overwrite:
create_stmt += " OR REPLACE"

if temp:
create_stmt += " TEMPORARY"

ident = self._quote(name)
create_stmt += f" TABLE {ident}"

if schema is not None:
schema_sql = ", ".join(
f"{name} {SnowflakeType.to_string(typ) + ' NOT NULL' * (not typ.nullable)}"
for name, typ in zip(map(self._quote, schema.keys()), schema.values())
)
create_stmt += f" ({schema_sql})"

if obj is not None:
if not isinstance(obj, ir.Expr):
table = ibis.memtable(obj)
else:
table = obj

self._run_pre_execute_hooks(table)

query = self.compile(table).compile(
dialect=self.con.dialect, compile_kwargs=dict(literal_binds=True)
)
create_stmt += f" AS {query}"

if comment is not None:
create_stmt += f" COMMENT '{comment}'"

with self.begin() as con:
con.exec_driver_sql(create_stmt)

return self.table(name, schema=database)

def drop_table(
self, name: str, database: str | None = None, force: bool = False
) -> None:
name = self._quote(name)
# TODO: handle database quoting
if database is not None:
name = f"{database}.{name}"
drop_stmt = "DROP TABLE" + (" IF EXISTS" * force) + f" {name}"
with self.begin() as con:
con.exec_driver_sql(drop_stmt)


@compiles(sa.sql.Join, "snowflake")
def compile_join(element, compiler, **kw):
"""Override compilation of LATERAL joins.
Snowflake doesn't support lateral joins with ON clauses as of
https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057
even if they are trivial boolean literals.
"""
result = compiler.visit_join(element, **kw)

# snowflake doesn't support lateral joins with ON clauses as of
# https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057
if element.right._is_lateral:
return re.sub(r"^(.+) ON true$", r"\1", result, flags=re.IGNORECASE)
return result


@compiles(sa.Table, "snowflake")
def compile_table(element, compiler, **kw):
"""Override compilation of leaf tables.
The override is necessary because snowflake-sqlalchemy does not handle
quoting databases and schemas correctly.
"""
schema = element.schema
name = compiler.preparer.quote_identifier(element.name)
if schema is not None:
return f"{schema}.{name}"
return name
2 changes: 1 addition & 1 deletion ibis/backends/snowflake/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def compiles_nulltype(element, compiler, **kw):
"DATE": dt.date,
"TIMESTAMP": dt.timestamp,
"VARIANT": dt.json,
"TIMESTAMP_LTZ": dt.timestamp,
"TIMESTAMP_LTZ": dt.Timestamp("UTC"),
"TIMESTAMP_TZ": dt.Timestamp("UTC"),
"TIMESTAMP_NTZ": dt.timestamp,
"OBJECT": dt.Map(dt.string, dt.json),
Expand Down
22 changes: 17 additions & 5 deletions ibis/backends/snowflake/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,35 @@
@pytest.fixture
def temp_db(con):
db = gen_name("tmp_db")
con.raw_sql(f"CREATE DATABASE {db}")

con.create_database(db)
assert db in con.list_databases()

yield db
con.raw_sql(f"DROP DATABASE {db}")

con.drop_database(db)
assert db not in con.list_databases()


@pytest.fixture
def temp_schema(con, temp_db):
schema = gen_name("tmp_schema")
con.raw_sql(f"CREATE SCHEMA {temp_db}.{schema}")

con.create_schema(schema, database=temp_db)
assert schema in con.list_schemas(database=temp_db)

yield schema
con.raw_sql(f"DROP SCHEMA {temp_db}.{schema}")

con.drop_schema(schema, database=temp_db)
assert schema not in con.list_schemas(database=temp_db)


def test_cross_db_access(con, temp_db, temp_schema):
table = gen_name("tmp_table")
con.raw_sql(f'CREATE TABLE {temp_db}.{temp_schema}."{table}" ("x" INT)')
with con.begin() as c:
c.exec_driver_sql(
f'CREATE TABLE "{temp_db}"."{temp_schema}"."{table}" ("x" INT)'
)
t = con.table(table, schema=f"{temp_db}.{temp_schema}")
assert t.schema() == ibis.schema(dict(x="int"))
assert t.execute().empty
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/snowflake/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
("DATE", dt.date),
("TIMESTAMP", dt.Timestamp(scale=9)),
("VARIANT", dt.json),
("TIMESTAMP_LTZ", dt.Timestamp(scale=9)),
("TIMESTAMP_LTZ", dt.Timestamp(timezone="UTC", scale=9)),
("TIMESTAMP_TZ", dt.Timestamp(timezone="UTC", scale=9)),
("TIMESTAMP_NTZ", dt.Timestamp(scale=9)),
("OBJECT", dt.Map(dt.string, dt.json)),
Expand Down
22 changes: 12 additions & 10 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,13 @@ def test_sql(backend, con):


backend_type_mapping = {
# backends only implement int64
"bigquery": {
# backend only implements int64
dt.int32: dt.int64
}
dt.int32: dt.int64,
},
"snowflake": {
dt.int32: dt.int64,
},
}


Expand Down Expand Up @@ -293,11 +296,12 @@ def test_nullable_input_output(con, temp_table):
def test_create_drop_view(ddl_con, temp_view):
# setup
table_name = "functional_alltypes"
try:
expr = ddl_con.table(table_name)
except (KeyError, sa.exc.NoSuchTableError):
table_name = table_name.upper()
tables = ddl_con.list_tables()

if table_name in tables or (table_name := table_name.upper()) in tables:
expr = ddl_con.table(table_name)
else:
raise ValueError(f"table `{table_name}` does not exist")

expr = expr.limit(1)

Expand Down Expand Up @@ -541,7 +545,7 @@ def test_in_memory(alchemy_backend, alchemy_temp_table):
@pytest.mark.notyet(
["mssql", "mysql", "postgres", "snowflake", "sqlite", "trino"],
raises=TypeError,
reason="postgres, mysql and sqlite do not support unsigned integer types",
reason="backend does not support unsigned integer types",
)
def test_unsigned_integer_type(alchemy_con, alchemy_temp_table):
alchemy_con.create_table(
Expand Down Expand Up @@ -830,8 +834,6 @@ def test_agg_memory_table(con):
)
@pytest.mark.notimpl(["dask", "datafusion", "druid"])
def test_create_from_in_memory_table(backend, con, t, temp_table):
if backend.name() == "snowflake":
pytest.skip("snowflake is unreliable here")
con.create_table(temp_table, t)
assert temp_table in con.list_tables()

Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/benchmarks/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ def test_snowflake_medium_sized_to_pandas(benchmark):
# LINEITEM at scale factor 1 is around 6MM rows, but we limit to 1,000,000
# to make the benchmark fast enough for development, yet large enough to show a
# difference if there's a performance hit
lineitem = con.table("LINEITEM", schema="snowflake_sample_data.tpch_sf1").limit(
lineitem = con.table("LINEITEM", schema="SNOWFLAKE_SAMPLE_DATA.TPCH_SF1").limit(
1_000_000
)

Expand Down

0 comments on commit 0ee68e2

Please sign in to comment.