diff --git a/ibis/backends/__init__.py b/ibis/backends/__init__.py index bbfb5099816b..371affac47fc 100644 --- a/ibis/backends/__init__.py +++ b/ibis/backends/__init__.py @@ -839,7 +839,7 @@ def _finalize_cached_table(self, name: str) -> None: raise def _create_cached_table(self, name: str, expr: ir.Table) -> ir.Table: - return self.create_table(name, expr, temp=True) + return self.create_table(name, expr, schema=expr.schema(), temp=True) def _drop_cached_table(self, name: str) -> None: self.drop_table(name, force=True) diff --git a/ibis/backends/mssql/__init__.py b/ibis/backends/mssql/__init__.py index 9367adcd30e7..ef59caad3dec 100644 --- a/ibis/backends/mssql/__init__.py +++ b/ibis/backends/mssql/__init__.py @@ -692,15 +692,18 @@ def create_table( new = raw_this.sql(self.dialect) cur.execute(f"EXEC sp_rename '{old}', '{new}'") + if temp: + # If a temporary table, amend the output name/catalog/db accordingly + name = "##" + name + catalog = "tempdb" + db = "dbo" + if schema is None: # Clean up temporary memtable if we've created one # for in-memory reads if temp_memtable_view is not None: self.drop_table(temp_memtable_view) - return self.table( - "##" * temp + name, - database=("tempdb" * temp or catalog, "dbo" * temp or db), - ) + return self.table(name, database=(catalog, db)) # preserve the input schema if it was provided return ops.DatabaseTable( diff --git a/ibis/backends/mssql/tests/test_client.py b/ibis/backends/mssql/tests/test_client.py index 288eb3964e78..ec33cd3c16dc 100644 --- a/ibis/backends/mssql/tests/test_client.py +++ b/ibis/backends/mssql/tests/test_client.py @@ -216,6 +216,19 @@ def test_create_temp_table_from_obj(con): con.drop_table("fuhreal") +@pytest.mark.parametrize("explicit_schema", [False, True]) +def test_create_temp_table_from_expression(con, explicit_schema, temp_table): + t = ibis.memtable( + {"x": [1, 2, 3], "y": ["a", "b", "c"]}, schema={"x": "int64", "y": "str"} + ) + t2 = con.create_table( + temp_table, t, temp=True, schema=t.schema() if explicit_schema else None + ) + res = con.to_pandas(t.order_by("y")) + sol = con.to_pandas(t2.order_by("y")) + assert res.equals(sol) + + def test_from_url(): user = MSSQL_USER password = MSSQL_PASS diff --git a/ibis/backends/tests/test_expr_caching.py b/ibis/backends/tests/test_expr_caching.py index fb78484da0f6..75aba279553e 100644 --- a/ibis/backends/tests/test_expr_caching.py +++ b/ibis/backends/tests/test_expr_caching.py @@ -11,10 +11,6 @@ @mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"]) -@mark.never( - ["mssql"], - reason="mssql supports support temporary tables through naming conventions", -) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], @@ -27,15 +23,12 @@ def test_persist_expression(backend, alltypes): ) persisted_table = non_persisted_table.cache() backend.assert_frame_equal( - non_persisted_table.to_pandas(), persisted_table.to_pandas() + non_persisted_table.order_by("id").to_pandas(), + persisted_table.order_by("id").to_pandas(), ) @mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"]) -@mark.never( - ["mssql"], - reason="mssql supports support temporary tables through naming conventions", -) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], @@ -48,16 +41,13 @@ def test_persist_expression_contextmanager(backend, con, alltypes): ) with non_cached_table.cache() as cached_table: backend.assert_frame_equal( - non_cached_table.to_pandas(), cached_table.to_pandas() + non_cached_table.order_by("id").to_pandas(), + cached_table.order_by("id").to_pandas(), ) assert non_cached_table.op() not in con._cache_op_to_entry @mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"]) -@mark.never( - ["mssql"], - reason="mssql supports support temporary tables through naming conventions", -) @pytest.mark.never( ["risingwave"], raises=com.UnsupportedOperationError, @@ -81,7 +71,10 @@ def test_persist_expression_multiple_refs(backend, con, alltypes): op = non_cached_table.op() cached_table = non_cached_table.cache() - backend.assert_frame_equal(non_cached_table.to_pandas(), cached_table.to_pandas()) + backend.assert_frame_equal( + non_cached_table.order_by("id").to_pandas(), + cached_table.order_by("id").to_pandas(), + ) name = cached_table.op().name nested_cached_table = non_cached_table.cache() @@ -104,10 +97,6 @@ def test_persist_expression_multiple_refs(backend, con, alltypes): @mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"]) -@mark.never( - ["mssql"], - reason="mssql supports support temporary tables through naming conventions", -) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"],