diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index d62b98de5178..05fa65d7827f 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -365,62 +365,63 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: from ibis.backends.snowflake.datatypes import to_sqla_type - t = op.data.to_pyarrow(schema=op.schema) dialect = self.con.dialect quote = dialect.preparer(dialect).quote_identifier - table = quote(op.name) - stage = util.gen_name("stage") + raw_name = op.name + table = quote(raw_name) with self.begin() as con: - # 1. create a temporary stage for holding parquet files - con.exec_driver_sql(f"CREATE TEMP STAGE {stage}") - - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: - path = os.path.join(tmpdir, f"{op.name}.parquet") - # optimize for bandwidth so use zstd which typically compresses - # better than the other options without much loss in speed - pq.write_table(t, path, compression="zstd") - - # 2. copy the parquet file into the stage - # - # disable the automatic compression to gzip because we've - # already compressed the data with zstd - # - # 99 is the limit on the number of threads use to upload data, - # who knows why? + if con.exec_driver_sql(f"SHOW TABLES LIKE '{raw_name}'").scalar() is None: + # 1. create a temporary stage for holding parquet files + stage = util.gen_name("stage") + con.exec_driver_sql(f"CREATE TEMP STAGE {stage}") + + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: + t = op.data.to_pyarrow(schema=op.schema) + path = os.path.join(tmpdir, f"{raw_name}.parquet") + # optimize for bandwidth so use zstd which typically compresses + # better than the other options without much loss in speed + pq.write_table(t, path, compression="zstd") + + # 2. copy the parquet file into the stage + # + # disable the automatic compression to gzip because we've + # already compressed the data with zstd + # + # 99 is the limit on the number of threads use to upload data, + # who knows why? + con.exec_driver_sql( + f""" + PUT 'file://{path}' @{stage} + PARALLEL = {min((os.cpu_count() or 2) // 2, 99)} + AUTO_COMPRESS = FALSE + """ + ) + + # 3. create a temporary table + schema = ", ".join( + "{name} {typ}".format( + name=quote(col), + typ=sa.types.to_instance(to_sqla_type(dialect, typ)).compile( + dialect=dialect + ), + ) + for col, typ in op.schema.items() + ) + con.exec_driver_sql(f"CREATE TEMP TABLE {table} ({schema})") + # 4. copy the data into the table + columns = op.schema.names + column_names = ", ".join(map(quote, columns)) + parquet_column_names = ", ".join(f"$1:{col}" for col in columns) con.exec_driver_sql( f""" - PUT 'file://{path}' @{stage} - PARALLEL={min((os.cpu_count() or 2) // 2, 99)} - AUTO_COMPRESS=FALSE + COPY INTO {table} ({column_names}) + FROM (SELECT {parquet_column_names} FROM @{stage}) + FILE_FORMAT = (TYPE = PARQUET COMPRESSION = AUTO) + PURGE = TRUE """ ) - # 3. create a temporary table - schema = ", ".join( - "{name} {typ}".format( - name=quote(col), - typ=sa.types.to_instance(to_sqla_type(dialect, typ)).compile( - dialect=dialect - ), - ) - for col, typ in op.schema.items() - ) - con.exec_driver_sql(f"CREATE TEMP TABLE {table} ({schema})") - - # 4. copy the data into the table - columns = op.schema.names - column_names = ", ".join(map(quote, columns)) - parquet_column_names = ", ".join(f"$1:{col}" for col in columns) - con.exec_driver_sql( - f""" - COPY INTO {table} ({column_names}) - FROM (SELECT {parquet_column_names} FROM @{stage}) - FILE_FORMAT=(TYPE=PARQUET COMPRESSION=AUTO) - PURGE=TRUE - """ - ) - def _get_temp_view_definition( self, name: str, definition: sa.sql.compiler.Compiled ) -> str: diff --git a/ibis/backends/snowflake/tests/test_client.py b/ibis/backends/snowflake/tests/test_client.py index bfa495af4f26..c11da4465ffe 100644 --- a/ibis/backends/snowflake/tests/test_client.py +++ b/ibis/backends/snowflake/tests/test_client.py @@ -55,3 +55,19 @@ def test_basic_memtable_registration(simple_con, data): t = ibis.memtable(data) result = simple_con.execute(t) tm.assert_frame_equal(result, expected) + + +def test_repeated_memtable_registration(simple_con, mocker): + data = {"key": list("abc"), "value": [[1], [2], [3]]} + expected = pd.DataFrame(data) + t = ibis.memtable(data) + + spy = mocker.spy(simple_con, "_register_in_memory_table") + + n = 2 + + for _ in range(n): + tm.assert_frame_equal(simple_con.execute(t), expected) + + # assert that we called _register_in_memory_table exactly n times + assert spy.call_count == n