Skip to content

Commit

Permalink
fix(snowflake): ensure that temp tables are only created once
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Apr 25, 2023
1 parent 8191529 commit 43b8152
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 47 deletions.
95 changes: 48 additions & 47 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,62 +365,63 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:

from ibis.backends.snowflake.datatypes import to_sqla_type

t = op.data.to_pyarrow(schema=op.schema)
dialect = self.con.dialect
quote = dialect.preparer(dialect).quote_identifier
table = quote(op.name)
stage = util.gen_name("stage")
raw_name = op.name
table = quote(raw_name)

with self.begin() as con:
# 1. create a temporary stage for holding parquet files
con.exec_driver_sql(f"CREATE TEMP STAGE {stage}")

with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
path = os.path.join(tmpdir, f"{op.name}.parquet")
# optimize for bandwidth so use zstd which typically compresses
# better than the other options without much loss in speed
pq.write_table(t, path, compression="zstd")

# 2. copy the parquet file into the stage
#
# disable the automatic compression to gzip because we've
# already compressed the data with zstd
#
# 99 is the limit on the number of threads use to upload data,
# who knows why?
if con.exec_driver_sql(f"SHOW TABLES LIKE '{raw_name}'").scalar() is None:
# 1. create a temporary stage for holding parquet files
stage = util.gen_name("stage")
con.exec_driver_sql(f"CREATE TEMP STAGE {stage}")

with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
t = op.data.to_pyarrow(schema=op.schema)
path = os.path.join(tmpdir, f"{raw_name}.parquet")
# optimize for bandwidth so use zstd which typically compresses
# better than the other options without much loss in speed
pq.write_table(t, path, compression="zstd")

# 2. copy the parquet file into the stage
#
# disable the automatic compression to gzip because we've
# already compressed the data with zstd
#
# 99 is the limit on the number of threads use to upload data,
# who knows why?
con.exec_driver_sql(
f"""
PUT 'file://{path}' @{stage}
PARALLEL = {min((os.cpu_count() or 2) // 2, 99)}
AUTO_COMPRESS = FALSE
"""
)

# 3. create a temporary table
schema = ", ".join(
"{name} {typ}".format(
name=quote(col),
typ=sa.types.to_instance(to_sqla_type(dialect, typ)).compile(
dialect=dialect
),
)
for col, typ in op.schema.items()
)
con.exec_driver_sql(f"CREATE TEMP TABLE {table} ({schema})")
# 4. copy the data into the table
columns = op.schema.names
column_names = ", ".join(map(quote, columns))
parquet_column_names = ", ".join(f"$1:{col}" for col in columns)
con.exec_driver_sql(
f"""
PUT 'file://{path}' @{stage}
PARALLEL={min((os.cpu_count() or 2) // 2, 99)}
AUTO_COMPRESS=FALSE
COPY INTO {table} ({column_names})
FROM (SELECT {parquet_column_names} FROM @{stage})
FILE_FORMAT = (TYPE = PARQUET COMPRESSION = AUTO)
PURGE = TRUE
"""
)

# 3. create a temporary table
schema = ", ".join(
"{name} {typ}".format(
name=quote(col),
typ=sa.types.to_instance(to_sqla_type(dialect, typ)).compile(
dialect=dialect
),
)
for col, typ in op.schema.items()
)
con.exec_driver_sql(f"CREATE TEMP TABLE {table} ({schema})")

# 4. copy the data into the table
columns = op.schema.names
column_names = ", ".join(map(quote, columns))
parquet_column_names = ", ".join(f"$1:{col}" for col in columns)
con.exec_driver_sql(
f"""
COPY INTO {table} ({column_names})
FROM (SELECT {parquet_column_names} FROM @{stage})
FILE_FORMAT=(TYPE=PARQUET COMPRESSION=AUTO)
PURGE=TRUE
"""
)

def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
) -> str:
Expand Down
16 changes: 16 additions & 0 deletions ibis/backends/snowflake/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,19 @@ def test_basic_memtable_registration(simple_con, data):
t = ibis.memtable(data)
result = simple_con.execute(t)
tm.assert_frame_equal(result, expected)


def test_repeated_memtable_registration(simple_con, mocker):
data = {"key": list("abc"), "value": [[1], [2], [3]]}
expected = pd.DataFrame(data)
t = ibis.memtable(data)

spy = mocker.spy(simple_con, "_register_in_memory_table")

n = 2

for _ in range(n):
tm.assert_frame_equal(simple_con.execute(t), expected)

# assert that we called _register_in_memory_table exactly n times
assert spy.call_count == n

0 comments on commit 43b8152

Please sign in to comment.